#!/usr/bin/env python3
"""Simulate theoretical betting returns based on model predictions vs market odds.

Simulates flat-stake betting strategies to evaluate the profitability of
model predictions. Supports two strategies:

  1. **Predicted winner**: Bet on the horse with the highest win probability.
  2. **Value bets**: Bet on horses where model probability exceeds the implied
     market probability (i.e., positive expected value).

Outputs a Rich table and optionally a JSON report.

Usage:
    python scripts/roi_simulation.py --from 2026-01-01 --to 2026-03-31 --stake 2.0 --out roi.json
"""

from __future__ import annotations

import json
from datetime import date, datetime, timedelta

import click
from rich.console import Console
from rich.table import Table
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from packages.core.ratings.predictions import PredictionEngine
from packages.core.storage.models import Meeting, Race, Starter

console = Console()


# ── Helpers ──────────────────────────────────────────────────────────────


def _parse_date(value: str | None) -> date | None:
    """Parse a YYYY-MM-DD string or return None."""
    if not value:
        return None
    return datetime.strptime(value, "%Y-%m-%d").date()


def _extract_market_odds(starter: Starter) -> float | None:
    """Extract win market odds from a starter's raw JSON if available.

    Checks common TAB API field names for win dividend / fixed price.
    Returns decimal odds or *None* if unavailable.
    """
    raw = starter.raw_json
    if not raw or not isinstance(raw, dict):
        return None

    # Common field names in TAB / HRNZ raw data
    for key in ("win_dividend", "fixed_price", "win_price", "odds", "win_odds"):
        val = raw.get(key)
        if val is not None:
            try:
                return float(val)
            except (ValueError, TypeError):
                continue

    # Nested structure: e.g. {"prices": {"win": 5.0, "place": 1.5}}
    prices = raw.get("prices") or raw.get("price") or {}
    if isinstance(prices, dict):
        for key in ("win", "win_price", "win_odds"):
            val = prices.get(key)
            if val is not None:
                try:
                    return float(val)
                except (ValueError, TypeError):
                    continue

    return None


def _implied_probability(decimal_odds: float) -> float:
    """Convert decimal odds to implied probability."""
    if decimal_odds <= 1.0:
        return 1.0
    return 1.0 / decimal_odds


# ── Simulation logic ─────────────────────────────────────────────────────


def _simulate(
    session: Session,
    date_from: date | None,
    date_to: date | None,
    min_confidence: float,
    stake: float,
) -> dict:
    """Run the betting simulation and return results.

    Returns a dict with keys:
      - ``predicted_winner`` / ``value_bets``: per-strategy results
      - ``races_processed``: total races evaluated
      - ``races_with_odds``: races where market odds were available
    """
    # Build query for completed races
    query = (
        session.query(Race)
        .join(Race.meeting)
        .join(Starter, Race.id == Starter.race_id)
        .filter(Starter.placing.isnot(None))
        .filter(Starter.did_not_finish.is_(False))
    )
    if date_from:
        query = query.filter(Meeting.meeting_date >= date_from)
    if date_to:
        query = query.filter(Meeting.meeting_date <= date_to)

    query = query.order_by(
        Meeting.meeting_date, Race.race_datetime, Race.race_number, Race.id
    )
    races = query.distinct(Race.id).all()

    if not races:
        return {"error": "No completed races found in the specified date range"}

    engine = PredictionEngine(session)

    # Per-strategy accumulators
    pw_bets = 0
    pw_wins = 0
    pw_return = 0.0
    pw_odds_sum = 0.0

    vb_bets = 0
    vb_wins = 0
    vb_return = 0.0
    vb_odds_sum = 0.0

    races_processed = 0
    races_with_odds = 0

    for race in races:
        starters = [
            s for s in race.starters if s.placing is not None and not s.did_not_finish
        ]
        if not starters:
            continue

        try:
            prediction = engine.predict_race(race, starters)
        except Exception:
            continue
        if not prediction or not prediction.predictions:
            continue

        races_processed += 1

        # Map starter_id -> actual placing for lookup
        placing_map = {s.id: s.placing for s in starters}
        winner_starter_id = next(
            (sid for sid, p in placing_map.items() if p == 1), None
        )
        if winner_starter_id is None:
            continue

        # ── Strategy 1: Predicted winner ─────────────────────────────
        predicted_winner = max(prediction.predictions, key=lambda p: p.win_probability)

        if predicted_winner.win_probability >= min_confidence:
            pw_bets += 1
            actual_placing = placing_map.get(predicted_winner.starter_id)
            if actual_placing == 1:
                pw_wins += 1
                # Derive odds from model probability if market odds unavailable
                odds = _extract_market_odds(
                    next(
                        (s for s in starters if s.id == predicted_winner.starter_id),
                        None,
                    )
                )
                if odds is None or odds <= 1.0:
                    odds = 1.0 / max(predicted_winner.win_probability, 0.001)
                else:
                    races_with_odds += 1
                payout = odds * stake
                pw_return += payout
                pw_odds_sum += odds
            else:
                pw_return -= stake

        # ── Strategy 2: Value bets ───────────────────────────────────
        for pred in prediction.predictions:
            if pred.win_probability < min_confidence:
                continue

            starter_obj = next((s for s in starters if s.id == pred.starter_id), None)
            if not starter_obj:
                continue

            odds = _extract_market_odds(starter_obj)
            if odds is None or odds <= 1.0:
                continue  # Need market odds for value comparison

            implied = _implied_probability(odds)

            # Value: model probability exceeds implied probability
            if pred.win_probability > implied:
                vb_bets += 1
                actual_placing = placing_map.get(pred.starter_id)
                if actual_placing == 1:
                    vb_wins += 1
                    payout = odds * stake
                    vb_return += payout
                    vb_odds_sum += odds
                else:
                    vb_return -= stake
                races_with_odds += 1

    # ── Aggregate results ────────────────────────────────────────────
    def _build_strategy(bets, wins, total_return, odds_sum):
        if bets == 0:
            return {
                "bets": 0,
                "wins": 0,
                "strike_rate": 0.0,
                "total_return": 0.0,
                "total_staked": 0.0,
                "profit_loss": 0.0,
                "roi_pct": 0.0,
                "avg_odds": 0.0,
            }
        total_staked = bets * stake
        profit_loss = total_return - total_staked
        return {
            "bets": bets,
            "wins": wins,
            "strike_rate": round(wins / bets * 100, 2),
            "total_return": round(total_return, 2),
            "total_staked": round(total_staked, 2),
            "profit_loss": round(profit_loss, 2),
            "roi_pct": (
                round(profit_loss / total_staked * 100, 2) if total_staked > 0 else 0.0
            ),
            "avg_odds": round(odds_sum / wins, 2) if wins > 0 else 0.0,
        }

    return {
        "simulation_params": {
            "date_from": str(date_from) if date_from else None,
            "date_to": str(date_to) if date_to else None,
            "min_confidence": min_confidence,
            "stake_per_bet": stake,
        },
        "summary": {
            "races_processed": races_processed,
            "races_with_market_odds": races_with_odds,
        },
        "predicted_winner": _build_strategy(pw_bets, pw_wins, pw_return, pw_odds_sum),
        "value_bets": _build_strategy(vb_bets, vb_wins, vb_return, vb_odds_sum),
    }


