"""Multi-runner Elo rating engine for harness racing.

Implements pairwise logistic Elo with support for:
- Multi-entity ratings (horse, driver, trainer)
- Condition adjustments (barrier, handicap)
- Rating deviation (RD) for uncertainty tracking
"""

import math
from dataclasses import dataclass
from datetime import date

from sqlalchemy.orm import Session

from packages.core.common.logging import get_logger
from packages.core.common.settings import get_settings
from packages.core.common.utils import get_distance_bucket
from packages.core.storage.models import EntityType, Race, Starter
from packages.core.storage.repositories import (
    BarrierAdjustmentRepository,
    HandicapAdjustmentRepository,
)

logger = get_logger(__name__)


@dataclass
class RatingState:
    """Current rating state for an entity."""

    rating: float
    rd: float | None = None
    race_count: int = 0
    last_race_id: int | None = None
    last_race_date: date | None = None


@dataclass
class RatingUpdate:
    """Rating update to apply after a race."""

    entity_type: EntityType
    entity_id: int
    old_rating: float
    new_rating: float
    delta: float
    rd: float | None = None
    meta: dict | None = None


class RatingEngine:
    """Multi-runner Elo rating engine."""

    def __init__(self, db_session: Session | None = None):
        """Initialize rating engine with configuration.

        Args:
            db_session: Optional database session for loading/saving adjustments
        """
        self.settings = get_settings().rating
        self.db_session = db_session

        # Rating states (in-memory cache during computation)
        self.states: dict[tuple[EntityType, int], RatingState] = {}

        # Condition adjustments (loaded/updated during computation)
        self.barrier_adjustments: dict[tuple, float] = {}
        self.handicap_adjustments: dict[tuple, float] = {}
        self.barrier_adjustment_samples: dict[tuple, int] = {}
        self.handicap_adjustment_samples: dict[tuple, int] = {}

        # Load adjustments from database if session provided
        if self.db_session and self.settings.enable_adjustments:
            self.load_adjustments_from_db()

    def get_or_init_rating(
        self, entity_type: EntityType, entity_id: int
    ) -> RatingState:
        """Get current rating state or initialize new entity.

        Args:
            entity_type: Type of entity
            entity_id: Entity ID

        Returns:
            Current rating state
        """
        key = (entity_type, entity_id)
        if key not in self.states:
            self.states[key] = RatingState(
                rating=self.settings.initial_rating,
                rd=self.settings.initial_rd if self.settings.enable_rd else None,
                race_count=0,
            )
        return self.states[key]

    def load_rating_state(
        self,
        entity_type: EntityType,
        entity_id: int,
        rating: float,
        rd: float | None = None,
        last_race_date: date | None = None,
    ) -> None:
        """Load existing rating state from database.

        Args:
            entity_type: Type of entity
            entity_id: Entity ID
            rating: Current rating
            rd: Current rating deviation
            last_race_date: Date of last race
        """
        key = (entity_type, entity_id)
        self.states[key] = RatingState(
            rating=rating,
            rd=rd,
            last_race_date=last_race_date,
        )

    def sigmoid(self, x: float) -> float:
        """Logistic sigmoid function.

        Args:
            x: Input value

        Returns:
            Value between 0 and 1
        """
        if x >= 0:
            return 1.0 / (1.0 + math.exp(-x))
        exp_x = math.exp(x)
        return exp_x / (1.0 + exp_x)

    def get_effective_k_factor(self, entity_type: EntityType, entity_id: int) -> float:
        """Compute effective K-factor based on rating deviation.

        When RD is enabled, adjust K-factor proportionally to entity's uncertainty:
        - High RD (new/inactive) → larger K → faster rating changes
        - Low RD (established) → smaller K → more stable ratings

        Args:
            entity_type: Type of entity (HORSE, DRIVER, TRAINER)
            entity_id: Entity ID

        Returns:
            Effective K-factor for this entity
        """
        base_k = self.settings.elo_k_base
        k_eff = base_k

        # If RD not enabled, use base K-factor
        if self.settings.enable_rd:
            # Get entity's current RD
            state = self.get_or_init_rating(entity_type, entity_id)
            if state.rd is not None:
                initial_rd = self.settings.initial_rd
                if initial_rd > 0:
                    ratio = state.rd / initial_rd
                    if self.settings.rd_scaling_mode == "sqrt":
                        ratio = math.sqrt(ratio)
                    elif self.settings.rd_scaling_mode == "none":
                        ratio = 1.0
                    k_eff = base_k * ratio

        if self.settings.elo_k_min is not None:
            k_eff = max(k_eff, self.settings.elo_k_min)
        if self.settings.elo_k_max is not None:
            k_eff = min(k_eff, self.settings.elo_k_max)

        return k_eff

    def compute_effective_rating(
        self,
        starter: Starter,
        race: Race,
    ) -> float:
        """Compute effective rating for a starter.

        R_eff = R_horse + α*R_driver + β*R_trainer + barrier_adj + handicap_adj

        Args:
            starter: Starter instance
            race: Race instance

        Returns:
            Effective rating
        """
        # Base horse rating
        horse_state = self.get_or_init_rating(EntityType.HORSE, starter.horse_id)
        r_eff = horse_state.rating

        # Add driver contribution
        if self.settings.enable_driver and starter.driver_id:
            driver_state = self.get_or_init_rating(EntityType.DRIVER, starter.driver_id)
            r_eff += self.settings.driver_weight_alpha * driver_state.rating

        # Add trainer contribution
        if self.settings.enable_trainer and starter.trainer_id:
            trainer_state = self.get_or_init_rating(
                EntityType.TRAINER, starter.trainer_id
            )
            r_eff += self.settings.trainer_weight_beta * trainer_state.rating

        # Add condition adjustments
        if self.settings.enable_adjustments:
            # Barrier adjustment
            if starter.barrier is not None:
                barrier_adj = self._get_barrier_adjustment(
                    race.meeting.venue,
                    race.start_type,
                    race.distance_m,
                    starter.barrier,
                )
                r_eff += barrier_adj

            # Handicap adjustment
            if starter.handicap_m is not None and starter.handicap_m != 0:
                handicap_adj = self._get_handicap_adjustment(
                    race.meeting.venue,
                    race.start_type,
                    race.distance_m,
                    starter.handicap_m,
                )
                r_eff += handicap_adj

        return r_eff

    def _get_barrier_adjustment(
        self,
        venue: str | None,
        start_type: str | None,
        distance_m: int | None,
        barrier: int,
    ) -> float:
        """Get barrier adjustment from learned table.

        Args:
            venue: Venue name
            start_type: mobile/standing
            distance_m: Distance in meters
            barrier: Barrier number

        Returns:
            Adjustment value (default 0.0)
        """
        if not self.settings.adj_barrier_enabled:
            return 0.0

        distance_bucket = get_distance_bucket(
            distance_m,
            self.settings.distance_buckets,
            mode=self.settings.distance_bucket_mode,
            bucket_size=self.settings.distance_bucket_size,
        )

        # Try specific key first, then fall back to global
        key = (venue, start_type, distance_bucket, barrier)

        # Fall back to global (no venue/start_type)
        global_key = (None, None, distance_bucket, barrier)
        return self._resolve_adjustment(
            key,
            global_key,
            self.barrier_adjustments,
            self.barrier_adjustment_samples,
        )

    def _get_handicap_adjustment(
        self,
        venue: str | None,
        start_type: str | None,
        distance_m: int | None,
        handicap_m: int,
    ) -> float:
        """Get handicap adjustment from learned table.

        Args:
            venue: Venue name
            start_type: mobile/standing
            distance_m: Distance in meters
            handicap_m: Handicap in meters

        Returns:
            Adjustment value (default 0.0)
        """
        if not self.settings.adj_handicap_enabled:
            return 0.0

        distance_bucket = get_distance_bucket(
            distance_m,
            self.settings.distance_buckets,
            mode=self.settings.distance_bucket_mode,
            bucket_size=self.settings.distance_bucket_size,
        )

        key = (venue, start_type, distance_bucket, handicap_m)

        global_key = (None, None, distance_bucket, handicap_m)
        return self._resolve_adjustment(
            key,
            global_key,
            self.handicap_adjustments,
            self.handicap_adjustment_samples,
        )

    def _resolve_adjustment(
        self,
        key: tuple,
        global_key: tuple,
        adjustments: dict[tuple, float],
        samples: dict[tuple, int],
    ) -> float:
        for candidate in (key, global_key):
            if candidate not in adjustments:
                continue
            if self.settings.adj_min_samples > 0:
                count = samples.get(candidate, 0)
                if count < self.settings.adj_min_samples:
                    continue
            adjustment = adjustments[candidate]
            return self._clamp_adjustment(adjustment)
        return 0.0

    def _clamp_adjustment(self, adjustment: float) -> float:
        if self.settings.adj_clamp_min is not None:
            adjustment = max(adjustment, self.settings.adj_clamp_min)
        if self.settings.adj_clamp_max is not None:
            adjustment = min(adjustment, self.settings.adj_clamp_max)
        return adjustment

    def load_adjustments_from_db(self) -> None:
        """Load barrier and handicap adjustments from database into memory."""
        if not self.db_session:
            return

        # Load barrier adjustments
        barrier_adjs = BarrierAdjustmentRepository.get_all(self.db_session)
        for adj in barrier_adjs:
            key = (adj.venue, adj.start_type, adj.distance_bucket, adj.barrier)
            self.barrier_adjustments[key] = adj.adjustment
            self.barrier_adjustment_samples[key] = adj.sample_count

        # Load handicap adjustments
        handicap_adjs = HandicapAdjustmentRepository.get_all(self.db_session)
        for adj in handicap_adjs:
            key = (adj.venue, adj.start_type, adj.distance_bucket, adj.handicap_m)
            self.handicap_adjustments[key] = adj.adjustment
            self.handicap_adjustment_samples[key] = adj.sample_count

        logger.info(
            f"Loaded {len(barrier_adjs)} barrier adjustments and "
            f"{len(handicap_adjs)} handicap adjustments from database"
        )

    def learn_adjustments_from_race(
        self, race: Race, starters: list[Starter], use_global_only: bool | None = None
    ) -> None:
        """Learn barrier and handicap adjustments from a completed race.

        Uses performance residuals: if a horse performs better than expected
        given its rating, attribute some of that to favorable conditions.

        Args:
            race: Race instance
            starters: List of starters with results
            use_global_only: If True, only update global adjustments (no venue-specific)
        """
        if not self.db_session or not self.settings.enable_adjustments:
            return
        if (
            not self.settings.adj_barrier_enabled
            and not self.settings.adj_handicap_enabled
        ):
            return

        valid_starters = [
            s for s in starters if s.placing is not None and not s.did_not_finish
        ]
        if len(valid_starters) < self.settings.min_finishers:
            return

        # Compute effective ratings (without adjustments to avoid feedback loop)
        saved_enable = self.settings.enable_adjustments
        self.settings.enable_adjustments = False

        effective_ratings = {}
        for starter in valid_starters:
            effective_ratings[starter.id] = self.compute_effective_rating(starter, race)

        self.settings.enable_adjustments = saved_enable

        # For each starter, compute performance residual
        use_global_only = (
            self.settings.adj_global_only
            if use_global_only is None
            else use_global_only
        )

        for i, starter in enumerate(valid_starters):
            if not starter.horse_id:
                continue

            r_eff = effective_ratings[starter.id]
            placing = starter.placing

            # Compute expected vs actual: sum over pairwise comparisons
            expected_sum = 0.0
            actual_sum = 0.0

            for j, other in enumerate(valid_starters):
                if i == j or not other.horse_id:
                    continue

                r_other = effective_ratings[other.id]
                # Expected: probability of beating other
                expected = self.sigmoid((r_eff - r_other) / self.settings.elo_scale_c)
                # Actual: 1 if beat other, 0 otherwise
                actual = 1.0 if placing < other.placing else 0.0

                expected_sum += expected
                actual_sum += actual

            # Performance delta: positive means outperformed expectations
            delta = (actual_sum - expected_sum) / max(len(valid_starters) - 1, 1)

            scaled_delta = delta * self.settings.adj_update_scale

            # Update barrier adjustment if present
            if self.settings.adj_barrier_enabled and starter.barrier is not None:
                distance_bucket = get_distance_bucket(
                    race.distance_m,
                    self.settings.distance_buckets,
                    mode=self.settings.distance_bucket_mode,
                    bucket_size=self.settings.distance_bucket_size,
                )

                if use_global_only:
                    venue, start_type = None, None
                else:
                    venue, start_type = race.meeting.venue, race.start_type

                BarrierAdjustmentRepository.increment_sample(
                    self.db_session,
                    venue=venue,
                    start_type=start_type,
                    distance_bucket=distance_bucket,
                    barrier=starter.barrier,
                    delta=scaled_delta,
                    learning_rate=self.settings.adj_learning_rate,
                )

            # Update handicap adjustment if present and non-zero
            if (
                self.settings.adj_handicap_enabled
                and starter.handicap_m is not None
                and starter.handicap_m != 0
            ):
                distance_bucket = get_distance_bucket(
                    race.distance_m,
                    self.settings.distance_buckets,
                    mode=self.settings.distance_bucket_mode,
                    bucket_size=self.settings.distance_bucket_size,
                )

                if use_global_only:
                    venue, start_type = None, None
                else:
                    venue, start_type = race.meeting.venue, race.start_type

                HandicapAdjustmentRepository.increment_sample(
                    self.db_session,
                    venue=venue,
                    start_type=start_type,
                    distance_bucket=distance_bucket,
                    handicap_m=starter.handicap_m,
                    delta=scaled_delta,
                    learning_rate=self.settings.adj_learning_rate,
                )

    def process_race(self, race: Race, starters: list[Starter]) -> list[RatingUpdate]:
        """Process a race and compute rating updates.

        Uses pairwise logistic Elo:
        - For each pair (i, j), compute expected outcome E_ij
        - Update based on actual outcome S_ij (1 if i beat j, 0 otherwise)
        - ΔR_i = K * (1/(n-1)) * Σ_j (S_ij - E_ij)

        Args:
            race: Race instance
            starters: List of starters in race

        Returns:
            List of rating updates to apply
        """
        # Filter starters with valid placings
        finishers = [
            s for s in starters if s.placing is not None and not s.did_not_finish
        ]
        dnf_starters = [s for s in starters if s.did_not_finish or s.placing is None]

        placing_by_id: dict[int, int] = {}
        for starter in finishers:
            placing_by_id[starter.id] = starter.placing

        if self.settings.dnf_treated_as_last and dnf_starters:
            max_place = max(placing_by_id.values(), default=0)
            for starter in dnf_starters:
                placing_by_id[starter.id] = max_place + 1
            valid_starters = finishers + dnf_starters
        else:
            valid_starters = finishers

        if len(valid_starters) < self.settings.min_finishers:
            logger.debug(
                f"Skipping race {race.id} - fewer than {self.settings.min_finishers} finishers"
            )
            return []

        n = len(valid_starters)
        updates = []

        # Compute effective ratings for all starters
        effective_ratings = {}
        for starter in valid_starters:
            effective_ratings[starter.id] = self.compute_effective_rating(starter, race)

        # Process each starter
        for i, starter_i in enumerate(valid_starters):
            if not starter_i.horse_id:
                continue

            r_eff_i = effective_ratings[starter_i.id]
            placing_i = starter_i.placing
            if starter_i.id in placing_by_id:
                placing_i = placing_by_id[starter_i.id]

            # Compute update based on pairwise comparisons
            delta_sum = 0.0
            comparisons = 0

            for j, starter_j in enumerate(valid_starters):
                if i == j or not starter_j.horse_id:
                    continue

                r_eff_j = effective_ratings[starter_j.id]
                placing_j = starter_j.placing
                if starter_j.id in placing_by_id:
                    placing_j = placing_by_id[starter_j.id]

                # Actual outcome: 1 if i beat j, 0 otherwise
                if placing_i == placing_j:
                    if self.settings.tie_handling == "skip":
                        continue
                    s_ij = 0.5 if self.settings.tie_handling == "half" else 0.0
                else:
                    s_ij = 1.0 if placing_i < placing_j else 0.0

                # Expected outcome using logistic model
                e_ij = self.sigmoid((r_eff_i - r_eff_j) / self.settings.elo_scale_c)

                # Accumulate delta
                delta_sum += s_ij - e_ij
                comparisons += 1

            # Average over pairwise comparisons, using effective K-factor
            if self.settings.pairwise_normalizer == "comparisons":
                normalizer = comparisons
            elif self.settings.pairwise_normalizer == "n":
                normalizer = n
            else:
                normalizer = n - 1

            if normalizer <= 0:
                continue

            k_eff = (
                self.get_effective_k_factor(EntityType.HORSE, starter_i.horse_id)
                * self.settings.horse_k_scale
            )
            delta_r = k_eff * (delta_sum / normalizer)

            # Get race date for RD calculations
            race_date = race.meeting.meeting_date if race.meeting else None

            # Apply updates to all entities involved
            self._apply_update(
                EntityType.HORSE,
                starter_i.horse_id,
                delta_r,
                race.id,
                race_date,
                updates,
            )

            if self.settings.enable_driver and starter_i.driver_id:
                driver_delta = (
                    delta_r
                    * self.settings.driver_weight_alpha
                    * self.settings.driver_k_scale
                )
                self._apply_update(
                    EntityType.DRIVER,
                    starter_i.driver_id,
                    driver_delta,
                    race.id,
                    race_date,
                    updates,
                )

            if self.settings.enable_trainer and starter_i.trainer_id:
                trainer_delta = (
                    delta_r
                    * self.settings.trainer_weight_beta
                    * self.settings.trainer_k_scale
                )
                self._apply_update(
                    EntityType.TRAINER,
                    starter_i.trainer_id,
                    trainer_delta,
                    race.id,
                    race_date,
                    updates,
                )

        return updates

    def _apply_update(
        self,
        entity_type: EntityType,
        entity_id: int,
        delta: float,
        race_id: int,
        race_date: date | None,
        updates: list[RatingUpdate],
    ) -> None:
        """Apply rating update to an entity.

        Args:
            entity_type: Type of entity
            entity_id: Entity ID
            delta: Rating change
            race_id: Race ID
            race_date: Date of the race
            updates: List to append update to
        """
        state = self.get_or_init_rating(entity_type, entity_id)
        old_rating = state.rating
        new_rating = old_rating + delta
        if self.settings.rating_min is not None:
            new_rating = max(new_rating, self.settings.rating_min)
        if self.settings.rating_max is not None:
            new_rating = min(new_rating, self.settings.rating_max)
        delta = new_rating - old_rating

        # Update RD if enabled
        if self.settings.enable_rd and state.rd is not None:
            # First apply inflation for inactivity
            if state.last_race_date and race_date:
                days_inactive = (race_date - state.last_race_date).days
                if days_inactive > 0:
                    if self.settings.rd_inflation_cap_days is not None:
                        days_inactive = min(
                            days_inactive, self.settings.rd_inflation_cap_days
                        )
                    inflation = days_inactive * self.settings.rd_inflation_per_day
                    state.rd = min(state.rd + inflation, self.settings.rd_max)

            # Then apply decay for participating in race
            decay = max(self.settings.rd_decay_per_race, self.settings.rd_decay_floor)
            state.rd = max(state.rd - decay, self.settings.rd_min)

        # Update state
        state.rating = new_rating
        state.race_count += 1
        state.last_race_id = race_id
        state.last_race_date = race_date

        # Record update
        updates.append(
            RatingUpdate(
                entity_type=entity_type,
                entity_id=entity_id,
                old_rating=old_rating,
                new_rating=new_rating,
                delta=delta,
                rd=state.rd,
                meta={
                    "race_count": state.race_count,
                },
            )
        )

        logger.debug(
            f"{entity_type.value} {entity_id}: "
            f"{old_rating:.1f} -> {new_rating:.1f} (Δ{delta:+.1f})"
            + (f", RD={state.rd:.1f}" if state.rd else "")
        )
