#!/usr/bin/env python3
"""Generate monthly trend reports from prediction accuracy data.

For each month in the specified range, computes:
  - Winner accuracy (top-1)
  - Top-3 hit rate
  - Average Brier score
  - Average field size
  - ROI from predicted-winner betting
  - Number of races ingested

Outputs a JSON report with per-month data points.

Usage:
    python scripts/monthly_trends.py --from 2025-01-01 --to 2025-12-31 --out trends.json
"""

from __future__ import annotations

import json
from collections import defaultdict
from datetime import date, datetime, timedelta

import click
from rich.console import Console
from rich.table import Table
from sqlalchemy import create_engine, func
from sqlalchemy.orm import Session

from packages.core.ratings.predictions import PredictionEngine
from packages.core.storage.models import Meeting, Race, Starter

console = Console()


# ── Helpers ──────────────────────────────────────────────────────────────


def _parse_date(value: str | None) -> date | None:
    if not value:
        return None
    return datetime.strptime(value, "%Y-%m-%d").date()


def _year_month_key(d: date) -> str:
    """Return a ``YYYY-MM`` key for a date."""
    return d.strftime("%Y-%m")


def _iter_months(from_date: date, to_date: date) -> list[date]:
    """Return first-day-of-month dates from *from_date* through *to_date*."""
    months: list[date] = []
    current = from_date.replace(day=1)
    end = to_date.replace(day=1)
    while current <= end:
        months.append(current)
        # Advance to next month
        if current.month == 12:
            current = current.replace(year=current.year + 1, month=1)
        else:
            current = current.replace(month=current.month + 1)
    return months


# ── Monthly aggregation ──────────────────────────────────────────────────


