"""Advanced analytics and metrics for HarnessElo ratings system.

Provides detailed analysis by venue, distance, gait, and other dimensions.
"""

import json
import math
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import date

from packages.common.logging import get_logger
from packages.common.utils import get_distance_bucket
from packages.ratings.engine import RatingEngine
from packages.storage.models import EntityType, Meeting, Race, Starter
from packages.storage.repositories import RatingSnapshotRepository
from sqlalchemy import func
from sqlalchemy.orm import Session

logger = get_logger(__name__)


@dataclass
class VenueMetrics:
    """Metrics for a specific venue."""

    venue: str
    race_count: int
    avg_field_size: float
    winner_accuracy: float
    top3_hit_rate: float
    avg_rating_spread: float  # Difference between top and bottom rated
    favorite_win_rate: float  # How often top-rated wins


@dataclass
class DistanceMetrics:
    """Metrics for a specific distance bucket."""

    distance_bucket: str
    race_count: int
    avg_field_size: float
    winner_accuracy: float
    barrier_bias: dict[int, float]  # Average placing by barrier


@dataclass
class EntityPerformanceMetrics:
    """Performance metrics for an entity (horse/driver/trainer)."""

    entity_type: str
    entity_id: int
    entity_name: str | None
    race_count: int
    win_count: int
    place_count: int  # Top 3
    win_rate: float
    place_rate: float
    avg_placing: float
    rating_volatility: float  # Std dev of rating changes
    current_rating: float | None


@dataclass
class BrierScoreAnalysis:
    """Brier score for probability calibration analysis."""

    overall_brier_score: float
    by_probability_bin: dict[str, dict]  # Binned calibration analysis
    perfect_score: float = 0.0  # Perfect Brier score is 0


@dataclass
class AdvancedAnalyticsReport:
    """Comprehensive advanced analytics report."""

    start_date: date
    end_date: date
    venue_metrics: list[VenueMetrics] = field(default_factory=list)
    distance_metrics: list[DistanceMetrics] = field(default_factory=list)
    top_performers: dict[str, list[EntityPerformanceMetrics]] = field(
        default_factory=dict
    )
    brier_score: BrierScoreAnalysis | None = None
    rating_correlation: dict[str, float] = field(default_factory=dict)
    overall_metrics: dict = field(default_factory=dict)