def _print_rich(results: dict) -> None:
    """Print simulation results as a Rich table."""
    if "error" in results:
        console.print(f"[red]Error: {results['error']}[/red]")
        return

    console.print()
    console.print("[bold]ROI Simulation Results[/bold]")
    console.print(
        f"  Date range: {results['simulation_params']['date_from']} → "
        f"{results['simulation_params']['date_to']}"
    )
    console.print(
        f"  Stake per bet: ${results['simulation_params']['stake_per_bet']:.2f}"
    )
    console.print(
        f"  Min confidence: {results['simulation_params']['min_confidence']:.0%}"
    )
    console.print(f"  Races processed: {results['summary']['races_processed']}")
    console.print(
        f"  Races with market odds: {results['summary']['races_with_market_odds']}"
    )
    console.print()

    table = Table(title="Betting Strategy Comparison")
    table.add_column("Metric", style="cyan")
    table.add_column("Predicted Winner", style="green")
    table.add_column("Value Bets", style="yellow")

    pw = results["predicted_winner"]
    vb = results["value_bets"]

    table.add_row("Total Bets", str(pw["bets"]), str(vb["bets"]))
    table.add_row("Winners", str(pw["wins"]), str(vb["wins"]))
    table.add_row(
        "Strike Rate", f"{pw['strike_rate']:.2f}%", f"{vb['strike_rate']:.2f}%"
    )
    table.add_row(
        "Total Staked", f"${pw['total_staked']:.2f}", f"${vb['total_staked']:.2f}"
    )
    table.add_row(
        "Total Return", f"${pw['total_return']:.2f}", f"${vb['total_return']:.2f}"
    )
    table.add_row(
        "Profit / Loss", f"${pw['profit_loss']:+.2f}", f"${vb['profit_loss']:+.2f}"
    )
    table.add_row("ROI", f"{pw['roi_pct']:+.2f}%", f"{vb['roi_pct']:+.2f}%")
    table.add_row("Avg Winner Odds", f"{pw['avg_odds']:.2f}", f"{vb['avg_odds']:.2f}")

    console.print(table)
    console.print()


# ── CLI ──────────────────────────────────────────────────────────────────


@click.command()
@click.option("--from", "date_from", default=None, help="Start date (YYYY-MM-DD)")
@click.option("--to", "date_to", default=None, help="End date (YYYY-MM-DD)")
@click.option(
    "--min-confidence",
    default=0.0,
    show_default=True,
    type=float,
    help="Minimum win probability threshold for bets",
)
@click.option(
    "--stake",
    default=1.0,
    show_default=True,
    type=float,
    help="Dollar amount per bet",
)
@click.option(
    "--out",
    default=None,
    type=click.Path(writable=True),
    help="Output JSON file path",
)
@click.option(
    "--db-url",
    default=None,
    envvar="DATABASE_URL",
    help="Database URL (defaults to DATABASE_URL env var)",
)
def main(
    date_from: str | None,
    date_to: str | None,
    min_confidence: float,
    stake: float,
    out: str | None,
    db_url: str | None,
) -> int:
    """Simulate theoretical betting returns from model predictions.

    Evaluates two betting strategies:

    \b
    - **Predicted winner**: bets on the horse with the highest win probability.
    - **Value bets**: bets on horses where the model's probability exceeds the
      implied probability from market odds.

    Requires **DATABASE_URL** environment variable or ``--db-url``.
    """
    if not db_url:
        console.print(
            "[red]DATABASE_URL is required (set env var or pass --db-url)[/red]"
        )
        return 1

    parsed_from = (
        _parse_date(date_from) if date_from else date.today() - timedelta(days=365)
    )
    parsed_to = _parse_date(date_to) if date_to else date.today()

    console.print("[dim]Connecting to database...[/dim]")
    engine = create_engine(db_url)
    with Session(engine) as session:
        results = _simulate(session, parsed_from, parsed_to, min_confidence, stake)

    # Output
    _print_rich(results)

    if out:
        with open(out, "w") as f:
            json.dump(results, f, indent=2, default=str)
        console.print(f"[green]Results written to {out}[/green]")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
