"""Database connection and session management."""

import time
from collections.abc import Generator
from contextlib import contextmanager

from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import Session, sessionmaker

from packages.core.common.logging import get_logger
from packages.core.common.settings import get_settings

logger = get_logger(__name__)

# Global engine and session factory
_engine = None
_SessionLocal = None

_MAX_DB_RETRIES = 5
_INITIAL_RETRY_DELAY = 1.0  # seconds


def init_db():
    """Initialize database engine and session factory with retry logic."""
    global _engine, _SessionLocal

    if _engine is not None:
        return

    settings = get_settings()
    db_config = settings.database

    logger.info("Initializing database connection (max_retries=%d)", _MAX_DB_RETRIES)

    # Retry loop with exponential backoff
    last_exception = None
    for attempt in range(1, _MAX_DB_RETRIES + 1):
        try:
            _engine = create_engine(
                db_config.url,
                pool_size=db_config.pool_size,
                max_overflow=db_config.max_overflow,
                pool_pre_ping=True,
                echo=False,
            )
            # Verify connection works by executing a simple query
            with _engine.connect() as conn:
                conn.execute(text("SELECT 1"))
            last_exception = None
            break
        except Exception as e:
            last_exception = e
            if attempt < _MAX_DB_RETRIES:
                delay = _INITIAL_RETRY_DELAY * (2 ** (attempt - 1))
                logger.warning(
                    "Database connection attempt %d/%d failed: %s. "
                    "Retrying in %.1fs...",
                    attempt,
                    _MAX_DB_RETRIES,
                    e,
                    delay,
                )
                time.sleep(delay)
            else:
                logger.error(
                    "All %d database connection attempts failed.",
                    _MAX_DB_RETRIES,
                )

    if last_exception is not None:
        raise RuntimeError(
            f"Failed to connect to database after {_MAX_DB_RETRIES} attempts. "
            f"Last error: {last_exception}"
        )

    # Set up slow query detection listeners
    _setup_slow_query_detection(_engine)

    _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)

    logger.info("Database connection initialized")


def _setup_slow_query_detection(engine):
    """Register SQLAlchemy event listeners for detecting slow queries.

    Logs queries that exceed the configured threshold (default: 100ms)
    at WARNING level with query text, duration, and sanitized parameters.

    Args:
        engine: SQLAlchemy engine to attach listeners to.
    """
    threshold_ms = get_settings().database.slow_query_threshold_ms

    @event.listens_for(engine, "before_cursor_execute")
    def _before_cursor_execute(
        conn, cursor, statement, parameters, context, executemany
    ):
        conn.info["query_start_time"] = time.time()

    @event.listens_for(engine, "after_cursor_execute")
    def _after_cursor_execute(
        conn, cursor, statement, parameters, context, executemany
    ):
        total = time.time() - conn.info.pop("query_start_time", time.time())
        duration_ms = total * 1000

        if duration_ms > threshold_ms:
            sanitized = _sanitize_params(parameters)
            logger.warning(
                "Slow query (%.1f ms): %s | params=%s",
                duration_ms,
                statement,
                sanitized,
            )


def _sanitize_params(parameters):
    """Sanitize query parameters for safe logging.

    Truncates long string values to prevent log flooding.
    Handles dict, list, tuple, and None parameter formats.

    Args:
        parameters: Raw SQLAlchemy query parameters.

    Returns:
        Sanitized parameters safe for logging.
    """
    if parameters is None:
        return None
    if isinstance(parameters, dict):
        return {k: _trunc(v) for k, v in parameters.items()}
    if isinstance(parameters, (list, tuple)):
        if parameters and isinstance(parameters[0], dict):
            return [{k: _trunc(v) for k, v in p.items()} for p in parameters]
        return tuple(_trunc(v) for v in parameters)
    return _trunc(parameters)


def _trunc(value, max_len: int = 100):
    """Truncate a value's string representation for safe logging."""
    s = str(value)
    return s[:max_len] + "..." if len(s) > max_len else s


def get_engine():
    """Get SQLAlchemy engine, initializing if needed.

    Returns:
        SQLAlchemy engine
    """
    if _engine is None:
        init_db()
    return _engine


def get_session_factory():
    """Get session factory, initializing if needed.

    Returns:
        Session factory
    """
    if _SessionLocal is None:
        init_db()
    return _SessionLocal


@contextmanager
def get_session() -> Generator[Session, None, None]:
    """Get database session context manager.

    Yields:
        Database session

    Example:
        >>> with get_session() as session:
        >>>     horses = session.query(Horse).all()
    """
    SessionLocal = get_session_factory()
    session = SessionLocal()
    try:
        yield session
        session.commit()
    except Exception:
        session.rollback()
        raise
    finally:
        session.close()


def dispose_db():
    """Dispose database engine and reset globals.

    Useful for testing and cleanup.
    """
    global _engine, _SessionLocal

    if _engine is not None:
        _engine.dispose()
        _engine = None
        _SessionLocal = None

    logger.info("Database connection disposed")
