"""Enhanced race prediction functionality with confidence intervals and tracking."""

import math
from dataclasses import dataclass
from datetime import date, datetime

from sqlalchemy import and_, func, or_
from sqlalchemy.orm import Session

from packages.core.common.logging import get_logger
from packages.core.common.settings import get_settings
from packages.core.ratings.engine import RatingEngine
from packages.core.storage.models import EntityType, Race, Starter
from packages.core.storage.repositories import RatingSnapshotRepository

logger = get_logger(__name__)


@dataclass
class PredictionResult:
    """Prediction for a single starter."""

    starter_id: int
    horse_id: int
    horse_name: str | None
    driver_id: int | None
    driver_name: str | None
    trainer_id: int | None
    trainer_name: str | None
    barrier: int | None
    handicap_m: float | None
    effective_rating: float
    win_probability: float
    place_probability: float  # Top 3
    place_score: float
    confidence_interval_low: float  # 95% CI lower bound
    confidence_interval_high: float  # 95% CI upper bound
    predicted_placing: int


@dataclass
class RacePrediction:
    """Complete prediction for a race."""

    race_id: int
    race_number: int | None
    venue: str | None
    distance_m: int | None
    race_date: date | None
    predictions: list[PredictionResult]
    prediction_timestamp: datetime
    metadata: dict


