Coverage for packages / core / storage / database.py: 47%
88 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
1"""Database connection and session management."""
3import time
4from collections.abc import Generator
5from contextlib import contextmanager
7from sqlalchemy import create_engine, event, text
8from sqlalchemy.orm import Session, sessionmaker
10from packages.core.common.logging import get_logger
11from packages.core.common.settings import get_settings
13logger = get_logger(__name__)
15# Global engine and session factory
16_engine = None
17_SessionLocal = None
19_MAX_DB_RETRIES = 5
20_INITIAL_RETRY_DELAY = 1.0 # seconds
23def init_db():
24 """Initialize database engine and session factory with retry logic."""
25 global _engine, _SessionLocal
27 if _engine is not None:
28 return
30 settings = get_settings()
31 db_config = settings.database
33 logger.info("Initializing database connection (max_retries=%d)", _MAX_DB_RETRIES)
35 # Retry loop with exponential backoff
36 last_exception = None
37 for attempt in range(1, _MAX_DB_RETRIES + 1):
38 try:
39 _engine = create_engine(
40 db_config.url,
41 pool_size=db_config.pool_size,
42 max_overflow=db_config.max_overflow,
43 pool_pre_ping=True,
44 echo=False,
45 )
46 # Verify connection works by executing a simple query
47 with _engine.connect() as conn:
48 conn.execute(text("SELECT 1"))
49 last_exception = None
50 break
51 except Exception as e:
52 last_exception = e
53 if attempt < _MAX_DB_RETRIES:
54 delay = _INITIAL_RETRY_DELAY * (2 ** (attempt - 1))
55 logger.warning(
56 "Database connection attempt %d/%d failed: %s. "
57 "Retrying in %.1fs...",
58 attempt,
59 _MAX_DB_RETRIES,
60 e,
61 delay,
62 )
63 time.sleep(delay)
64 else:
65 logger.error(
66 "All %d database connection attempts failed.",
67 _MAX_DB_RETRIES,
68 )
70 if last_exception is not None:
71 raise RuntimeError(
72 f"Failed to connect to database after {_MAX_DB_RETRIES} attempts. "
73 f"Last error: {last_exception}"
74 )
76 # Set up slow query detection listeners
77 _setup_slow_query_detection(_engine)
79 _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
81 logger.info("Database connection initialized")
84def _setup_slow_query_detection(engine):
85 """Register SQLAlchemy event listeners for detecting slow queries.
87 Logs queries that exceed the configured threshold (default: 100ms)
88 at WARNING level with query text, duration, and sanitized parameters.
90 Args:
91 engine: SQLAlchemy engine to attach listeners to.
92 """
93 threshold_ms = get_settings().database.slow_query_threshold_ms
95 @event.listens_for(engine, "before_cursor_execute")
96 def _before_cursor_execute(
97 conn, cursor, statement, parameters, context, executemany
98 ):
99 conn.info["query_start_time"] = time.time()
101 @event.listens_for(engine, "after_cursor_execute")
102 def _after_cursor_execute(
103 conn, cursor, statement, parameters, context, executemany
104 ):
105 total = time.time() - conn.info.pop("query_start_time", time.time())
106 duration_ms = total * 1000
108 if duration_ms > threshold_ms:
109 sanitized = _sanitize_params(parameters)
110 logger.warning(
111 "Slow query (%.1f ms): %s | params=%s",
112 duration_ms,
113 statement,
114 sanitized,
115 )
118def _sanitize_params(parameters):
119 """Sanitize query parameters for safe logging.
121 Truncates long string values to prevent log flooding.
122 Handles dict, list, tuple, and None parameter formats.
124 Args:
125 parameters: Raw SQLAlchemy query parameters.
127 Returns:
128 Sanitized parameters safe for logging.
129 """
130 if parameters is None:
131 return None
132 if isinstance(parameters, dict):
133 return {k: _trunc(v) for k, v in parameters.items()}
134 if isinstance(parameters, (list, tuple)):
135 if parameters and isinstance(parameters[0], dict):
136 return [{k: _trunc(v) for k, v in p.items()} for p in parameters]
137 return tuple(_trunc(v) for v in parameters)
138 return _trunc(parameters)
141def _trunc(value, max_len: int = 100):
142 """Truncate a value's string representation for safe logging."""
143 s = str(value)
144 return s[:max_len] + "..." if len(s) > max_len else s
147def get_engine():
148 """Get SQLAlchemy engine, initializing if needed.
150 Returns:
151 SQLAlchemy engine
152 """
153 if _engine is None:
154 init_db()
155 return _engine
158def get_session_factory():
159 """Get session factory, initializing if needed.
161 Returns:
162 Session factory
163 """
164 if _SessionLocal is None:
165 init_db()
166 return _SessionLocal
169@contextmanager
170def get_session() -> Generator[Session, None, None]:
171 """Get database session context manager.
173 Yields:
174 Database session
176 Example:
177 >>> with get_session() as session:
178 >>> horses = session.query(Horse).all()
179 """
180 SessionLocal = get_session_factory()
181 session = SessionLocal()
182 try:
183 yield session
184 session.commit()
185 except Exception:
186 session.rollback()
187 raise
188 finally:
189 session.close()
192def dispose_db():
193 """Dispose database engine and reset globals.
195 Useful for testing and cleanup.
196 """
197 global _engine, _SessionLocal
199 if _engine is not None:
200 _engine.dispose()
201 _engine = None
202 _SessionLocal = None
204 logger.info("Database connection disposed")