class AdvancedAnalyzer:
    """Advanced analytics engine."""

    def __init__(self, session: Session):
        """Initialize analyzer.

        Args:
            session: Database session
        """
        self.session = session
        self.engine = RatingEngine(session)

    def analyze_venue_performance(
        self, from_date: date, to_date: date
    ) -> list[VenueMetrics]:
        """Analyze prediction performance by venue.

        Args:
            from_date: Start date
            to_date: End date

        Returns:
            List of venue metrics
        """
        logger.info("Analyzing venue performance")

        # Get all meetings in range
        meetings = (
            self.session.query(Meeting)
            .filter(
                Meeting.meeting_date >= from_date,
                Meeting.meeting_date <= to_date,
            )
            .all()
        )

        venue_stats = defaultdict(
            lambda: {
                "race_count": 0,
                "total_starters": 0,
                "top_rated_wins": 0,
                "top3_overlaps": 0,
                "total_rating_spread": 0,
            }
        )

        for meeting in meetings:
            venue = meeting.venue
            for race in meeting.races:
                starters = [
                    s for s in race.starters if s.placing and not s.did_not_finish
                ]
                if len(starters) < 2:
                    continue

                # Get ratings at time of race
                ratings = {}
                for starter in starters:
                    if starter.horse_id:
                        snapshot = RatingSnapshotRepository.get_latest_for_entity(
                            self.session,
                            EntityType.HORSE,
                            starter.horse_id,
                            before_race_id=race.id,
                        )
                        if snapshot:
                            ratings[starter.id] = snapshot.rating

                if not ratings:
                    continue

                venue_stats[venue]["race_count"] += 1
                venue_stats[venue]["total_starters"] += len(starters)

                # Check if top-rated won
                top_rated_id = max(ratings, key=ratings.get)
                winner = min(starters, key=lambda s: s.placing)
                if top_rated_id == winner.id:
                    venue_stats[venue]["top_rated_wins"] += 1

                # Calculate rating spread
                if ratings:
                    spread = max(ratings.values()) - min(ratings.values())
                    venue_stats[venue]["total_rating_spread"] += spread

        # Convert to metrics objects
        metrics = []
        for venue, stats in venue_stats.items():
            if stats["race_count"] == 0:
                continue

            metrics.append(
                VenueMetrics(
                    venue=venue,
                    race_count=stats["race_count"],
                    avg_field_size=stats["total_starters"] / stats["race_count"],
                    winner_accuracy=stats["top_rated_wins"] / stats["race_count"],
                    top3_hit_rate=0.0,  # TODO: Implement
                    avg_rating_spread=stats["total_rating_spread"]
                    / stats["race_count"],
                    favorite_win_rate=stats["top_rated_wins"] / stats["race_count"],
                )
            )

        return sorted(metrics, key=lambda m: m.race_count, reverse=True)

    def analyze_distance_performance(
        self, from_date: date, to_date: date
    ) -> list[DistanceMetrics]:
        """Analyze performance by distance bucket.

        Args:
            from_date: Start date
            to_date: End date

        Returns:
            List of distance metrics
        """
        logger.info("Analyzing distance performance")

        # Get all races in range
        races = (
            self.session.query(Race)
            .join(Race.meeting)
            .filter(
                Meeting.meeting_date >= from_date,
                Meeting.meeting_date <= to_date,
            )
            .all()
        )

        distance_stats = defaultdict(
            lambda: {
                "race_count": 0,
                "total_starters": 0,
                "top_rated_wins": 0,
                "barrier_placings": defaultdict(list),
            }
        )

        for race in races:
            if not race.distance_m:
                continue

            bucket = get_distance_bucket(race.distance_m)
            starters = [s for s in race.starters if s.placing and not s.did_not_finish]

            if len(starters) < 2:
                continue

            distance_stats[bucket]["race_count"] += 1
            distance_stats[bucket]["total_starters"] += len(starters)

            # Track barrier performance
            for starter in starters:
                if starter.barrier and starter.placing:
                    distance_stats[bucket]["barrier_placings"][starter.barrier].append(
                        starter.placing
                    )

        # Convert to metrics
        metrics = []
        for bucket, stats in distance_stats.items():
            if stats["race_count"] == 0:
                continue

            # Calculate barrier bias
            barrier_bias = {}
            for barrier, placings in stats["barrier_placings"].items():
                if placings:
                    barrier_bias[barrier] = sum(placings) / len(placings)

            metrics.append(
                DistanceMetrics(
                    distance_bucket=bucket,
                    race_count=stats["race_count"],
                    avg_field_size=stats["total_starters"] / stats["race_count"],
                    winner_accuracy=0.0,  # TODO: Implement
                    barrier_bias=barrier_bias,
                )
            )

        return sorted(metrics, key=lambda m: m.race_count, reverse=True)

    def calculate_brier_score(
        self, from_date: date, to_date: date
    ) -> BrierScoreAnalysis:
        """Calculate Brier score for probability accuracy.

        The Brier score measures the accuracy of probabilistic predictions.
        Lower is better, with 0 being perfect.

        Args:
            from_date: Start date
            to_date: End date

        Returns:
            Brier score analysis
        """
        logger.info("Calculating Brier score")

        # Get all races
        races = (
            self.session.query(Race)
            .join(Race.meeting)
            .filter(
                Meeting.meeting_date >= from_date,
                Meeting.meeting_date <= to_date,
            )
            .all()
        )

        squared_errors = []
        prob_bins = defaultdict(lambda: {"predictions": [], "outcomes": []})

        for race in races:
            starters = [s for s in race.starters if s.placing and not s.did_not_finish]
            if len(starters) < 2:
                continue

            # Get effective ratings for all starters
            effective_ratings = {}
            for starter in starters:
                r_eff = self.engine.compute_effective_rating(starter, race)
                effective_ratings[starter.id] = r_eff

            # Compute win probabilities using softmax
            rating_values = list(effective_ratings.values())
            max_rating = max(rating_values)

            # Numerically stable softmax
            exp_ratings = [math.exp((r - max_rating) / 400.0) for r in rating_values]
            total_exp = sum(exp_ratings)

            probabilities = {
                sid: exp_val / total_exp
                for sid, exp_val in zip(
                    effective_ratings.keys(), exp_ratings, strict=False
                )
            }

            # Calculate Brier score for this race
            for starter in starters:
                if starter.id not in probabilities:
                    continue

                prob = probabilities[starter.id]
                actual = 1.0 if starter.placing == 1 else 0.0

                squared_error = (prob - actual) ** 2
                squared_errors.append(squared_error)

                # Bin by probability
                bin_key = f"{int(prob * 10) * 10}-{int(prob * 10) * 10 + 10}%"
                prob_bins[bin_key]["predictions"].append(prob)
                prob_bins[bin_key]["outcomes"].append(actual)

        # Overall Brier score
        overall_brier = (
            sum(squared_errors) / len(squared_errors) if squared_errors else 1.0
        )

        # Binned analysis
        binned_analysis = {}
        for bin_key, data in prob_bins.items():
            if data["predictions"]:
                binned_analysis[bin_key] = {
                    "avg_predicted_prob": sum(data["predictions"])
                    / len(data["predictions"]),
                    "actual_win_rate": sum(data["outcomes"]) / len(data["outcomes"]),
                    "count": len(data["predictions"]),
                }

        return BrierScoreAnalysis(
            overall_brier_score=overall_brier,
            by_probability_bin=binned_analysis,
        )

    def analyze_top_performers(
        self, from_date: date, to_date: date, top_n: int = 20
    ) -> dict[str, list[EntityPerformanceMetrics]]:
        """Identify top performing entities.

        Args:
            from_date: Start date
            to_date: End date
            top_n: Number of top performers to return

        Returns:
            Dictionary with top horses, drivers, trainers
        """
        logger.info("Analyzing top performers")

        result = {}

        for entity_type in [EntityType.HORSE, EntityType.DRIVER, EntityType.TRAINER]:
            # Get all starters for this entity type in date range
            if entity_type == EntityType.HORSE:
                entity_id_field = Starter.horse_id
            elif entity_type == EntityType.DRIVER:
                entity_id_field = Starter.driver_id
            else:
                entity_id_field = Starter.trainer_id

            # Aggregate statistics
            stats_query = (
                self.session.query(
                    entity_id_field,
                    func.count(Starter.id).label("race_count"),
                    func.sum(func.cast(Starter.placing == 1, func.Integer())).label(
                        "wins"
                    ),
                    func.sum(func.cast(Starter.placing <= 3, func.Integer())).label(
                        "places"
                    ),
                    func.avg(Starter.placing).label("avg_placing"),
                )
                .join(Starter.race)
                .join(Race.meeting)
                .filter(
                    Meeting.meeting_date >= from_date,
                    Meeting.meeting_date <= to_date,
                    Starter.placing.isnot(None),
                    Starter.did_not_finish.is_(False),
                    entity_id_field.isnot(None),
                )
                .group_by(entity_id_field)
                .having(func.count(Starter.id) >= 10)  # Min 10 starts
                .order_by(
                    func.count(func.cast(Starter.placing == 1, func.Integer())).desc()
                )
                .limit(top_n)
                .all()
            )

            metrics = []
            for entity_id, race_count, wins, places, avg_placing in stats_query:
                win_rate = wins / race_count if race_count > 0 else 0
                place_rate = places / race_count if race_count > 0 else 0

                # Get current rating
                latest_snapshot = RatingSnapshotRepository.get_latest_for_entity(
                    self.session, entity_type, entity_id
                )

                metrics.append(
                    EntityPerformanceMetrics(
                        entity_type=entity_type.value,
                        entity_id=entity_id,
                        entity_name=None,  # TODO: Fetch name
                        race_count=race_count,
                        win_count=wins or 0,
                        place_count=places or 0,
                        win_rate=win_rate,
                        place_rate=place_rate,
                        avg_placing=float(avg_placing) if avg_placing else 0.0,
                        rating_volatility=0.0,  # TODO: Calculate
                        current_rating=(
                            latest_snapshot.rating if latest_snapshot else None
                        ),
                    )
                )

            result[entity_type.value.lower()] = metrics

        return result

    def calculate_rating_correlation(
        self, from_date: date, to_date: date
    ) -> dict[str, float]:
        """Calculate correlation between driver/trainer ratings and outcomes.

        Args:
            from_date: Start date
            to_date: End date

        Returns:
            Correlation coefficients
        """
        # This is a simplified implementation
        # Full implementation would calculate Pearson correlation
        return {
            "driver_impact": 0.35,  # Placeholder
            "trainer_impact": 0.15,  # Placeholder
        }

    def generate_report(
        self, from_date: date, to_date: date
    ) -> AdvancedAnalyticsReport:
        """Generate comprehensive advanced analytics report.

        Args:
            from_date: Start date
            to_date: End date

        Returns:
            Complete analytics report
        """
        logger.info(
            f"Generating advanced analytics report from {from_date} to {to_date}"
        )

        report = AdvancedAnalyticsReport(
            start_date=from_date,
            end_date=to_date,
        )

        # Run all analyses
        report.venue_metrics = self.analyze_venue_performance(from_date, to_date)
        report.distance_metrics = self.analyze_distance_performance(from_date, to_date)
        report.brier_score = self.calculate_brier_score(from_date, to_date)
        report.top_performers = self.analyze_top_performers(from_date, to_date)
        report.rating_correlation = self.calculate_rating_correlation(
            from_date, to_date
        )

        # Overall metrics
        total_races = sum(v.race_count for v in report.venue_metrics)
        avg_accuracy = (
            sum(v.winner_accuracy * v.race_count for v in report.venue_metrics)
            / total_races
            if total_races > 0
            else 0.0
        )

        report.overall_metrics = {
            "total_races": total_races,
            "overall_accuracy": avg_accuracy,
            "brier_score": report.brier_score.overall_brier_score,
            "venues_analyzed": len(report.venue_metrics),
            "distance_buckets_analyzed": len(report.distance_metrics),
        }

        logger.info("Advanced analytics report complete")
        return report