class PredictionEngine:
    """Enhanced prediction engine with confidence intervals and tracking."""

    def __init__(self, session: Session):
        """Initialize prediction engine.

        Args:
            session: Database session
        """
        self.session = session
        self.rating_engine = RatingEngine(session)
        settings = get_settings().rating
        self.place_history_limit = settings.place_history_limit
        self.place_prior_rate = settings.place_prior_rate
        self.place_prior_weight = settings.place_prior_weight
        self.place_top3_weight = settings.place_top3_weight
        self.place_consistency_weight = settings.place_consistency_weight

    def predict_race(self, race: Race, starters: list[Starter]) -> RacePrediction:
        """Generate predictions for a race.

        Args:
            race: Race instance
            starters: List of starters

        Returns:
            Complete race prediction with probabilities
        """
        if not starters:
            logger.warning(f"No starters for race {race.id}")
            return RacePrediction(
                race_id=race.id,
                race_number=race.race_number,
                venue=race.meeting.venue if race.meeting else None,
                distance_m=race.distance_m,
                race_date=race.meeting.meeting_date if race.meeting else None,
                predictions=[],
                prediction_timestamp=datetime.now(),
                metadata={},
            )

        # Load latest ratings for all starters (use pre-race ratings if available)
        for starter in starters:
            if starter.horse_id:
                latest = RatingSnapshotRepository.get_latest_rating(
                    self.session,
                    EntityType.HORSE,
                    starter.horse_id,
                    before_race_id=race.id,
                )
                if latest:
                    self.rating_engine.load_rating_state(
                        EntityType.HORSE,
                        starter.horse_id,
                        latest.rating,
                        latest.rd,
                        last_race_date=(
                            latest.race.meeting.meeting_date
                            if latest.race and latest.race.meeting
                            else None
                        ),
                    )

            if starter.driver_id:
                latest = RatingSnapshotRepository.get_latest_rating(
                    self.session,
                    EntityType.DRIVER,
                    starter.driver_id,
                    before_race_id=race.id,
                )
                if latest:
                    self.rating_engine.load_rating_state(
                        EntityType.DRIVER,
                        starter.driver_id,
                        latest.rating,
                        latest.rd,
                        last_race_date=(
                            latest.race.meeting.meeting_date
                            if latest.race and latest.race.meeting
                            else None
                        ),
                    )

            if starter.trainer_id:
                latest = RatingSnapshotRepository.get_latest_rating(
                    self.session,
                    EntityType.TRAINER,
                    starter.trainer_id,
                    before_race_id=race.id,
                )
                if latest:
                    self.rating_engine.load_rating_state(
                        EntityType.TRAINER,
                        starter.trainer_id,
                        latest.rating,
                        latest.rd,
                        last_race_date=(
                            latest.race.meeting.meeting_date
                            if latest.race and latest.race.meeting
                            else None
                        ),
                    )

        # Compute effective ratings for all starters
        effective_ratings = {}
        rating_uncertainties = {}

        for starter in starters:
            if not starter.horse_id:
                continue

            r_eff = self.rating_engine.compute_effective_rating(starter, race)
            effective_ratings[starter.id] = r_eff

            # Get rating uncertainty (RD) for confidence interval
            horse_state = self.rating_engine.get_or_init_rating(
                EntityType.HORSE, starter.horse_id
            )
            rating_uncertainties[starter.id] = horse_state.rd or 100.0

        if not effective_ratings:
            logger.warning(f"No ratings available for race {race.id}")
            return RacePrediction(
                race_id=race.id,
                race_number=race.race_number,
                venue=race.meeting.venue if race.meeting else None,
                distance_m=race.distance_m,
                race_date=race.meeting.meeting_date if race.meeting else None,
                predictions=[],
                prediction_timestamp=datetime.now(),
                metadata={},
            )

        # Compute win probabilities using softmax
        win_probs = self._compute_win_probabilities(effective_ratings)

        # Compute place scores/probabilities (top 3)
        place_scores = self._compute_place_scores(race, starters, effective_ratings)
        place_probs = self._compute_place_probabilities(place_scores, top_n=3)

        # Generate predictions for each starter
        predictions = []
        place_ranks = self._compute_place_ranks(place_scores)
        for starter in starters:
            if starter.id not in effective_ratings:
                continue

            r_eff = effective_ratings[starter.id]
            win_prob = win_probs.get(starter.id, 0.0)
            place_prob = place_probs.get(starter.id, 0.0)
            place_score = place_scores.get(starter.id, r_eff)
            rd = rating_uncertainties.get(starter.id, 100.0)

            # Compute confidence interval (95% CI = ±1.96 * RD)
            ci_low = r_eff - 1.96 * rd
            ci_high = r_eff + 1.96 * rd

            predicted_placing = place_ranks.get(starter.id, len(place_ranks))

            predictions.append(
                PredictionResult(
                    starter_id=starter.id,
                    horse_id=starter.horse_id,
                    horse_name=starter.horse.name if starter.horse else None,
                    driver_id=starter.driver_id,
                    driver_name=starter.driver.name if starter.driver else None,
                    trainer_id=starter.trainer_id,
                    trainer_name=starter.trainer.name if starter.trainer else None,
                    barrier=starter.barrier,
                    handicap_m=starter.handicap_m,
                    effective_rating=r_eff,
                    win_probability=win_prob,
                    place_probability=place_prob,
                    place_score=place_score,
                    confidence_interval_low=ci_low,
                    confidence_interval_high=ci_high,
                    predicted_placing=predicted_placing,
                )
            )

        # Sort by predicted placing
        predictions.sort(key=lambda p: p.predicted_placing)

        metadata = {
            "field_size": len(predictions),
            "avg_rating": (
                sum(p.effective_rating for p in predictions) / len(predictions)
                if predictions
                else 0
            ),
            "rating_spread": (
                max(p.effective_rating for p in predictions)
                - min(p.effective_rating for p in predictions)
                if predictions
                else 0
            ),
            "avg_place_score": (
                sum(p.place_score for p in predictions) / len(predictions)
                if predictions
                else 0
            ),
        }

        return RacePrediction(
            race_id=race.id,
            race_number=race.race_number,
            venue=race.meeting.venue if race.meeting else None,
            distance_m=race.distance_m,
            race_date=race.meeting.meeting_date if race.meeting else None,
            predictions=predictions,
            prediction_timestamp=datetime.now(),
            metadata=metadata,
        )

    def _compute_win_probabilities(
        self, effective_ratings: dict[int, float]
    ) -> dict[int, float]:
        """Compute win probabilities using softmax.

        Args:
            effective_ratings: Dictionary of starter_id -> effective rating

        Returns:
            Dictionary of starter_id -> win probability
        """
        if not effective_ratings:
            return {}

        # Numerically stable softmax
        rating_values = list(effective_ratings.values())
        max_rating = max(rating_values)

        # Scale factor (400 = standard Elo scale)
        scale = 400.0

        # Compute exp(rating/scale) for all starters
        exp_ratings = {
            sid: math.exp((rating - max_rating) / scale)
            for sid, rating in effective_ratings.items()
        }

        total_exp = sum(exp_ratings.values())

        # Normalize to probabilities
        return {sid: exp_val / total_exp for sid, exp_val in exp_ratings.items()}

    def _compute_place_probabilities(
        self, place_scores: dict[int, float], top_n: int = 3
    ) -> dict[int, float]:
        """Compute place probabilities (finishing in top N).

        Args:
            place_scores: Dictionary of starter_id -> place score
            top_n: Number of top placings to consider

        Returns:
            Dictionary of starter_id -> place probability
        """
        if not place_scores or len(place_scores) <= top_n:
            return dict.fromkeys(place_scores, 1.0)

        base_probs = self._compute_win_probabilities(place_scores)
        return {sid: min(1.0, prob * (top_n + 0.5)) for sid, prob in base_probs.items()}

    def _compute_place_scores(
        self,
        race: Race,
        starters: list[Starter],
        effective_ratings: dict[int, float],
    ) -> dict[int, float]:
        """Compute place scores from ratings and consistency history."""
        scores = {}
        scale = self.rating_engine.settings.elo_scale_c

        for starter in starters:
            if starter.id not in effective_ratings or not starter.horse_id:
                continue

            history = self._get_recent_finish_stats(race, starter.horse_id)
            top3_rate = history["top3_rate"]
            consistency = history["consistency"]

            rating_adjustment = (
                self.place_top3_weight * (top3_rate - self.place_prior_rate) * scale
                + self.place_consistency_weight * (consistency - 0.5) * scale
            )

            scores[starter.id] = effective_ratings[starter.id] + rating_adjustment

        return scores

    def _compute_place_ranks(self, place_scores: dict[int, float]) -> dict[int, int]:
        """Convert place scores into ordinal ranks."""
        ordered = sorted(place_scores.items(), key=lambda item: item[1], reverse=True)
        return {starter_id: idx + 1 for idx, (starter_id, _) in enumerate(ordered)}

    def _get_recent_finish_stats(self, race: Race, horse_id: int) -> dict[str, float]:
        """Compute smoothed top-3 rate and consistency from recent finishes."""
        from packages.core.storage.models import Meeting

        field_size_subquery = (
            self.session.query(
                Starter.race_id.label("race_id"),
                func.count(Starter.id).label("field_size"),
            )
            .group_by(Starter.race_id)
            .subquery()
        )

        query = (
            self.session.query(
                Starter.placing,
                field_size_subquery.c.field_size,
                Race.race_datetime,
                Race.race_number,
                Meeting.meeting_date,
            )
            .join(Race, Starter.race_id == Race.id)
            .join(Meeting, Race.meeting_id == Meeting.id)
            .join(field_size_subquery, Starter.race_id == field_size_subquery.c.race_id)
            .filter(
                Starter.horse_id == horse_id,
                Starter.placing.isnot(None),
                Starter.did_not_finish.is_(False),
            )
        )

        if race.race_datetime:
            query = query.filter(Race.race_datetime < race.race_datetime)
            query = query.order_by(Race.race_datetime.desc(), Race.race_number.desc())
        elif race.meeting and race.meeting.meeting_date:
            query = query.filter(
                or_(
                    Meeting.meeting_date < race.meeting.meeting_date,
                    and_(
                        Meeting.meeting_date == race.meeting.meeting_date,
                        Race.race_number < race.race_number,
                    ),
                )
            )
            query = query.order_by(Meeting.meeting_date.desc(), Race.race_number.desc())
        else:
            query = query.order_by(Race.id.desc())

        rows = query.limit(self.place_history_limit).all()

        if not rows:
            return {"top3_rate": self.place_prior_rate, "consistency": 0.5}

        top3_count = 0
        percentiles = []
        for placing, field_size, _, _, _ in rows:
            field_size = field_size or 1
            if placing <= 3:
                top3_count += 1
            if field_size <= 1:
                percentiles.append(1.0)
            else:
                percentiles.append(1.0 - (placing - 1) / (field_size - 1))

        sample_count = len(rows)
        top3_rate = (top3_count + self.place_prior_rate * self.place_prior_weight) / (
            sample_count + self.place_prior_weight
        )

        if sample_count < 2:
            consistency = 0.5
        else:
            mean = sum(percentiles) / sample_count
            variance = sum((p - mean) ** 2 for p in percentiles) / sample_count
            stddev = math.sqrt(variance)
            consistency = 1.0 - min(1.0, stddev / 0.5)

        return {"top3_rate": top3_rate, "consistency": consistency}

    def get_upcoming_races(self, race_date: date | None = None) -> list[Race]:
        """Get upcoming races for prediction.

        Args:
            race_date: Date to get races for (defaults to today)

        Returns:
            List of races
        """
        if race_date is None:
            race_date = date.today()

        from packages.core.storage.models import Meeting

        races = (
            self.session.query(Race)
            .join(Race.meeting)
            .filter(Meeting.meeting_date == race_date)
            .order_by(Race.race_number)
            .all()
        )

        return races

    def compare_prediction_to_actual(self, race_id: int) -> dict | None:
        """Compare prediction to actual result for completed race.

        Args:
            race_id: Race ID

        Returns:
            Comparison dictionary with prediction accuracy metrics
        """
        race = self.session.query(Race).filter(Race.id == race_id).first()
        if not race:
            return None

        starters = race.starters
        if not starters:
            return None

        # Generate prediction
        prediction = self.predict_race(race, starters)

        # Compare to actual results
        starter_by_id = {starter.id: starter for starter in starters}
        actual_winner_id = None
        actual_top3_ids = []

        for starter in starters:
            if starter.placing and not starter.did_not_finish:
                if starter.placing == 1:
                    actual_winner_id = starter.id
                if starter.placing <= 3:
                    actual_top3_ids.append(starter.id)

        if not actual_winner_id:
            return None  # Race not completed

        # Find predicted winner
        predicted_winner = max(prediction.predictions, key=lambda p: p.win_probability)

        # Check if prediction was correct
        winner_correct = predicted_winner.starter_id == actual_winner_id

        # Check top-3 overlap
        predicted_top3_ids = [
            p.starter_id
            for p in sorted(prediction.predictions, key=lambda p: p.predicted_placing)[
                :3
            ]
        ]
        top3_overlap = len(set(predicted_top3_ids) & set(actual_top3_ids))

        # Calculate Brier score for winner prediction
        brier_score = sum(
            (p.win_probability - (1.0 if p.starter_id == actual_winner_id else 0.0))
            ** 2
            for p in prediction.predictions
        ) / len(prediction.predictions)

        predictions_with_actuals = []
        for pred in prediction.predictions:
            starter = starter_by_id.get(pred.starter_id)
            actual_placing = None
            if starter and starter.placing and not starter.did_not_finish:
                actual_placing = starter.placing
            predictions_with_actuals.append(
                {
                    "starter_id": pred.starter_id,
                    "horse_id": pred.horse_id,
                    "horse_name": pred.horse_name,
                    "driver_id": pred.driver_id,
                    "driver_name": pred.driver_name,
                    "trainer_id": pred.trainer_id,
                    "trainer_name": pred.trainer_name,
                    "barrier": pred.barrier,
                    "handicap_m": pred.handicap_m,
                    "effective_rating": pred.effective_rating,
                    "win_probability": pred.win_probability,
                    "place_probability": pred.place_probability,
                    "place_score": pred.place_score,
                    "ci_lower": pred.confidence_interval_low,
                    "ci_upper": pred.confidence_interval_high,
                    "predicted_placing": pred.predicted_placing,
                    "actual_placing": actual_placing,
                }
            )

        return {
            "race_id": race_id,
            "race_number": race.race_number,
            "venue": race.meeting.venue if race.meeting else None,
            "race_date": (
                race.meeting.meeting_date.isoformat() if race.meeting else None
            ),
            "winner_correct": winner_correct,
            "predicted_winner_id": predicted_winner.starter_id,
            "actual_winner_id": actual_winner_id,
            "top3_overlap": top3_overlap,
            "top3_overlap_rate": top3_overlap / 3.0,
            "brier_score": brier_score,
            "field_size": len(prediction.predictions),
            "predictions": predictions_with_actuals,
        }


