#!/usr/bin/env python3
"""Batch backfill CLI for historical TAB data ingestion and recomputation.

Backfills data in monthly chunks with progress logging and error handling.
Supports ingest-only, recompute-only, and full (ingest + recompute) modes.
"""

from __future__ import annotations

import asyncio
import sys
from calendar import monthrange
from dataclasses import dataclass, field
from datetime import date

import click
from rich.console import Console
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
)
from rich.table import Table

# Must be importable when running as script
from packages.core.common.logging import get_logger, setup_logging
from packages.core.common.settings import get_settings
from packages.core.storage.database import get_session
from packages.core.storage.ingestion import IngestionService

setup_logging()
logger = get_logger(__name__)
console = Console()


# ── Helpers ────────────────────────────────────────────────────────────


def _month_bounds(ref: date) -> tuple[date, date]:
    """Return (first_day, last_day) of the month containing *ref*."""
    first = ref.replace(day=1)
    last_day = monthrange(ref.year, ref.month)[1]
    last = ref.replace(day=last_day)
    return first, last


def _iter_month_chunks(
    date_from: date, date_to: date, chunk_months: int = 1
) -> list[tuple[date, date]]:
    """Split [date_from, date_to] into monthly (or multi-month) chunks.

    Each chunk is a (start, end) tuple inclusive on both ends.
    """
    chunks: list[tuple[date, date]] = []
    cursor = date_from.replace(day=1)
    while cursor <= date_to:
        # Advance by chunk_months
        chunk_end_month = cursor.month - 1 + chunk_months  # 0-indexed
        chunk_end_year = cursor.year + chunk_end_month // 12
        chunk_end_month = chunk_end_month % 12 + 1
        last_day = monthrange(chunk_end_year, chunk_end_month)[1]
        chunk_end = date(chunk_end_year, chunk_end_month, last_day)

        chunk_start = cursor
        chunk_end = min(chunk_end, date_to)

        if chunk_start <= date_to:
            chunks.append((chunk_start, chunk_end))

        # Move cursor to next month after chunk_end
        next_month = chunk_end.month  # 1-indexed
        next_year = chunk_end.year
        if next_month == 12:
            next_month = 1
            next_year += 1
        else:
            next_month += 1
        cursor = date(next_year, next_month, 1)

    return chunks


@dataclass
class ChunkResult:
    """Result of processing a single monthly chunk."""

    chunk_start: date
    chunk_end: date
    success: bool = False
    meetings: int = 0
    races: int = 0
    starters: int = 0
    errors: int = 0
    error_message: str = ""


@dataclass
class BackfillSummary:
    """Aggregated summary of a backfill run."""

    command: str
    date_from: date
    date_to: date
    category: str
    dry_run: bool
    total_chunks: int = 0
    successful_chunks: int = 0
    failed_chunks: int = 0
    total_meetings: int = 0
    total_races: int = 0
    total_starters: int = 0
    total_errors: int = 0
    snapshots_created: int = 0
    failed_details: list[tuple[str, str, str]] = field(
        default_factory=list
    )  # (start, end, error)

    def print(self) -> None:
        """Render summary table to console."""
        table = Table(title="Backfill Summary")
        table.add_column("Metric", style="cyan")
        table.add_column("Value", style="green", justify="right")

        table.add_row("Command", self.command)
        table.add_row("Date range", f"{self.date_from} to {self.date_to}")
        table.add_row("Category", self.category)
        table.add_row("Total chunks", str(self.total_chunks))
        table.add_row("Successful chunks", str(self.successful_chunks))
        if self.failed_chunks:
            table.add_row("Failed chunks", str(self.failed_chunks), style="red")
        table.add_row("Total meetings", str(self.total_meetings))
        table.add_row("Total races", str(self.total_races))
        table.add_row("Total starters", str(self.total_starters))
        if self.total_errors:
            table.add_row("Errors", str(self.total_errors), style="red")
        if self.snapshots_created:
            table.add_row("Rating snapshots", str(self.snapshots_created))

        console.print("\n")
        console.print(table)

        if self.dry_run:
            console.print("\n[yellow]⚠ DRY RUN — no data was modified[/yellow]")

        if self.failed_details:
            console.print(
                f"\n[yellow]⚠ {len(self.failed_details)} chunk(s) failed:[/yellow]"
            )
            for cs, ce, err in self.failed_details:
                console.print(f"  [yellow]{cs} to {ce}: {err}[/yellow]")

        if self.failed_chunks == 0 and not self.dry_run:
            console.print("\n[green]✓ Backfill completed successfully[/green]")
        elif self.failed_chunks > 0:
            console.print(
                f"\n[yellow]⚠ Completed with {self.failed_chunks} failed chunk(s)[/yellow]"
            )


