Coverage for secondary.py: 98%

70 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 09:44 +0000

1import os 

2import re 

3import time 

4from subprocess import Popen 

5 

6from database import Database 

7 

8WAITING_PROGRESS_IN_SECONDS = 10 

9 

10 

11class Secondary: 

12 def __init__(self, db: Database): 

13 self.db = db 

14 

15 def get_subscription_name(self, db_primary: str) -> str | None: 

16 query = f"select subname from pg_subscription where subname like 'subscription_{db_primary}_%'" 

17 print(f"psql \"{self.db.conn_string}\" --no-align -tc \"{query}\"") 

18 results = self.db.execute_query(query) 

19 if results is None: 

20 return None 

21 if not results: 

22 return "" 

23 

24 return results[0][0] 

25 

26 def create_subscription(self, unique_name: str): 

27 # Create subscription on secondary 

28 subscription_name = f"subscription_{unique_name}" 

29 print( 

30 f"Create subscription on secondary {self.db.conn_string} database {self.db.db_name}") 

31 # Get the primary db connexion string fron environment 

32 connection_primary_full = os.environ.get('CONN_DB_PRIMARY_FULL') 

33 self.db.execute_query( 

34 f"CREATE SUBSCRIPTION {subscription_name} CONNECTION '{connection_primary_full}' PUBLICATION publication_{unique_name} with (copy_data=true, create_slot=true, enabled=true, slot_name='{subscription_name}');", 

35 fetch=False) 

36 

37 def wait_first_step_of_replication(self): 

38 print( 

39 "The first step of logical replication is not finished - retrying later") 

40 while True: 

41 try: 

42 results = self.db.execute_query("select a.* from pg_subscription_rel a inner join pg_class on srrelid=pg_class.oid where relname <> 'spatial_ref_sys' and srsubstate <> 'r';") 

43 if not results: 

44 break 

45 

46 # Log progress 

47 progress_query = """ 

48 with ready as (select count(a.*) as ready 

49 from pg_subscription_rel a 

50 inner join pg_class on srrelid = pg_class.oid 

51 where relname <> 'spatial_ref_sys' 

52 and srsubstate = 'r'), 

53 total as (select count(a.*) as total 

54 from pg_subscription_rel a 

55 inner join pg_class on srrelid = pg_class.oid 

56 where relname <> 'spatial_ref_sys') 

57 select * 

58 from ready, 

59 total; \ 

60 """ 

61 results = self.db.execute_query(progress_query) 

62 print( 

63 f"Replication progress : {results[0][0]}/{results[0][1]}") 

64 

65 time.sleep(WAITING_PROGRESS_IN_SECONDS) 

66 except: 

67 # If the query fails, it means there are no more tables in non-ready state 

68 break 

69 

70 def disable_subscription(self, subscription_name): 

71 print(f"Disable subscription on {self.db.conn_string}") 

72 self.db.execute_query(f"ALTER SUBSCRIPTION {subscription_name} DISABLE;", fetch=False) 

73 

74 def enable_subscription(self, subscription_name): 

75 print(f"Enable subscription on {self.db.conn_string}") 

76 self.db.execute_query(f"ALTER SUBSCRIPTION {subscription_name} ENABLE;", fetch=False) 

77 

78 def execute_pre_data_dump(self, dump: Popen[str]): 

79 dump_queries = dump.stdout.read() 

80 

81 queries = dump_queries.replace("CREATE SCHEMA public;", "") 

82 # ignore "\restrict" and "\unrestrict" lines 

83 queries = re.sub("\\\\(un)?restrict.*\n", "", queries) 

84 

85 self.db.execute_query_rollback_on_error(queries) 

86 

87 def execute_post_data_dump_only_pk(self, dump: Popen[str]): 

88 dump_queries = dump.stdout.read() 

89 

90 # ignore "\restrict" and "\unrestrict" lines 

91 dump_queries = re.sub("\\\\(un)?restrict.*\n", "", dump_queries) 

92 splitlines = dump_queries.splitlines() 

93 queries = "" 

94 for i in range(0, len(splitlines) - 1): 

95 if re.match(r'.*ADD CONSTRAINT.*PRIMARY KEY.*', splitlines[i]): 

96 queries = queries + \ 

97 splitlines[i - 1] + splitlines[i] 

98 

99 self.db.execute_query_rollback_on_error(queries) 

100 

101 def execute_post_data_dump_without_pk(self, dump: Popen[str]): 

102 dump_queries = dump.stdout.read() 

103 

104 # ignore "\restrict" and "\unrestrict" lines 

105 dump_queries = re.sub("\\\\(un)?restrict.*\n", "", dump_queries) 

106 splitlines = dump_queries.splitlines() 

107 queries = "" 

108 line_before = "" 

109 for i in range(0, len(splitlines) - 1): 

110 # Skip primary key constraints 

111 if re.match(r'.*ADD CONSTRAINT.*PRIMARY KEY.*', splitlines[i]): 

112 line_before = "" 

113 continue 

114 else: 

115 queries += line_before 

116 line_before = splitlines[i] + "\n" 

117 

118 queries += line_before 

119 

120 self.db.execute_query_rollback_on_error(queries)