#!/usr/bin/env python3
"""Time-series analysis of rating volatility.

Computes rolling window statistics, max drawdown, rating velocity,
and peak/trough periods for rating snapshots.

Usage:
    python scripts/time_series_volatility.py --entity-type horse
    python scripts/time_series_volatility.py --entity-type driver --entity-id 42
    python scripts/time_series_volatility.py --out volatility.json
"""

from __future__ import annotations

import json
import math
import os
from collections.abc import Sequence
from datetime import date, datetime, timedelta
from typing import Any

import click
from rich.console import Console
from rich.table import Table
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from packages.core.storage.models import EntityType, Meeting, Race, RatingSnapshot

console = Console()

ENTITY_TYPE_MAP = {
    "horse": EntityType.HORSE,
    "driver": EntityType.DRIVER,
    "trainer": EntityType.TRAINER,
}

_SPARKLINE_CHARS = "▁▂▃▄▅▆▇█"


def _compute_sparkline(values: Sequence[float], width: int = 30) -> str:
    """Render an ASCII sparkline from a list of values."""
    if not values:
        return ""
    if len(values) == 1:
        return _SPARKLINE_CHARS[len(_SPARKLINE_CHARS) // 2]

    n = min(len(values), width)
    step = len(values) / n
    sampled = []
    for i in range(n):
        idx = int(i * step)
        sampled.append(values[idx])

    mn, mx = min(sampled), max(sampled)
    if mx == mn:
        return _SPARKLINE_CHARS[len(_SPARKLINE_CHARS) // 2] * n

    return "".join(
        _SPARKLINE_CHARS[
            min(
                int((v - mn) / (mx - mn) * (len(_SPARKLINE_CHARS) - 1)),
                len(_SPARKLINE_CHARS) - 1,
            )
        ]
        for v in sampled
    )


def _compute_max_drawdown(ratings: Sequence[float]) -> float:
    """Compute maximum drawdown from peak rating."""
    if len(ratings) < 2:
        return 0.0
    peak = ratings[0]
    max_dd = 0.0
    for r in ratings:
        if r > peak:
            peak = r
        dd = (peak - r) / peak if peak != 0 else 0.0
        if dd > max_dd:
            max_dd = dd
    return max_dd


def _compute_rolling_std(
    dates: Sequence[date],
    ratings: Sequence[float],
    window_days: int,
) -> list[float]:
    """Compute rolling standard deviation over a sliding window."""
    if len(ratings) < 2:
        return [0.0] * len(ratings)

    result: list[float] = [0.0] * len(ratings)
    for i in range(len(ratings)):
        window_start = dates[i] - timedelta(days=window_days)
        window_ratings = [ratings[j] for j in range(i + 1) if dates[j] >= window_start]
        if len(window_ratings) < 2:
            result[i] = 0.0
        else:
            mean = sum(window_ratings) / len(window_ratings)
            variance = sum((r - mean) ** 2 for r in window_ratings) / len(
                window_ratings
            )
            result[i] = math.sqrt(variance)
    return result


def _find_peak_trough(
    dates: Sequence[date],
    ratings: Sequence[float],
) -> dict[str, Any]:
    """Identify peak and trough periods."""
    if not ratings:
        return {"peak": None, "trough": None}

    max_r = max(ratings)
    min_r = min(ratings)
    peak_idx = ratings.index(max_r)
    trough_idx = ratings.index(min_r)

    return {
        "peak": {
            "rating": round(max_r, 2),
            "date": str(dates[peak_idx]) if peak_idx < len(dates) else None,
        },
        "trough": {
            "rating": round(min_r, 2),
            "date": str(dates[trough_idx]) if trough_idx < len(dates) else None,
        },
    }


def _analyze_entity(
    session: Session,
    entity_type: EntityType,
    entity_id: int | None,
    date_from: date | None,
    date_to: date | None,
    window_days: int,
) -> dict[str, Any]:
    """Compute volatility metrics for an entity type (or specific entity)."""
    # Build query for rating snapshots
    query = (
        session.query(RatingSnapshot, Race, Meeting)
        .join(Race, RatingSnapshot.as_of_race_id == Race.id)
        .join(Meeting, Race.meeting_id == Meeting.id)
        .filter(RatingSnapshot.entity_type == entity_type)
    )

    if entity_id is not None:
        query = query.filter(RatingSnapshot.entity_id == entity_id)

    if date_from:
        query = query.filter(Meeting.meeting_date >= date_from)
    if date_to:
        query = query.filter(Meeting.meeting_date <= date_to)

    query = query.order_by(Meeting.meeting_date, Race.race_datetime, Race.race_number)

    rows = query.all()

    if not rows:
        return {"error": "no_data", "detail": "No rating snapshots found"}

    dates: list[date] = []
    ratings: list[float] = []
    entity_name: str | None = None

    for snap, _race, meeting in rows:
        if meeting.meeting_date:
            dates.append(meeting.meeting_date)
        else:
            dates.append(date.min)
        ratings.append(float(snap.rating))  # type: ignore[arg-type]

    # Try to get entity name for specific entities
    if entity_id is not None:
        try:
            if entity_type == EntityType.HORSE:
                from packages.core.storage.models import Horse

                ent = session.query(Horse).filter(Horse.id == entity_id).first()
            elif entity_type == EntityType.DRIVER:
                from packages.core.storage.models import Driver

                ent = session.query(Driver).filter(Driver.id == entity_id).first()
            else:
                from packages.core.storage.models import Trainer

                ent = session.query(Trainer).filter(Trainer.id == entity_id).first()
            if ent:
                entity_name = str(ent.name)
        except Exception:
            pass

    # ── Compute metrics ──
    total_changes = sum(
        abs(ratings[i] - ratings[i - 1]) for i in range(1, len(ratings))
    )
    num_races = len(ratings)

    rolling_30 = _compute_rolling_std(dates, ratings, 30)
    rolling_90 = _compute_rolling_std(dates, ratings, 90)

    avg_rolling_30 = sum(rolling_30) / len(rolling_30) if rolling_30 else 0.0
    avg_rolling_90 = sum(rolling_90) / len(rolling_90) if rolling_90 else 0.0

    max_drawdown = _compute_max_drawdown(ratings)

    # Rating velocity: average absolute change per race
    velocity = total_changes / num_races if num_races > 0 else 0.0

    # Net change
    net_change = ratings[-1] - ratings[0] if len(ratings) >= 2 else 0.0

    peak_trough = _find_peak_trough(dates, ratings)

    sparkline = _compute_sparkline(ratings)

    return {
        "entity_type": entity_type.value,
        "entity_id": entity_id,
        "entity_name": entity_name,
        "num_ratings": num_races,
        "date_from": str(dates[0]) if dates else None,
        "date_to": str(dates[-1]) if dates else None,
        "first_rating": round(ratings[0], 2) if ratings else None,
        "last_rating": round(ratings[-1], 2) if ratings else None,
        "net_change": round(net_change, 2),
        "avg_rating": round(sum(ratings) / len(ratings), 2) if ratings else None,
        "rating_std_dev": (
            round(
                math.sqrt(
                    sum((r - sum(ratings) / len(ratings)) ** 2 for r in ratings)
                    / len(ratings)
                ),
                2,
            )
            if ratings
            else None
        ),
        "avg_rolling_30d_std": round(avg_rolling_30, 2),
        "avg_rolling_90d_std": round(avg_rolling_90, 2),
        "max_drawdown_pct": round(max_drawdown * 100, 2),
        "rating_velocity": round(velocity, 2),
        "peak_trough": peak_trough,
        "sparkline": sparkline,
    }


@click.command()
@click.option(
    "--entity-type",
    type=click.Choice(["horse", "driver", "trainer"]),
    default="horse",
    show_default=True,
    help="Entity type to analyze",
)
@click.option(
    "--entity-id", type=int, default=None, help="Specific entity ID (optional)"
)
@click.option("--from", "date_from", default=None, help="Start date (YYYY-MM-DD)")
@click.option("--to", "date_to", default=None, help="End date (YYYY-MM-DD)")
@click.option(
    "--window-days",
    default=90,
    show_default=True,
    help="Rolling window size in days for std dev (also 30d computed)",
)
@click.option("--out", "output", default=None, help="Write JSON report to file")
@click.option("--db-url", default=None, help="Database URL (default: $DATABASE_URL)")
def cli(
    entity_type: str,
    entity_id: int | None,
    date_from: str | None,
    date_to: str | None,
    window_days: int,
    output: str | None,
    db_url: str | None,
):
    """Time-series analysis of rating volatility.

    Computes rolling standard deviation, max drawdown, rating velocity,
    and peak/trough periods for horse/driver/trainer ratings.
    """
    db_url = db_url or os.getenv("DATABASE_URL")
    if not db_url:
        console.print("[red]ERROR: --db-url or DATABASE_URL is required[/]")
        raise SystemExit(1)

    et = ENTITY_TYPE_MAP[entity_type]

    parsed_from = None
    parsed_to = None
    if date_from:
        parsed_from = datetime.strptime(date_from, "%Y-%m-%d").date()
    if date_to:
        parsed_to = datetime.strptime(date_to, "%Y-%m-%d").date()

    engine = create_engine(db_url)
    with Session(engine) as session:
        results = _analyze_entity(
            session=session,
            entity_type=et,
            entity_id=entity_id,
            date_from=parsed_from,
            date_to=parsed_to,
            window_days=window_days,
        )

    if "error" in results:
        console.print(f"[yellow]Warning: {results.get('detail', results['error'])}[/]")
        if output:
            with open(output, "w") as f:
                json.dump(results, f, indent=2)
        return

    # ── Rich table ──
    table = Table(
        title=f"Rating Volatility — {entity_type.title()}", title_style="bold cyan"
    )
    table.add_column("Metric", style="cyan")
    table.add_column("Value", justify="right")

    name_display = results.get("entity_name") or (
        f"ID {results['entity_id']}" if results.get("entity_id") else "All"
    )
    table.add_row("Entity", name_display)
    table.add_row("Ratings Count", str(results["num_ratings"]))
    table.add_row("First Rating", str(results["first_rating"]))
    table.add_row("Last Rating", str(results["last_rating"]))
    table.add_row("Net Change", f"{results['net_change']:+.2f}")
    table.add_row("Avg Rating", str(results["avg_rating"]))
    table.add_row("Overall Std Dev", str(results["rating_std_dev"]))
    table.add_row("Avg 30d Rolling Std", str(results["avg_rolling_30d_std"]))
    table.add_row("Avg 90d Rolling Std", str(results["avg_rolling_90d_std"]))
    table.add_row("Max Drawdown", f"{results['max_drawdown_pct']:.2f}%")
    table.add_row("Rating Velocity", str(results["rating_velocity"]))
    table.add_row("Peak Date", str(results["peak_trough"]["peak"]["date"]))
    table.add_row("Peak Rating", str(results["peak_trough"]["peak"]["rating"]))
    table.add_row("Trough Date", str(results["peak_trough"]["trough"]["date"]))
    table.add_row("Trough Rating", str(results["peak_trough"]["trough"]["rating"]))

    console.print(table)

    # ── Sparkline ──
    sparkline = results.get("sparkline", "")
    if sparkline:
        console.print("\n[bold cyan]Rating Trajectory (Sparkline):[/]")
        console.print(f"  [green]{sparkline}[/]")
        console.print(
            f"  {results['first_rating']} {'→':>25} {results['last_rating']}\n"
        )

    # ── JSON output ──
    if output:
        with open(output, "w") as f:
            json.dump(results, f, indent=2, sort_keys=True)
        console.print(f"[green]Report written to {output}[/]")
    else:
        console.print_json(json.dumps(results, indent=2, sort_keys=True))


if __name__ == "__main__":
    cli()
