#!/usr/bin/env python3
"""Analyze prediction performance by gait (Pace vs Trot).

Computes winner accuracy, top-3 hit rate, Brier score, and field size
for each gait type and gait+distance-bucket combinations over a date range.

Usage:
    python scripts/per_gait_analysis.py --from 2025-01-01 --to 2025-12-31
    python scripts/per_gait_analysis.py --out report.json
"""

from __future__ import annotations

import json
import os
import sys
import time
from datetime import datetime
from typing import Any

import click
from rich.console import Console
from rich.table import Table
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from packages.core.common.utils import get_distance_bucket
from packages.core.ratings.predictions import PredictionEngine
from packages.core.storage.models import Meeting, Race, Starter

_DISTANCE_THRESHOLDS = [1600, 1800, 2000, 2200, 2400]
console = Console()


def _parse_date(value: str | None) -> datetime | None:
    if not value:
        return None
    return datetime.fromisoformat(value)


def _distance_bucket(distance_m: int | None) -> str:
    return get_distance_bucket(distance_m, _DISTANCE_THRESHOLDS, mode="thresholds")


def _evaluate(
    session: Session,
    date_from: datetime | None,
    date_to: datetime | None,
) -> dict[str, Any]:
    """Compute per-gait and per-gait+distance metrics."""
    # Fetch all completed races in date range with their gait
    query = (
        session.query(Race)
        .join(Race.meeting)
        .join(Race.starters)
        .filter(Starter.placing.isnot(None))
        .filter(Race.gait.isnot(None))
        .distinct()
    )
    if date_from:
        query = query.filter(Meeting.meeting_date >= date_from.date())
    if date_to:
        query = query.filter(Meeting.meeting_date <= date_to.date())

    races = query.order_by(Meeting.meeting_date, Race.race_datetime).all()

    if not races:
        return {"error": "no_races", "detail": "No completed races found in date range"}

    engine = PredictionEngine(session)

    # Accumulators: {key: [winner_acc_sum, top3_sum, brier_sum, field_size_sum, count]}
    gait_acc: dict[str, list[float]] = {}
    combo_acc: dict[str, list[float]] = {}

    start_time = time.time()
    for idx, race in enumerate(races, start=1):
        starters = [s for s in race.starters if s.placing is not None]
        if not starters:
            continue

        res = engine.predict_race(race, starters)
        if not res or not res.predictions:
            continue

        winner = next((s for s in starters if s.placing == 1), None)
        if not winner:
            continue

        prob_map = {p.starter_id: p.win_probability for p in res.predictions}
        predicted_winner = max(res.predictions, key=lambda p: p.win_probability)
        winner_acc = 1.0 if predicted_winner.starter_id == winner.id else 0.0

        top3_pred = {
            p.starter_id
            for p in sorted(
                res.predictions, key=lambda p: p.win_probability, reverse=True
            )[:3]
        }
        top3_actual = {
            s.id for s in sorted(starters, key=lambda s: (s.placing or 999))[:3]
        }
        top3_overlap = len(top3_pred & top3_actual) / 3.0

        p_win = prob_map.get(winner.id, 0.0)
        brier = (1.0 - p_win) ** 2

        field_size = len(starters)
        gait = (race.gait or "Unknown").strip().lower()
        dist_bucket = _distance_bucket(race.distance_m)  # type: ignore[arg-type]
        combo_key = f"{gait}|{dist_bucket}"

        if gait not in gait_acc:
            gait_acc[gait] = [0.0, 0.0, 0.0, 0.0, 0.0]
        gait_acc[gait][0] += winner_acc
        gait_acc[gait][1] += top3_overlap
        gait_acc[gait][2] += brier
        gait_acc[gait][3] += field_size
        gait_acc[gait][4] += 1.0

        if combo_key not in combo_acc:
            combo_acc[combo_key] = [0.0, 0.0, 0.0, 0.0, 0.0]
        combo_acc[combo_key][0] += winner_acc
        combo_acc[combo_key][1] += top3_overlap
        combo_acc[combo_key][2] += brier
        combo_acc[combo_key][3] += field_size
        combo_acc[combo_key][4] += 1.0

        if idx % 500 == 0:
            elapsed = time.time() - start_time
            print(
                f"  processed={idx}/{len(races)} elapsed_s={elapsed:.1f}",
                file=sys.stderr,
            )

    def _to_metrics(acc: list[float]) -> dict[str, Any] | None:
        count = int(acc[4])
        if count == 0:
            return None
        return {
            "races": count,
            "winner_accuracy": round(acc[0] / count, 4),
            "top3_accuracy": round(acc[1] / count, 4),
            "avg_brier_score": round(acc[2] / count, 4),
            "avg_field_size": round(acc[3] / count, 1),
        }

    gait_results: dict[str, Any] = {}
    for g, acc in sorted(gait_acc.items()):
        gait_results[g] = _to_metrics(acc)

    combo_results: dict[str, Any] = {}
    for ck, acc in sorted(combo_acc.items()):
        combo_results[ck] = _to_metrics(acc)

    return {
        "date_from": date_from.isoformat() if date_from else None,
        "date_to": date_to.isoformat() if date_to else None,
        "total_races": len(races),
        "by_gait": gait_results,
        "by_gait_distance": combo_results,
    }