# ── Shared click options ──────────────────────────────────────────────

_opt_from = click.option(
    "--from",
    "date_from",
    required=True,
    help="Start date (YYYY-MM-DD)",
)
_opt_to = click.option(
    "--to",
    "date_to",
    required=True,
    help="End date (YYYY-MM-DD)",
)
_opt_category = click.option(
    "--category",
    default=None,
    type=click.Choice(["T", "H", "G"], case_sensitive=False),
    help="Racing category: T (Thoroughbred), H (Harness), G (Greyhound)",
)
_opt_dry_run = click.option(
    "--dry-run",
    is_flag=True,
    default=False,
    help="Show what would be done without executing",
)


# ── CLI group ──────────────────────────────────────────────────────────


@click.group()
def cli():
    """Batch backfill historical TAB data in monthly chunks.

    \b
    Commands:
      ingest     - Ingest data in monthly chunks
      recompute  - Recompute ratings for a date range
      full       - Run ingest + recompute sequentially
    """
    pass


# ── Ingest ─────────────────────────────────────────────────────────────


@cli.command()
@_opt_from
@_opt_to
@_opt_category
@_opt_dry_run
@click.option(
    "--chunk-months",
    default=1,
    type=click.IntRange(1, 6),
    show_default=True,
    help="Number of months per chunk",
)
@click.option(
    "--source",
    type=click.Choice(["tab", "ingest"], case_sensitive=False),
    default="tab",
    show_default=True,
    help="Data source",
)
def ingest(
    date_from: str,
    date_to: str,
    category: str | None,
    dry_run: bool,
    chunk_months: int,
    source: str,
):
    """Ingest TAB data in monthly chunks.

    Splits the date range into monthly (or multi-month) chunks and ingests
    each chunk sequentially. Continues on chunk failure and reports summary.
    """
    start_date = _parse_date_arg(date_from)
    end_date = _parse_date_arg(date_to)
    settings = get_settings()
    effective_category = category.upper() if category else settings.tab.default_category

    chunks = _iter_month_chunks(start_date, end_date, chunk_months)
    summary = BackfillSummary(
        command=f"ingest ({source})",
        date_from=start_date,
        date_to=end_date,
        category=effective_category,
        dry_run=dry_run,
        total_chunks=len(chunks),
    )

    console.print(
        f"\n[bold]Ingesting {effective_category} data from "
        f"{start_date} to {end_date} in {len(chunks)} chunk(s)[/bold]"
    )
    if dry_run:
        console.print("[yellow]DRY RUN — no data will be modified[/yellow]\n")
        _show_chunks(chunks)
        summary.print()
        return

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        console=console,
    ) as progress:
        task = progress.add_task("Ingesting...", total=len(chunks))

        for chunk_start, chunk_end in chunks:
            result = _do_ingest_chunk(
                chunk_start, chunk_end, effective_category, source
            )
            summary.total_chunks += 1

            if result.success:
                summary.successful_chunks += 1
                summary.total_meetings += result.meetings
                summary.total_races += result.races
                summary.total_starters += result.starters
                summary.total_errors += result.errors
            else:
                summary.failed_chunks += 1
                summary.total_errors += 1
                summary.failed_details.append(
                    (str(chunk_start), str(chunk_end), result.error_message)
                )

            progress.update(
                task, advance=1, description=f"Ingesting {chunk_start} to {chunk_end}"
            )

    summary.print()


