#!/usr/bin/env python3
"""Evaluate prediction accuracy using current rating settings.

Outputs aggregate metrics for all races plus with/without driver+trainer splits.
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
from datetime import datetime
from math import log

from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session

from packages.core.ratings.predictions import PredictionEngine
from packages.core.storage.models import Race, Starter


def _parse_date(value: str | None) -> datetime | None:
    if not value:
        return None
    return datetime.fromisoformat(value)


def _evaluate(session: Session, date_from: datetime | None, date_to: datetime | None):
    query = select(Race).join(Starter).where(Starter.placing.isnot(None)).distinct()
    if date_from:
        query = query.where(Race.race_datetime >= date_from)
    if date_to:
        query = query.where(Race.race_datetime <= date_to)

    races = session.execute(query).scalars().all()
    if not races:
        return {"error": "no_races"}

    engine_pred = PredictionEngine(session)
    totals = {
        "all": [0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0],
        "with_dt": [0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0],
        "no_dt": [0, 0, 0.0, 0.0, 0.0, 0.0, 0.0, 0],
    }

    start_time = time.time()
    for idx, race in enumerate(races, start=1):
        starters = [s for s in race.starters if s.placing is not None]
        if not starters:
            continue

        res = engine_pred.predict_race(race, starters)
        if not res or not res.predictions:
            continue

        winner = next((s for s in starters if s.placing == 1), None)
        if not winner:
            continue

        prob_map = {p.starter_id: p.win_probability for p in res.predictions}
        place_prob_map = {p.starter_id: p.place_probability for p in res.predictions}
        predicted_winner = max(res.predictions, key=lambda p: p.win_probability)
        winner_acc = 1.0 if predicted_winner.starter_id == winner.id else 0.0

        top3_pred = {
            p.starter_id
            for p in sorted(
                res.predictions, key=lambda p: p.win_probability, reverse=True
            )[:3]
        }
        top3_actual = {
            s.id for s in sorted(starters, key=lambda s: (s.placing or 999))[:3]
        }
        top3_overlap = len(top3_pred & top3_actual) / 3.0

        top3_pred_place = {
            p.starter_id
            for p in sorted(res.predictions, key=lambda p: p.predicted_placing)[:3]
        }
        top3_overlap_place = len(top3_pred_place & top3_actual) / 3.0

        p_win = prob_map.get(winner.id, 0.0)
        brier = (1 - p_win) ** 2
        log_loss = -log(max(min(p_win, 1 - 1e-15), 1e-15))

        place_brier_sum = 0.0
        place_log_loss_sum = 0.0
        for starter in starters:
            prob_place = place_prob_map.get(starter.id, 0.0)
            actual_place = 1.0 if starter.placing and starter.placing <= 3 else 0.0
            place_brier_sum += (prob_place - actual_place) ** 2
            if actual_place == 1.0:
                place_log_loss_sum += -log(max(min(prob_place, 1 - 1e-15), 1e-15))
            else:
                place_log_loss_sum += -log(max(min(1.0 - prob_place, 1 - 1e-15), 1e-15))

        starters_count = len(starters)
        place_brier = place_brier_sum / starters_count
        place_log_loss = place_log_loss_sum / starters_count

        has_dt = any(s.driver_id and s.trainer_id for s in starters)
        key = "with_dt" if has_dt else "no_dt"

        for bucket in ("all", key):
            totals[bucket][0] += winner_acc
            totals[bucket][1] += top3_overlap
            totals[bucket][2] += brier
            totals[bucket][3] += log_loss
            totals[bucket][4] += top3_overlap_place
            totals[bucket][5] += place_brier
            totals[bucket][6] += place_log_loss
            totals[bucket][7] += 1

        if idx % 250 == 0:
            elapsed = time.time() - start_time
            msg = f"processed={idx}/{len(races)} elapsed_s={elapsed:.1f}"
            print(msg, file=sys.stderr, flush=True)

    def avg(vals):
        if vals[7] == 0:
            return None
        return {
            "races": vals[7],
            "winner_acc": vals[0] / vals[7],
            "top3_overlap": vals[1] / vals[7],
            "brier": vals[2] / vals[7],
            "log_loss": vals[3] / vals[7],
            "place_top3_overlap": vals[4] / vals[7],
            "place_brier": vals[5] / vals[7],
            "place_log_loss": vals[6] / vals[7],
        }

    return {key: avg(vals) for key, vals in totals.items()}


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--from", dest="date_from", default=None)
    parser.add_argument("--to", dest="date_to", default=None)
    parser.add_argument("--json", action="store_true", help="Emit JSON output")
    args = parser.parse_args()

    date_from = _parse_date(args.date_from)
    date_to = _parse_date(args.date_to)

    db_url = os.getenv("DATABASE_URL")
    if not db_url:
        raise SystemExit("DATABASE_URL is required")

    engine = create_engine(db_url)
    with Session(engine) as session:
        results = _evaluate(session, date_from, date_to)

    if args.json:
        print(json.dumps(results, indent=2, sort_keys=True))
    else:
        for key in ("all", "with_dt", "no_dt"):
            print(key, results.get(key))

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
