#!/usr/bin/env python3
"""Validate backfilled data for correctness and completeness.

Produces a JSON report with:
  - Meeting count per month
  - Race count per month
  - Rating convergence (snapshot count, rating distribution)
  - Duplicate race detection
  - Overall health summary
"""

from __future__ import annotations

import json
import sys
from dataclasses import asdict, dataclass, field
from datetime import date, datetime
from typing import Any

import click
from rich.console import Console
from rich.table import Table

from packages.core.common.logging import setup_logging
from packages.core.storage.database import get_session

setup_logging()
console = Console()


# ── Data classes ────────────────────────────────────────────────────────


@dataclass
class MonthlyCounts:
    """Counts for a single month."""

    month: str  # YYYY-MM
    meetings: int = 0
    races: int = 0
    starters: int = 0


@dataclass
class RatingDistribution:
    """Rating distribution statistics."""

    min: float = 0.0
    max: float = 0.0
    mean: float = 0.0
    median: float = 0.0
    std: float = 0.0
    p5: float = 0.0
    p25: float = 0.0
    p75: float = 0.0
    p95: float = 0.0
    total_snapshots: int = 0


@dataclass
class DuplicateInfo:
    """Information about duplicate races found."""

    total_races: int = 0
    duplicate_race_numbers: int = 0  # Same meeting + race_number + date
    duplicate_ids: int = 0  # Same tab_event_id
    duplicates: list[dict[str, Any]] = field(default_factory=list)


@dataclass
class ValidationReport:
    """Complete validation report."""

    date_from: str
    date_to: str
    generated_at: str
    monthly_meetings: list[dict[str, Any]] = field(default_factory=list)
    monthly_races: list[dict[str, Any]] = field(default_factory=list)
    monthly_starters: list[dict[str, Any]] = field(default_factory=list)
    total_meetings: int = 0
    total_races: int = 0
    total_starters: int = 0
    total_entities: dict[str, int] = field(default_factory=dict)
    rating_distribution: dict[str, RatingDistribution] = field(default_factory=dict)
    duplicates: DuplicateInfo = field(default_factory=DuplicateInfo)
    issues: list[dict[str, str]] = field(default_factory=list)
    status: str = "pass"  # pass, warn, fail


# ── Validation logic ────────────────────────────────────────────────────


def _get_monthly_meetings(
    session, date_from: date, date_to: date
) -> list[dict[str, Any]]:
    """Count meetings per month."""
    from sqlalchemy import func

    from packages.core.storage.models import Meeting

    rows = (
        session.query(
            func.to_char(Meeting.meeting_date, "YYYY-MM").label("month"),
            func.count(Meeting.id).label("count"),
        )
        .filter(Meeting.meeting_date >= date_from, Meeting.meeting_date <= date_to)
        .group_by("month")
        .order_by("month")
        .all()
    )
    return [{"month": r.month, "count": r.count} for r in rows]


def _get_monthly_races(session, date_from: date, date_to: date) -> list[dict[str, Any]]:
    """Count races per month."""
    from sqlalchemy import func

    from packages.core.storage.models import Meeting, Race

    rows = (
        session.query(
            func.to_char(Meeting.meeting_date, "YYYY-MM").label("month"),
            func.count(Race.id).label("count"),
        )
        .join(Race, Race.meeting_id == Meeting.id)
        .filter(Meeting.meeting_date >= date_from, Meeting.meeting_date <= date_to)
        .group_by("month")
        .order_by("month")
        .all()
    )
    return [{"month": r.month, "count": r.count} for r in rows]


def _get_monthly_starters(
    session, date_from: date, date_to: date
) -> list[dict[str, Any]]:
    """Count starters per month."""
    from sqlalchemy import func

    from packages.core.storage.models import Meeting, Race, Starter

    rows = (
        session.query(
            func.to_char(Meeting.meeting_date, "YYYY-MM").label("month"),
            func.count(Starter.id).label("count"),
        )
        .join(Race, Race.meeting_id == Meeting.id)
        .join(Starter, Starter.race_id == Race.id)
        .filter(Meeting.meeting_date >= date_from, Meeting.meeting_date <= date_to)
        .group_by("month")
        .order_by("month")
        .all()
    )
    return [{"month": r.month, "count": r.count} for r in rows]


def _get_entity_counts(session) -> dict[str, int]:
    """Count unique entities in the database."""
    from packages.core.storage.models import (
        Driver,
        Horse,
        Meeting,
        Race,
        Starter,
        Trainer,
    )

    return {
        "horses": session.query(Horse).count(),
        "drivers": session.query(Driver).count(),
        "trainers": session.query(Trainer).count(),
        "meetings": session.query(Meeting).count(),
        "races": session.query(Race).count(),
        "starters": session.query(Starter).count(),
    }