def export_predictions_csv(predictions: list[RacePrediction], output_file: str) -> None:
    """Export predictions to CSV file.

    Args:
        predictions: List of race predictions
        output_file: Output CSV file path
    """
    import csv

    with open(output_file, "w", newline="") as f:
        writer = csv.writer(f)

        # Header
        writer.writerow(
            [
                "Race ID",
                "Race Number",
                "Venue",
                "Distance (m)",
                "Starter ID",
                "Horse Name",
                "Driver Name",
                "Barrier",
                "Handicap (m)",
                "Effective Rating",
                "Win Probability",
                "Place Probability",
                "Predicted Placing",
                "CI Low",
                "CI High",
            ]
        )

        # Data rows
        for race_pred in predictions:
            for pred in race_pred.predictions:
                writer.writerow(
                    [
                        race_pred.race_id,
                        race_pred.race_number or "",
                        race_pred.venue or "",
                        race_pred.distance_m or "",
                        pred.starter_id,
                        pred.horse_name or "",
                        pred.driver_name or "",
                        pred.barrier or "",
                        pred.handicap_m or "",
                        f"{pred.effective_rating:.1f}",
                        f"{pred.win_probability:.3f}",
                        f"{pred.place_probability:.3f}",
                        pred.predicted_placing,
                        f"{pred.confidence_interval_low:.1f}",
                        f"{pred.confidence_interval_high:.1f}",
                    ]
                )

    logger.info(f"Predictions exported to {output_file}")
