#!/usr/bin/env python3
"""Evaluation script for HarnessElo ratings system.

Computes:
- Winner accuracy (top-rated wins)
- Top-3 hit rate
- Calibration by predicted probability bins
"""

import argparse
import json
import math
import sys
from collections import defaultdict
from datetime import date
from pathlib import Path

from packages.common.logging import get_logger, setup_logging
from packages.common.settings import get_settings
from packages.common.utils import parse_date
from packages.ratings.engine import RatingEngine
from packages.storage.database import get_session
from packages.storage.models import EntityType, Race, Starter
from packages.storage.repositories import RaceRepository, RatingSnapshotRepository
from sqlalchemy.orm import joinedload

setup_logging()
logger = get_logger(__name__)
settings = get_settings()


class RatingEvaluator:
    """Evaluate rating system performance."""

    def __init__(self):
        """Initialize evaluator."""
        self.engine = RatingEngine()
        self.results = {
            "winner_accuracy": 0.0,
            "top3_hit_rate": 0.0,
            "calibration": {},
            "total_races": 0,
            "races_evaluated": 0,
        }

    def softmax(self, ratings: list[float]) -> list[float]:
        """Compute softmax probabilities from ratings.

        Args:
            ratings: List of ratings

        Returns:
            List of probabilities (sum to 1.0)
        """
        # Use temperature = scale_c for consistency with Elo
        temperature = settings.rating.elo_scale_c

        # Subtract max for numerical stability
        max_rating = max(ratings)
        exp_ratings = [math.exp((r - max_rating) / temperature) for r in ratings]
        total = sum(exp_ratings)

        return [e / total for e in exp_ratings]

    def evaluate_race(self, session, race: Race, starters: list[Starter]) -> dict:
        """Evaluate predictions for a single race.

        Args:
            session: Database session
            race: Race instance
            starters: List of starters

        Returns:
            Race evaluation metrics
        """
        # Filter valid finishers
        finishers = [
            s for s in starters if s.placing is not None and not s.did_not_finish
        ]

        if len(finishers) < 2:
            return None

        # Load ratings before this race
        ratings = []
        for starter in finishers:
            if not starter.horse_id:
                return None

            # Get rating before this race
            snapshot = RatingSnapshotRepository.get_latest_rating(
                session, EntityType.HORSE, starter.horse_id, before_race_id=race.id
            )

            if snapshot:
                self.engine.load_rating_state(
                    EntityType.HORSE,
                    starter.horse_id,
                    snapshot.rating,
                    snapshot.rd,
                )

            # Compute effective rating
            r_eff = self.engine.compute_effective_rating(starter, race)
            ratings.append(r_eff)

        # Compute win probabilities
        probabilities = self.softmax(ratings)

        # Find actual winner and top 3
        actual_winner_idx = min(
            range(len(finishers)), key=lambda i: finishers[i].placing
        )
        actual_top3 = {i for i, s in enumerate(finishers) if s.placing <= 3}

        # Find predicted winner
        predicted_winner_idx = max(range(len(ratings)), key=lambda i: ratings[i])

        # Find predicted top 3
        top3_indices = sorted(
            range(len(ratings)), key=lambda i: ratings[i], reverse=True
        )[:3]
        predicted_top3 = set(top3_indices)

        return {
            "winner_correct": predicted_winner_idx == actual_winner_idx,
            "top3_overlap": len(actual_top3 & predicted_top3),
            "win_probability": probabilities[actual_winner_idx],
            "predicted_winner_prob": probabilities[predicted_winner_idx],
        }

    def evaluate_date_range(self, session, date_from: date, date_to: date) -> dict:
        """Evaluate all races in date range.

        Args:
            session: Database session
            date_from: Start date
            date_to: End date

        Returns:
            Evaluation results
        """
        logger.info(f"Evaluating races from {date_from} to {date_to}")

        races = RaceRepository.get_races_for_recompute(session, date_from, date_to)
        logger.info(f"Found {len(races)} races to evaluate")

        winner_correct_count = 0
        top3_hit_count = 0
        races_evaluated = 0

        # Calibration buckets
        calibration_bins = defaultdict(
            lambda: {"predicted": 0.0, "actual": 0, "count": 0}
        )
        bin_edges = [0.0, 0.05, 0.10, 0.15, 0.20, 0.30, 0.40, 0.50, 1.0]

        for race in races:
            # Load starters
            starters = (
                session.query(Starter)
                .filter(Starter.race_id == race.id)
                .options(
                    joinedload(Starter.horse),
                    joinedload(Starter.driver),
                    joinedload(Starter.trainer),
                )
                .all()
            )

            # Load meeting
            if not race.meeting:
                session.refresh(race, ["meeting"])

            result = self.evaluate_race(session, race, starters)

            if result is None:
                continue

            races_evaluated += 1

            # Winner accuracy
            if result["winner_correct"]:
                winner_correct_count += 1

            # Top-3 hit rate
            if result["top3_overlap"] > 0:
                top3_hit_count += 1

            # Calibration
            win_prob = result["win_probability"]
            for i in range(len(bin_edges) - 1):
                if bin_edges[i] <= win_prob < bin_edges[i + 1]:
                    bin_key = f"{bin_edges[i]:.2f}-{bin_edges[i+1]:.2f}"
                    calibration_bins[bin_key]["predicted"] += win_prob
                    calibration_bins[bin_key]["actual"] += 1  # Horse won
                    calibration_bins[bin_key]["count"] += 1
                    break

        # Compute final metrics
        if races_evaluated > 0:
            self.results["winner_accuracy"] = winner_correct_count / races_evaluated
            self.results["top3_hit_rate"] = top3_hit_count / races_evaluated
        else:
            self.results["winner_accuracy"] = 0.0
            self.results["top3_hit_rate"] = 0.0

        self.results["races_evaluated"] = races_evaluated
        self.results["total_races"] = len(races)

        # Format calibration
        calibration = {}
        for bin_key, data in sorted(calibration_bins.items()):
            if data["count"] > 0:
                avg_predicted = data["predicted"] / data["count"]
                actual_win_rate = data["actual"] / data["count"]
                calibration[bin_key] = {
                    "predicted_avg": round(avg_predicted, 4),
                    "actual_win_rate": round(actual_win_rate, 4),
                    "sample_size": data["count"],
                    "error": round(abs(avg_predicted - actual_win_rate), 4),
                }

        self.results["calibration"] = calibration

        return self.results