def _get_rating_distribution(session) -> dict[str, RatingDistribution]:
    """Get rating distribution per entity type."""
    import statistics

    from packages.core.storage.models import RatingSnapshot

    entity_types = ["horse", "driver", "trainer"]
    distributions: dict[str, RatingDistribution] = {}

    for et in entity_types:
        snapshots = (
            session.query(RatingSnapshot.rating)
            .filter(RatingSnapshot.entity_type == et)
            .all()
        )
        ratings = [float(s.rating) for s in snapshots]
        if not ratings:
            distributions[et] = RatingDistribution()
            continue

        sorted_ratings = sorted(ratings)
        n = len(sorted_ratings)

        distributions[et] = RatingDistribution(
            min=min(ratings),
            max=max(ratings),
            mean=statistics.mean(ratings),
            median=statistics.median(ratings),
            std=statistics.stdev(ratings) if len(ratings) > 1 else 0.0,
            p5=sorted_ratings[max(0, int(n * 0.05))],
            p25=sorted_ratings[max(0, int(n * 0.25))],
            p75=sorted_ratings[min(n - 1, int(n * 0.75))],
            p95=sorted_ratings[min(n - 1, int(n * 0.95))],
            total_snapshots=n,
        )

    return distributions


def _find_duplicate_races(session, date_from: date, date_to: date) -> DuplicateInfo:
    """Find duplicate races in the date range."""
    from sqlalchemy import func

    from packages.core.storage.models import Meeting, Race

    # Duplicate: same meeting_id + race_number (should be unique per meeting)
    dup_race_numbers = (
        session.query(
            Race.meeting_id,
            Race.race_number,
            func.count(Race.id).label("cnt"),
        )
        .join(Meeting)
        .filter(Meeting.meeting_date >= date_from, Meeting.meeting_date <= date_to)
        .group_by(Race.meeting_id, Race.race_number)
        .having(func.count(Race.id) > 1)
        .all()
    )

    dup_details: list[dict[str, Any]] = []
    for meeting_id, race_number, cnt in dup_race_numbers:
        races = (
            session.query(Race)
            .filter(Race.meeting_id == meeting_id, Race.race_number == race_number)
            .all()
        )
        dup_details.append(
            {
                "meeting_id": str(meeting_id),
                "race_number": race_number,
                "count": cnt,
                "race_ids": [r.id for r in races],
            }
        )

    # Duplicate by tab_event_id
    dup_event_ids = (
        session.query(
            Race.tab_event_id,
            func.count(Race.id).label("cnt"),
        )
        .filter(Race.tab_event_id.isnot(None))
        .group_by(Race.tab_event_id)
        .having(func.count(Race.id) > 1)
        .all()
    )

    for tab_event_id, cnt in dup_event_ids:
        races = session.query(Race).filter(Race.tab_event_id == tab_event_id).all()
        dup_details.append(
            {
                "tab_event_id": str(tab_event_id),
                "count": cnt,
                "race_ids": [r.id for r in races],
            }
        )

    total_races = (
        session.query(Race)
        .join(Meeting)
        .filter(Meeting.meeting_date >= date_from, Meeting.meeting_date <= date_to)
        .count()
    )

    return DuplicateInfo(
        total_races=total_races,
        duplicate_race_numbers=len(dup_race_numbers),
        duplicate_ids=len(dup_event_ids),
        duplicates=dup_details,
    )


def _check_issues(
    monthly_meetings: list[dict[str, int]],
    monthly_races: list[dict[str, int]],
    duplicates: DuplicateInfo,
    rating_distribution: dict[str, RatingDistribution],
) -> list[dict[str, str]]:
    """Check for data quality issues."""
    issues: list[dict[str, str]] = []

    # Check for months with no data
    for row in monthly_meetings:
        if row["count"] == 0:
            issues.append(
                {
                    "severity": "warning",
                    "category": "missing_data",
                    "message": f"No meetings found for {row['month']}",
                }
            )

    for row in monthly_races:
        if row["count"] == 0:
            issues.append(
                {
                    "severity": "warning",
                    "category": "missing_data",
                    "message": f"No races found for {row['month']}",
                }
            )

    # Check for duplicate races
    total_dups = duplicates.duplicate_race_numbers + duplicates.duplicate_ids
    if total_dups > 0:
        issues.append(
            {
                "severity": "error" if total_dups > 10 else "warning",
                "category": "duplicates",
                "message": (
                    f"Found {total_dups} duplicate race group(s) "
                    f"(race_number: {duplicates.duplicate_race_numbers}, "
                    f"tab_event_id: {duplicates.duplicate_ids})"
                ),
            }
        )

    # Check rating distribution sanity
    for entity_type, dist in rating_distribution.items():
        if dist.total_snapshots == 0:
            issues.append(
                {
                    "severity": "warning",
                    "category": "no_ratings",
                    "message": f"No rating snapshots for {entity_type} entities",
                }
            )
        else:
            # Check for extreme rating values (likely data issues)
            if dist.min < 500 or dist.max > 3000:
                issues.append(
                    {
                        "severity": "warning",
                        "category": "rating_range",
                        "message": (
                            f"{entity_type} ratings have extreme range: "
                            f"[{dist.min:.1f}, {dist.max:.1f}] "
                            f"(expected ~500-3000)"
                        ),
                    }
                )

    return issues