# ── Recompute ──────────────────────────────────────────────────────────


@cli.command()
@_opt_from
@_opt_to
@_opt_dry_run
@click.option(
    "--clear",
    is_flag=True,
    default=False,
    help="Clear existing ratings before recompute",
)
@click.option(
    "--learn-adjustments",
    is_flag=True,
    default=False,
    help="Learn barrier/handicap adjustments during recompute",
)
def recompute(
    date_from: str,
    date_to: str,
    dry_run: bool,
    clear: bool,
    learn_adjustments: bool,
):
    """Recompute ratings for a date range.

    Runs a single recompute over the full date range (no chunking since
    recompute is incremental and deterministic).
    """
    start_date = _parse_date_arg(date_from)
    end_date = _parse_date_arg(date_to)

    summary = BackfillSummary(
        command="recompute",
        date_from=start_date,
        date_to=end_date,
        category="N/A",
        dry_run=dry_run,
        total_chunks=1,
    )

    console.print(f"\n[bold]Recomputing ratings from {start_date} to {end_date}[/bold]")
    if clear:
        console.print("[yellow]Will clear existing ratings before recompute[/yellow]")
    if learn_adjustments:
        console.print("[yellow]Will learn barrier/handicap adjustments[/yellow]")

    if dry_run:
        console.print("\n[yellow]DRY RUN — no data will be modified[/yellow]")
        summary.print()
        return

    try:
        from packages.core.ratings.recompute import recompute_ratings

        with get_session() as session:
            snapshots = recompute_ratings(
                session,
                start_date,
                end_date,
                clear_existing=clear,
                learn_adjustments=learn_adjustments,
            )

        summary.successful_chunks = 1
        summary.snapshots_created = snapshots
        console.print(f"\n[green]✓ {snapshots} rating snapshots created[/green]")
    except Exception as e:
        summary.failed_chunks = 1
        summary.failed_details.append((str(start_date), str(end_date), str(e)))
        logger.error(f"Recompute failed: {e}", exc_info=True)
        console.print(f"\n[red]✗ Recompute failed: {e}[/red]")

    summary.print()


# ── Full (ingest + recompute) ──────────────────────────────────────────