def main():
    """Main evaluation script."""
    parser = argparse.ArgumentParser(description="Evaluate HarnessElo ratings")
    parser.add_argument(
        "--from", dest="date_from", required=True, help="Start date (YYYY-MM-DD)"
    )
    parser.add_argument(
        "--to", dest="date_to", required=True, help="End date (YYYY-MM-DD)"
    )
    parser.add_argument("--out", required=True, help="Output file path (JSON)")

    args = parser.parse_args()

    start_date = parse_date(args.date_from)
    end_date = parse_date(args.date_to)
    output_path = Path(args.out)

    # Ensure output directory exists
    output_path.parent.mkdir(parents=True, exist_ok=True)

    logger.info(f"Evaluating from {start_date} to {end_date}")

    with get_session() as session:
        evaluator = RatingEvaluator()
        results = evaluator.evaluate_date_range(session, start_date, end_date)

    # Save JSON results
    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)

    logger.info(f"Results saved to {output_path}")

    # Print summary to console
    print("\n" + "=" * 60)
    print("HarnessElo Evaluation Results")
    print("=" * 60)
    print(f"Date range: {start_date} to {end_date}")
    print(f"Races evaluated: {results['races_evaluated']} / {results['total_races']}")
    print(f"\nWinner Accuracy: {results['winner_accuracy']:.2%}")
    print(f"Top-3 Hit Rate: {results['top3_hit_rate']:.2%}")

    print("\nCalibration by Probability Bin:")
    print("-" * 60)
    print(f"{'Bin':<15} {'Predicted':<12} {'Actual':<12} {'Error':<10} {'N':<8}")
    print("-" * 60)

    for bin_key, data in sorted(results["calibration"].items()):
        print(
            f"{bin_key:<15} "
            f"{data['predicted_avg']:<12.4f} "
            f"{data['actual_win_rate']:<12.4f} "
            f"{data['error']:<10.4f} "
            f"{data['sample_size']:<8}"
        )

    print("=" * 60)

    # Create markdown summary
    md_path = output_path.with_suffix(".md")
    with open(md_path, "w") as f:
        f.write("# HarnessElo Evaluation Report\n\n")
        f.write(f"**Date Range:** {start_date} to {end_date}\n\n")
        f.write(
            f"**Races Evaluated:** {results['races_evaluated']} / {results['total_races']}\n\n"
        )
        f.write("## Summary Metrics\n\n")
        f.write(f"- **Winner Accuracy:** {results['winner_accuracy']:.2%}\n")
        f.write(f"- **Top-3 Hit Rate:** {results['top3_hit_rate']:.2%}\n\n")
        f.write("## Calibration\n\n")
        f.write(
            "| Probability Bin | Predicted Avg | Actual Win Rate | Error | Sample Size |\n"
        )
        f.write(
            "|-----------------|---------------|-----------------|-------|-------------|\n"
        )

        for bin_key, data in sorted(results["calibration"].items()):
            f.write(
                f"| {bin_key} | "
                f"{data['predicted_avg']:.4f} | "
                f"{data['actual_win_rate']:.4f} | "
                f"{data['error']:.4f} | "
                f"{data['sample_size']} |\n"
            )

    logger.info(f"Markdown summary saved to {md_path}")

    return 0


if __name__ == "__main__":
    sys.exit(main())