@click.command()
@click.option("--from", "date_from", default=None, help="Start date (YYYY-MM-DD)")
@click.option("--to", "date_to", default=None, help="End date (YYYY-MM-DD)")
@click.option("--out", "output", default=None, help="Write JSON report to file")
@click.option("--db-url", default=None, help="Database URL (default: $DATABASE_URL)")
def cli(
    date_from: str | None, date_to: str | None, output: str | None, db_url: str | None
):
    """Analyze prediction performance by gait (Pace vs Trot).

    Computes winner accuracy, top-3 hit rate, average Brier score, and
    average field size for each gait type, plus gait+distance-bucket combos.
    """
    db_url = db_url or os.getenv("DATABASE_URL")
    if not db_url:
        console.print("[red]ERROR: --db-url or DATABASE_URL is required[/]")
        raise SystemExit(1)

    parsed_from = _parse_date(date_from)
    parsed_to = _parse_date(date_to)

    engine = create_engine(db_url)
    with Session(engine) as session:
        results = _evaluate(session, parsed_from, parsed_to)

    if "error" in results:
        console.print(f"[yellow]Warning: {results.get('detail', results['error'])}[/]")
        if output:
            with open(output, "w") as f:
                json.dump(results, f, indent=2)
        return

    # ── Rich table: by gait ──
    gait_table = Table(title="Performance by Gait", title_style="bold cyan")
    gait_table.add_column("Gait", style="cyan")
    gait_table.add_column("Races", justify="right")
    gait_table.add_column("Win Acc", justify="right")
    gait_table.add_column("Top-3 Rate", justify="right")
    gait_table.add_column("Avg Brier", justify="right")
    gait_table.add_column("Avg Field", justify="right")

    for gait, metrics in results["by_gait"].items():
        if metrics is None:
            continue
        gait_table.add_row(
            gait.capitalize(),
            str(metrics["races"]),
            f"{metrics['winner_accuracy']:.1%}",
            f"{metrics['top3_accuracy']:.1%}",
            f"{metrics['avg_brier_score']:.4f}",
            str(metrics["avg_field_size"]),
        )

    console.print(gait_table)

    # ── Rich table: by gait + distance bucket ──
    combo_table = Table(
        title="Performance by Gait + Distance Bucket", title_style="bold cyan"
    )
    combo_table.add_column("Gait", style="cyan")
    combo_table.add_column("Distance", style="green")
    combo_table.add_column("Races", justify="right")
    combo_table.add_column("Win Acc", justify="right")
    combo_table.add_column("Top-3 Rate", justify="right")
    combo_table.add_column("Avg Brier", justify="right")
    combo_table.add_column("Avg Field", justify="right")

    for ck, metrics in results["by_gait_distance"].items():
        if metrics is None:
            continue
        gait, dist = ck.split("|", 1)
        combo_table.add_row(
            gait.capitalize(),
            dist,
            str(metrics["races"]),
            f"{metrics['winner_accuracy']:.1%}",
            f"{metrics['top3_accuracy']:.1%}",
            f"{metrics['avg_brier_score']:.4f}",
            str(metrics["avg_field_size"]),
        )

    console.print(combo_table)

    # ── JSON output ──
    if output:
        with open(output, "w") as f:
            json.dump(results, f, indent=2, sort_keys=True)
        console.print(f"[green]Report written to {output}[/]")
    else:
        console.print_json(json.dumps(results, indent=2, sort_keys=True))


if __name__ == "__main__":
    cli()