@cli.command()
@_opt_from
@_opt_to
@_opt_category
@_opt_dry_run
@click.option(
    "--source",
    type=click.Choice(["tab", "ingest"], case_sensitive=False),
    default="tab",
    show_default=True,
    help="Data source for ingestion",
)
@click.option(
    "--clear",
    is_flag=True,
    default=False,
    help="Clear existing ratings before recompute",
)
@click.option(
    "--learn-adjustments",
    is_flag=True,
    default=False,
    help="Learn barrier/handicap adjustments during recompute",
)
def full(
    date_from: str,
    date_to: str,
    category: str | None,
    dry_run: bool,
    source: str,
    clear: bool,
    learn_adjustments: bool,
):
    """Run ingest + recompute sequentially for a date range.

    Ingests data in monthly chunks, then recomputes ratings over the
    full date range. Reports a combined summary.
    """
    start_date = _parse_date_arg(date_from)
    end_date = _parse_date_arg(date_to)
    settings = get_settings()
    effective_category = category.upper() if category else settings.tab.default_category

    chunks = _iter_month_chunks(start_date, end_date, 1)
    summary = BackfillSummary(
        command=f"full (ingest={source} + recompute)",
        date_from=start_date,
        date_to=end_date,
        category=effective_category,
        dry_run=dry_run,
        total_chunks=len(chunks),
    )

    console.print(
        f"\n[bold]Full backfill: {effective_category} data from "
        f"{start_date} to {end_date} ({len(chunks)} chunk(s))[/bold]"
    )

    if dry_run:
        console.print("[yellow]DRY RUN — no data will be modified[/yellow]\n")
        _show_chunks(chunks)
        console.print(
            "[dim]After ingest: will recompute ratings and learn adjustments"
            f"{' (with clear)' if clear else ''}"
            f"{' (learn adjustments)' if learn_adjustments else ''}[/dim]"
        )
        summary.print()
        return

    # ── Phase 1: Ingest ──────────────────────────────────────────────
    console.print("\n[bold cyan]Phase 1: Ingestion[/bold cyan]")

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
        TimeElapsedColumn(),
        console=console,
    ) as progress:
        task = progress.add_task("Ingesting...", total=len(chunks))

        for chunk_start, chunk_end in chunks:
            result = _do_ingest_chunk(
                chunk_start, chunk_end, effective_category, source
            )
            summary.total_chunks += 1

            if result.success:
                summary.successful_chunks += 1
                summary.total_meetings += result.meetings
                summary.total_races += result.races
                summary.total_starters += result.starters
                summary.total_errors += result.errors
            else:
                summary.failed_chunks += 1
                summary.total_errors += 1
                summary.failed_details.append(
                    (str(chunk_start), str(chunk_end), result.error_message)
                )

            progress.update(
                task, advance=1, description=f"Ingesting {chunk_start} to {chunk_end}"
            )

    # ── Phase 2: Recompute ───────────────────────────────────────────
    console.print("\n[bold cyan]Phase 2: Recompute[/bold cyan]")

    if summary.successful_chunks == 0:
        console.print("[yellow]No data ingested — skipping recompute.[/yellow]")
    else:
        try:
            from packages.core.ratings.recompute import recompute_ratings

            with get_session() as session:
                snapshots = recompute_ratings(
                    session,
                    start_date,
                    end_date,
                    clear_existing=clear,
                    learn_adjustments=learn_adjustments,
                )
            summary.snapshots_created = snapshots
            console.print(f"[green]✓ {snapshots} rating snapshots created[/green]")
        except Exception as e:
            logger.error(f"Recompute failed: {e}", exc_info=True)
            console.print(f"[red]✗ Recompute failed: {e}[/red]")

    summary.print()


# ── Internal helpers ───────────────────────────────────────────────────


def _parse_date_arg(value: str) -> date:
    """Parse a YYYY-MM-DD string into a date."""
    try:
        return date.fromisoformat(value)
    except ValueError:
        console.print(
            f"[red]Error: Invalid date '{value}'. Use YYYY-MM-DD format.[/red]"
        )
        sys.exit(1)


def _show_chunks(chunks: list[tuple[date, date]]) -> None:
    """Display the chunks that would be processed."""
    table = Table(title="Planned Chunks (DRY RUN)")
    table.add_column("#", style="dim")
    table.add_column("Start", style="cyan")
    table.add_column("End", style="cyan")

    for idx, (cs, ce) in enumerate(chunks, 1):
        table.add_row(str(idx), str(cs), str(ce))

    console.print(table)


def _do_ingest_chunk(
    chunk_start: date,
    chunk_end: date,
    category: str,
    source: str,
) -> ChunkResult:
    """Ingest a single monthly chunk and return results."""
    result = ChunkResult(chunk_start=chunk_start, chunk_end=chunk_end)

    try:
        with get_session() as session:
            service = IngestionService(session, source=source)
            meetings, races, starters = asyncio.run(
                service.ingest_date_range(chunk_start, chunk_end, category=category)
            )

        result.success = True
        result.meetings = meetings
        result.races = races
        result.starters = starters
        result.errors = service.stats["errors"]

        logger.info(
            "Chunk ingested",
            extra={
                "chunk_start": str(chunk_start),
                "chunk_end": str(chunk_end),
                "meetings": meetings,
                "races": races,
                "starters": starters,
                "errors": service.stats["errors"],
            },
        )
    except Exception as e:
        result.success = False
        result.error_message = str(e)
        logger.error(
            f"Chunk {chunk_start} to {chunk_end} failed: {e}",
            exc_info=True,
        )

    return result


if __name__ == "__main__":
    cli()