# ── CLI ─────────────────────────────────────────────────────────────────


@click.command()
@click.option(
    "--from",
    "date_from",
    required=True,
    help="Start date (YYYY-MM-DD)",
)
@click.option(
    "--to",
    "date_to",
    required=True,
    help="End date (YYYY-MM-DD)",
)
@click.option(
    "--out",
    "output_file",
    default=None,
    help="Output JSON file path (optional, prints to stdout if omitted)",
)
@click.option(
    "--verbose",
    is_flag=True,
    default=False,
    help="Show detailed per-month breakdown",
)
def validate_backfill(
    date_from: str, date_to: str, output_file: str | None, verbose: bool
):
    """Validate backfilled data in a date range.

    Checks meeting/race/starter counts per month, rating convergence,
    and duplicate races. Outputs a JSON report.
    """
    try:
        start_date = date.fromisoformat(date_from)
        end_date = date.fromisoformat(date_to)
    except ValueError as e:
        console.print(f"[red]Invalid date format: {e}[/red]")
        sys.exit(1)

    if start_date > end_date:
        console.print("[red]Error: start date must be before end date[/red]")
        sys.exit(1)

    console.print(
        f"\n[bold]Validating backfilled data from {start_date} to {end_date}[/bold]\n"
    )

    report = ValidationReport(
        date_from=date_from,
        date_to=date_to,
        generated_at=datetime.now().isoformat(),
    )

    with get_session() as session:
        # Monthly breakdowns
        report.monthly_meetings = _get_monthly_meetings(session, start_date, end_date)
        report.monthly_races = _get_monthly_races(session, start_date, end_date)
        report.monthly_starters = _get_monthly_starters(session, start_date, end_date)

        # Totals
        report.total_meetings = sum(r["count"] for r in report.monthly_meetings)
        report.total_races = sum(r["count"] for r in report.monthly_races)
        report.total_starters = sum(r["count"] for r in report.monthly_starters)
        report.total_entities = _get_entity_counts(session)

        # Rating distribution
        report.rating_distribution = _get_rating_distribution(session)

        # Duplicate detection
        report.duplicates = _find_duplicate_races(session, start_date, end_date)

        # Issues
        report.issues = _check_issues(
            report.monthly_meetings,
            report.monthly_races,
            report.duplicates,
            report.rating_distribution,
        )

    # Determine overall status
    errors = [i for i in report.issues if i["severity"] == "error"]
    warnings = [i for i in report.issues if i["severity"] == "warning"]
    if errors:
        report.status = "fail"
    elif warnings:
        report.status = "warn"
    else:
        report.status = "pass"

    # ── Display summary ─────────────────────────────────────────────
    summary_table = Table(title="Validation Summary")
    summary_table.add_column("Metric", style="cyan")
    summary_table.add_column("Value", style="green", justify="right")

    summary_table.add_row(
        "Status", report.status.upper(), style=_style_for_status(report.status)
    )
    summary_table.add_row("Total meetings", str(report.total_meetings))
    summary_table.add_row("Total races", str(report.total_races))
    summary_table.add_row("Total starters", str(report.total_starters))
    summary_table.add_row("Horses", str(report.total_entities.get("horses", 0)))
    summary_table.add_row("Drivers", str(report.total_entities.get("drivers", 0)))
    summary_table.add_row("Trainers", str(report.total_entities.get("trainers", 0)))
    summary_table.add_row(
        "Duplicate groups",
        str(report.duplicates.duplicate_race_numbers + report.duplicates.duplicate_ids),
    )
    summary_table.add_row("Issues", str(len(report.issues)))
    summary_table.add_row(
        "Errors", str(len(errors)), style="red" if errors else "green"
    )
    summary_table.add_row(
        "Warnings", str(len(warnings)), style="yellow" if warnings else "green"
    )

    console.print(summary_table)

    # Show rating distribution
    if report.rating_distribution:
        rd_table = Table(title="Rating Distribution by Entity Type")
        rd_table.add_column("Entity Type", style="cyan")
        rd_table.add_column("Snapshots", justify="right")
        rd_table.add_column("Mean", justify="right")
        rd_table.add_column("Median", justify="right")
        rd_table.add_column("Std", justify="right")
        rd_table.add_column("Min", justify="right")
        rd_table.add_column("Max", justify="right")

        for et, dist in report.rating_distribution.items():
            if dist.total_snapshots > 0:
                rd_table.add_row(
                    et,
                    str(dist.total_snapshots),
                    f"{dist.mean:.1f}",
                    f"{dist.median:.1f}",
                    f"{dist.std:.1f}",
                    f"{dist.min:.1f}",
                    f"{dist.max:.1f}",
                )
        console.print("\n")
        console.print(rd_table)

    # Show monthly breakdown if verbose
    if verbose and report.monthly_meetings:
        monthly_table = Table(title="Monthly Breakdown")
        monthly_table.add_column("Month", style="cyan")
        monthly_table.add_column("Meetings", justify="right")
        monthly_table.add_column("Races", justify="right")
        monthly_table.add_column("Starters", justify="right")

        monthly_map: dict[str, dict] = {}
        for row in report.monthly_meetings:
            monthly_map.setdefault(row["month"], {})["meetings"] = row["count"]
        for row in report.monthly_races:
            monthly_map.setdefault(row["month"], {})["races"] = row["count"]
        for row in report.monthly_starters:
            monthly_map.setdefault(row["month"], {})["starters"] = row["count"]

        for month in sorted(monthly_map.keys()):
            data = monthly_map[month]
            monthly_table.add_row(
                month,
                str(data.get("meetings", 0)),
                str(data.get("races", 0)),
                str(data.get("starters", 0)),
            )
        console.print("\n")
        console.print(monthly_table)

    # Show issues
    if report.issues:
        issues_table = Table(title="Issues")
        issues_table.add_column("Severity", style="bold")
        issues_table.add_column("Category")
        issues_table.add_column("Message")

        for issue in report.issues:
            sev_style = {
                "error": "red",
                "warning": "yellow",
                "info": "blue",
            }.get(issue["severity"], "white")
            issues_table.add_row(
                f"[{sev_style}]{issue['severity'].upper()}[/{sev_style}]",
                issue["category"],
                issue["message"],
            )
        console.print("\n")
        console.print(issues_table)

    # Exit code
    if report.status == "fail":
        console.print("\n[bold red]✗ Validation FAILED[/bold red]")
    elif report.status == "warn":
        console.print("\n[bold yellow]⚠ Validation passed with warnings[/bold yellow]")
    else:
        console.print("\n[bold green]✓ Validation PASSED[/bold green]")

    # ── Serialise report ─────────────────────────────────────────────
    report_dict = _serialise_report(report)

    if output_file:
        with open(output_file, "w") as f:
            json.dump(report_dict, f, indent=2, default=str)
        console.print(f"\n[green]Report saved to {output_file}[/green]")
    else:
        console.print("\n[dim]JSON report (stdout):[/dim]")
        click.echo(json.dumps(report_dict, indent=2, default=str))

    # Exit with code
    if report.status == "fail":
        sys.exit(1)
    sys.exit(0)


# ── Helpers ─────────────────────────────────────────────────────────────


def _style_for_status(status: str) -> str:
    if status == "pass":
        return "green"
    elif status == "warn":
        return "yellow"
    return "red"


def _serialise_report(report: ValidationReport) -> dict[str, Any]:
    """Convert report to a JSON-serialisable dict."""
    return {
        "date_from": report.date_from,
        "date_to": report.date_to,
        "generated_at": report.generated_at,
        "status": report.status,
        "summary": {
            "total_meetings": report.total_meetings,
            "total_races": report.total_races,
            "total_starters": report.total_starters,
            "total_entities": report.total_entities,
            "duplicate_groups": report.duplicates.duplicate_race_numbers
            + report.duplicates.duplicate_ids,
        },
        "monthly_meetings": report.monthly_meetings,
        "monthly_races": report.monthly_races,
        "monthly_starters": report.monthly_starters,
        "rating_distribution": {
            et: asdict(dist) for et, dist in report.rating_distribution.items()
        },
        "duplicates": asdict(report.duplicates),
        "issues": report.issues,
    }


if __name__ == "__main__":
    validate_backfill()