def _compute_monthly(
    session: Session,
    date_from: date | None,
    date_to: date | None,
) -> list[dict]:
    """Compute monthly accuracy and volume metrics.

    Returns a list of dicts, one per month, sorted chronologically.
    """
    # Query completed races in range
    query = (
        session.query(Race)
        .join(Race.meeting)
        .join(Starter, Race.id == Starter.race_id)
        .filter(Starter.placing.isnot(None))
        .filter(Starter.did_not_finish.is_(False))
    )
    if date_from:
        query = query.filter(Meeting.meeting_date >= date_from)
    if date_to:
        query = query.filter(Meeting.meeting_date <= date_to)

    query = query.order_by(
        Meeting.meeting_date, Race.race_datetime, Race.race_number, Race.id
    )
    races = query.distinct(Race.id).all()

    if not races:
        return []

    engine = PredictionEngine(session)

    # Accumulate per-month buckets
    monthly: dict[str, dict] = defaultdict(
        lambda: {
            "races": 0,
            "winner_correct": 0,
            "top3_overlap": 0,
            "brier_sum": 0.0,
            "field_size_sum": 0,
            "bet_return": 0.0,
            "bet_count": 0,
        }
    )

    for race in races:
        month_key = _year_month_key(race.meeting.meeting_date) if race.meeting else None
        if not month_key:
            continue

        starters = [
            s for s in race.starters if s.placing is not None and not s.did_not_finish
        ]
        if not starters:
            continue

        try:
            prediction = engine.predict_race(race, starters)
        except Exception:
            continue
        if not prediction or not prediction.predictions:
            continue

        bucket = monthly[month_key]

        # Field size
        bucket["field_size_sum"] += len(prediction.predictions)
        bucket["races"] += 1

        # Winner prediction
        predicted_winner = max(prediction.predictions, key=lambda p: p.win_probability)
        actual_winner = next((s for s in starters if s.placing == 1), None)
        if actual_winner:
            winner_correct = predicted_winner.starter_id == actual_winner.id
            if winner_correct:
                bucket["winner_correct"] += 1

            # Brier score
            brier = sum(
                (p.win_probability - (1.0 if p.starter_id == actual_winner.id else 0.0))
                ** 2
                for p in prediction.predictions
            ) / len(prediction.predictions)
            bucket["brier_sum"] += brier

        # Top-3 overlap
        actual_top3 = {s.id for s in starters if s.placing and s.placing <= 3}
        predicted_top3 = {
            p.starter_id
            for p in sorted(prediction.predictions, key=lambda p: p.predicted_placing)[
                :3
            ]
        }
        overlap = len(actual_top3 & predicted_top3)
        bucket["top3_overlap"] += overlap

        # Simple betting simulation (predicted winner, $1 stake, even-money proxy)
        if actual_winner:
            bucket["bet_count"] += 1
            if predicted_winner.starter_id == actual_winner.id:
                # Use model-derived odds as a proxy (1/p - 1 for margin)
                implied_odds = 1.0 / max(predicted_winner.win_probability, 0.001)
                bucket["bet_return"] += implied_odds  # Return includes stake
            else:
                bucket["bet_return"] -= 1.0  # Lost stake

    # Build output
    results = []
    months = _iter_months(date_from or date(2020, 1, 1), date_to or date.today())
    for month_start in months:
        key = _year_month_key(month_start)
        if key not in monthly:
            results.append(
                {
                    "month": key,
                    "races": 0,
                    "winner_accuracy": None,
                    "top3_hit_rate": None,
                    "avg_brier_score": None,
                    "avg_field_size": None,
                    "bet_roi_pct": None,
                    "races_ingested": 0,
                }
            )
            continue

        b = monthly[key]
        n = b["races"]
        winner_acc = b["winner_correct"] / n if n > 0 else None
        top3_rate = b["top3_overlap"] / (n * 3) if n > 0 else None
        avg_brier = b["brier_sum"] / n if n > 0 else None
        avg_field = b["field_size_sum"] / n if n > 0 else None
        bet_roi = (
            (b["bet_return"] / b["bet_count"] - 1.0) * 100
            if b["bet_count"] > 0
            else None
        )

        # Count races ingested that month (regardless of results)
        ingested = (
            session.query(func.count(Race.id))
            .join(Race.meeting)
            .filter(
                Meeting.meeting_date >= month_start,
                Meeting.meeting_date
                < (
                    month_start.replace(
                        month=month_start.month % 12 + 1, year=month_start.year
                    )
                    if month_start.month < 12
                    else month_start.replace(year=month_start.year + 1, month=1)
                ),
            )
            .scalar()
            or 0
        )

        results.append(
            {
                "month": key,
                "races": n,
                "winner_accuracy": (
                    round(winner_acc, 4) if winner_acc is not None else None
                ),
                "top3_hit_rate": round(top3_rate, 4) if top3_rate is not None else None,
                "avg_brier_score": (
                    round(avg_brier, 4) if avg_brier is not None else None
                ),
                "avg_field_size": (
                    round(avg_field, 2) if avg_field is not None else None
                ),
                "bet_roi_pct": round(bet_roi, 2) if bet_roi is not None else None,
                "races_ingested": ingested,
            }
        )

    return results


def _print_trend_table(results: list[dict]) -> None:
    """Print a Rich table of monthly trend data."""
    if not results:
        console.print("[yellow]No data found for the specified date range.[/yellow]")
        return

    table = Table(title="Monthly Prediction Trends")
    table.add_column("Month", style="cyan", no_wrap=True)
    table.add_column("Races", justify="right")
    table.add_column("Win Acc", justify="right")
    table.add_column("Top-3 Rate", justify="right")
    table.add_column("Brier", justify="right")
    table.add_column("Avg Field", justify="right")
    table.add_column("Bet ROI", justify="right")
    table.add_column("Ingested", justify="right")

    for row in results:

        def _fmt(val, fmt_str=".2f"):
            if val is None:
                return "—"
            return f"{val:{fmt_str}}"

        table.add_row(
            row["month"],
            str(row["races"]),
            _fmt(row["winner_accuracy"], ".1%"),
            _fmt(row["top3_hit_rate"], ".1%"),
            _fmt(row["avg_brier_score"]),
            _fmt(row["avg_field_size"]),
            (
                _fmt(row["bet_roi_pct"], "+.1f") + "%"
                if row.get("bet_roi_pct") is not None
                else "—"
            ),
            str(row["races_ingested"]),
        )

    console.print()
    console.print(table)
    console.print()