def main():
    """CLI entry point for advanced analytics."""
    import click
    from packages.common.utils import parse_date
    from packages.storage.database import get_session

    @click.command()
    @click.option("--from", "date_from", required=True, help="Start date (YYYY-MM-DD)")
    @click.option("--to", "date_to", required=True, help="End date (YYYY-MM-DD)")
    @click.option("--out", "output_file", required=True, help="Output JSON file")
    def analyze(date_from, date_to, output_file):
        """Generate advanced analytics report."""
        start_date = parse_date(date_from)
        end_date = parse_date(date_to)

        with get_session() as session:
            analyzer = AdvancedAnalyzer(session)
            report = analyzer.generate_report(start_date, end_date)

        # Convert to JSON
        report_data = {
            "start_date": report.start_date.isoformat(),
            "end_date": report.end_date.isoformat(),
            "overall_metrics": report.overall_metrics,
            "venue_metrics": [
                {
                    "venue": v.venue,
                    "race_count": v.race_count,
                    "avg_field_size": round(v.avg_field_size, 2),
                    "winner_accuracy": round(v.winner_accuracy, 3),
                    "avg_rating_spread": round(v.avg_rating_spread, 1),
                    "favorite_win_rate": round(v.favorite_win_rate, 3),
                }
                for v in report.venue_metrics
            ],
            "distance_metrics": [
                {
                    "distance_bucket": d.distance_bucket,
                    "race_count": d.race_count,
                    "avg_field_size": round(d.avg_field_size, 2),
                    "barrier_bias": {k: round(v, 2) for k, v in d.barrier_bias.items()},
                }
                for d in report.distance_metrics
            ],
            "brier_score": (
                {
                    "overall": round(report.brier_score.overall_brier_score, 4),
                    "by_probability_bin": report.brier_score.by_probability_bin,
                }
                if report.brier_score
                else None
            ),
            "top_performers": {
                entity_type: [
                    {
                        "entity_id": p.entity_id,
                        "race_count": p.race_count,
                        "win_count": p.win_count,
                        "win_rate": round(p.win_rate, 3),
                        "place_rate": round(p.place_rate, 3),
                        "avg_placing": round(p.avg_placing, 2),
                        "current_rating": (
                            round(p.current_rating, 1) if p.current_rating else None
                        ),
                    }
                    for p in performers
                ]
                for entity_type, performers in report.top_performers.items()
            },
        }

        with open(output_file, "w") as f:
            json.dump(report_data, f, indent=2)

        print(f"Advanced analytics report saved to {output_file}")
        print(
            f"Overall accuracy: {report.overall_metrics.get('overall_accuracy', 0):.1%}"
        )
        print(f"Brier score: {report.overall_metrics.get('brier_score', 0):.4f}")

    analyze()


if __name__ == "__main__":
    main()
