Coverage for packages / core / storage / database.py: 47%

88 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 08:37 +1200

1"""Database connection and session management.""" 

2 

3import time 

4from collections.abc import Generator 

5from contextlib import contextmanager 

6 

7from sqlalchemy import create_engine, event, text 

8from sqlalchemy.orm import Session, sessionmaker 

9 

10from packages.core.common.logging import get_logger 

11from packages.core.common.settings import get_settings 

12 

13logger = get_logger(__name__) 

14 

15# Global engine and session factory 

16_engine = None 

17_SessionLocal = None 

18 

19_MAX_DB_RETRIES = 5 

20_INITIAL_RETRY_DELAY = 1.0 # seconds 

21 

22 

23def init_db(): 

24 """Initialize database engine and session factory with retry logic.""" 

25 global _engine, _SessionLocal 

26 

27 if _engine is not None: 

28 return 

29 

30 settings = get_settings() 

31 db_config = settings.database 

32 

33 logger.info("Initializing database connection (max_retries=%d)", _MAX_DB_RETRIES) 

34 

35 # Retry loop with exponential backoff 

36 last_exception = None 

37 for attempt in range(1, _MAX_DB_RETRIES + 1): 

38 try: 

39 _engine = create_engine( 

40 db_config.url, 

41 pool_size=db_config.pool_size, 

42 max_overflow=db_config.max_overflow, 

43 pool_pre_ping=True, 

44 echo=False, 

45 ) 

46 # Verify connection works by executing a simple query 

47 with _engine.connect() as conn: 

48 conn.execute(text("SELECT 1")) 

49 last_exception = None 

50 break 

51 except Exception as e: 

52 last_exception = e 

53 if attempt < _MAX_DB_RETRIES: 

54 delay = _INITIAL_RETRY_DELAY * (2 ** (attempt - 1)) 

55 logger.warning( 

56 "Database connection attempt %d/%d failed: %s. " 

57 "Retrying in %.1fs...", 

58 attempt, 

59 _MAX_DB_RETRIES, 

60 e, 

61 delay, 

62 ) 

63 time.sleep(delay) 

64 else: 

65 logger.error( 

66 "All %d database connection attempts failed.", 

67 _MAX_DB_RETRIES, 

68 ) 

69 

70 if last_exception is not None: 

71 raise RuntimeError( 

72 f"Failed to connect to database after {_MAX_DB_RETRIES} attempts. " 

73 f"Last error: {last_exception}" 

74 ) 

75 

76 # Set up slow query detection listeners 

77 _setup_slow_query_detection(_engine) 

78 

79 _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine) 

80 

81 logger.info("Database connection initialized") 

82 

83 

84def _setup_slow_query_detection(engine): 

85 """Register SQLAlchemy event listeners for detecting slow queries. 

86 

87 Logs queries that exceed the configured threshold (default: 100ms) 

88 at WARNING level with query text, duration, and sanitized parameters. 

89 

90 Args: 

91 engine: SQLAlchemy engine to attach listeners to. 

92 """ 

93 threshold_ms = get_settings().database.slow_query_threshold_ms 

94 

95 @event.listens_for(engine, "before_cursor_execute") 

96 def _before_cursor_execute( 

97 conn, cursor, statement, parameters, context, executemany 

98 ): 

99 conn.info["query_start_time"] = time.time() 

100 

101 @event.listens_for(engine, "after_cursor_execute") 

102 def _after_cursor_execute( 

103 conn, cursor, statement, parameters, context, executemany 

104 ): 

105 total = time.time() - conn.info.pop("query_start_time", time.time()) 

106 duration_ms = total * 1000 

107 

108 if duration_ms > threshold_ms: 

109 sanitized = _sanitize_params(parameters) 

110 logger.warning( 

111 "Slow query (%.1f ms): %s | params=%s", 

112 duration_ms, 

113 statement, 

114 sanitized, 

115 ) 

116 

117 

118def _sanitize_params(parameters): 

119 """Sanitize query parameters for safe logging. 

120 

121 Truncates long string values to prevent log flooding. 

122 Handles dict, list, tuple, and None parameter formats. 

123 

124 Args: 

125 parameters: Raw SQLAlchemy query parameters. 

126 

127 Returns: 

128 Sanitized parameters safe for logging. 

129 """ 

130 if parameters is None: 

131 return None 

132 if isinstance(parameters, dict): 

133 return {k: _trunc(v) for k, v in parameters.items()} 

134 if isinstance(parameters, (list, tuple)): 

135 if parameters and isinstance(parameters[0], dict): 

136 return [{k: _trunc(v) for k, v in p.items()} for p in parameters] 

137 return tuple(_trunc(v) for v in parameters) 

138 return _trunc(parameters) 

139 

140 

141def _trunc(value, max_len: int = 100): 

142 """Truncate a value's string representation for safe logging.""" 

143 s = str(value) 

144 return s[:max_len] + "..." if len(s) > max_len else s 

145 

146 

147def get_engine(): 

148 """Get SQLAlchemy engine, initializing if needed. 

149 

150 Returns: 

151 SQLAlchemy engine 

152 """ 

153 if _engine is None: 

154 init_db() 

155 return _engine 

156 

157 

158def get_session_factory(): 

159 """Get session factory, initializing if needed. 

160 

161 Returns: 

162 Session factory 

163 """ 

164 if _SessionLocal is None: 

165 init_db() 

166 return _SessionLocal 

167 

168 

169@contextmanager 

170def get_session() -> Generator[Session, None, None]: 

171 """Get database session context manager. 

172 

173 Yields: 

174 Database session 

175 

176 Example: 

177 >>> with get_session() as session: 

178 >>> horses = session.query(Horse).all() 

179 """ 

180 SessionLocal = get_session_factory() 

181 session = SessionLocal() 

182 try: 

183 yield session 

184 session.commit() 

185 except Exception: 

186 session.rollback() 

187 raise 

188 finally: 

189 session.close() 

190 

191 

192def dispose_db(): 

193 """Dispose database engine and reset globals. 

194 

195 Useful for testing and cleanup. 

196 """ 

197 global _engine, _SessionLocal 

198 

199 if _engine is not None: 

200 _engine.dispose() 

201 _engine = None 

202 _SessionLocal = None 

203 

204 logger.info("Database connection disposed")