# ── ASCII sparkline ──────────────────────────────────────────────────────


_UNICODE_SPARK = "▁▂▃▄▅▆▇█"


def _sparkline(values: list[float], width: int = 20) -> str:
    """Generate a Unicode sparkline from a list of values."""
    if not values:
        return ""
    mn, mx = min(values), max(values)
    if mx == mn:
        return _UNICODE_SPARK[-1] * min(len(values), width)
    # Normalise and downsample
    step = max(1, len(values) // width)
    sampled = values[::step]
    if len(sampled) > width:
        sampled = sampled[:width]
    bars = []
    for v in sampled:
        idx = int((v - mn) / (mx - mn) * (len(_UNICODE_SPARK) - 1))
        bars.append(_UNICODE_SPARK[idx])
    return "".join(bars)


def _print_sparklines(results: list[dict]) -> None:
    """Print simple ASCII/Unicode trend charts for key metrics."""
    win_acc = [
        r["winner_accuracy"] for r in results if r["winner_accuracy"] is not None
    ]
    brier = [r["avg_brier_score"] for r in results if r["avg_brier_score"] is not None]
    roi = [r["bet_roi_pct"] for r in results if r["bet_roi_pct"] is not None]

    if not win_acc:
        return

    months = [r["month"] for r in results if r["winner_accuracy"] is not None]

    console.print("[bold]Trend sparklines[/bold]")
    console.print(
        f"  Win Acc : {_sparkline(win_acc)}  "
        f"(min={min(win_acc):.1%}, max={max(win_acc):.1%})"
    )
    if brier:
        console.print(
            f"  Brier   : {_sparkline(brier)}  "
            f"(min={min(brier):.3f}, max={max(brier):.3f})"
        )
    if roi:
        console.print(
            f"  Bet ROI : {_sparkline(roi)}  "
            f"(min={min(roi):+.1f}%, max={max(roi):+.1f}%)"
        )
    console.print(f"  Months  : {months[0]} … {months[-1]} ({len(months)} months)")
    console.print()


# ── CLI ──────────────────────────────────────────────────────────────────


@click.command()
@click.option("--from", "date_from", default=None, help="Start date (YYYY-MM-DD)")
@click.option("--to", "date_to", default=None, help="End date (YYYY-MM-DD)")
@click.option(
    "--out",
    default=None,
    type=click.Path(writable=True),
    help="Output JSON file path",
)
@click.option(
    "--db-url",
    default=None,
    envvar="DATABASE_URL",
    help="Database URL (defaults to DATABASE_URL env var)",
)
def main(
    date_from: str | None,
    date_to: str | None,
    out: str | None,
    db_url: str | None,
) -> int:
    """Generate monthly prediction trend reports.

    For each month in the date range, computes winner accuracy, top-3 hit
    rate, average Brier score, average field size, betting ROI, and a total
    race-ingestion count.

    Requires **DATABASE_URL** environment variable or ``--db-url``.
    """
    if not db_url:
        console.print(
            "[red]DATABASE_URL is required (set env var or pass --db-url)[/red]"
        )
        return 1

    parsed_from = (
        _parse_date(date_from) if date_from else date.today() - timedelta(days=365)
    )
    parsed_to = _parse_date(date_to) if date_to else date.today()

    console.print("[dim]Connecting to database...[/dim]")
    engine = create_engine(db_url)
    with Session(engine) as session:
        results = _compute_monthly(session, parsed_from, parsed_to)

    if not results:
        console.print("[yellow]No data found.[/yellow]")
        return 0

    # Print table
    _print_trend_table(results)

    # Print sparklines
    _print_sparklines(results)

    # Write JSON
    report = {
        "parameters": {
            "date_from": str(parsed_from),
            "date_to": str(parsed_to),
        },
        "total_months": len(results),
        "monthly": results,
    }

    if out:
        with open(out, "w") as f:
            json.dump(report, f, indent=2, default=str)
        console.print(f"[green]Report written to {out}[/green]")
    else:
        console.print("[dim]JSON report (use --out to write to file):[/dim]")
        console.print(json.dumps(report, indent=2, default=str))

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
