"""Performance regression tests for race predictions.

Usage:
    # Run benchmarks and save baseline:
    pytest tests/test_performance_predictions.py --benchmark-only --benchmark-json=tests/benchmark_baseline.json

    # Run benchmarks and compare against saved baseline:
    pytest tests/test_performance_predictions.py --benchmark-only --benchmark-compare=tests/benchmark_baseline.json

    # Run tests without benchmarking:
    pytest tests/test_performance_predictions.py -m "not slow"

    # Run including benchmarks:
    pytest tests/test_performance_predictions.py -v
"""

import json
import os
import time
from datetime import date

import pytest

# Set deterministic env vars for performance testing before any settings are loaded
os.environ.setdefault("ELO_SCALE_C", "400.0")
os.environ.setdefault("ELO_K_BASE", "24.0")
os.environ.setdefault("INITIAL_RATING", "1500.0")
os.environ.setdefault("ENABLE_DRIVER", "true")
os.environ.setdefault("ENABLE_TRAINER", "true")
os.environ.setdefault("ENABLE_ADJUSTMENTS", "false")
os.environ.setdefault("ENABLE_RD", "false")
os.environ.setdefault("TAB_MOCK_MODE", "true")

# Reload settings to pick up env vars
from packages.core.common.settings import reload_settings  # noqa: E402

reload_settings()

from packages.core.ratings.predictions import PredictionEngine  # noqa: E402
from packages.core.storage.models import EntityType  # noqa: E402
from packages.core.storage.repositories import (  # noqa: E402
    DriverRepository,
    HorseRepository,
    MeetingRepository,
    RaceRepository,
    RatingSnapshotRepository,
    StarterRepository,
    TrainerRepository,
)

pytestmark = pytest.mark.slow

BENCHMARK_BASELINE_FILE = os.path.join(
    os.path.dirname(__file__), "benchmark_baseline.json"
)


# ── Baseline helpers ─────────────────────────────────────────────────


def _load_baseline(key: str) -> float | None:
    """Load a benchmark baseline value from JSON file if it exists."""
    if not os.path.exists(BENCHMARK_BASELINE_FILE):
        return None
    try:
        with open(BENCHMARK_BASELINE_FILE) as f:
            data = json.load(f)
        benchmarks = data.get("benchmarks", [])
        for b in benchmarks:
            name = b.get("name", "")
            if name == key:
                return b.get("stats", {}).get("mean", None)
    except (json.JSONDecodeError, KeyError, TypeError):
        return None
    return None


def _check_benchmark_baseline(test_name: str, current_mean: float) -> None:
    """Check that the current benchmark result is within 10% of the saved baseline."""
    baseline = _load_baseline(test_name)
    if baseline is not None and baseline > 0:
        slowdown_ratio = current_mean / baseline
        assert slowdown_ratio < 1.10, (
            f"Performance regression detected for {test_name}: "
            f"{current_mean * 1000:.2f}ms vs baseline {baseline * 1000:.2f}ms "
            f"({(slowdown_ratio - 1) * 100:.1f}% slower, limit is 10%)"
        )


# ── Test data helpers ────────────────────────────────────────────────


def _create_test_meeting(session, meeting_id: str, meeting_date: date, venue: str):
    return MeetingRepository.upsert(
        session,
        {
            "meeting": meeting_id,
            "date": meeting_date.isoformat(),
            "name": venue,
            "category": "H",
        },
    )


def _create_test_race(session, meeting, race_number: int, distance: int = 2000):
    hour = 8 + (race_number % 12)  # Keep hours in valid range (8-19)
    return RaceRepository.upsert(
        session,
        meeting.id,
        {
            "race_number": race_number,
            "distance": distance,
            "start_type": "mobile",
            "gait": "pace",
            "advertised_start_string": f"2025-05-06T{hour:02d}:00:00+12:00",
        },
    )


def _create_test_horses(session, count: int, start_id: int = 10000) -> list[int]:
    horse_ids = []
    for i in range(count):
        horse_id = start_id + i
        HorseRepository.upsert(session, horse_id, f"PerfHorse_{i}")
        horse_ids.append(horse_id)
    return horse_ids


def _create_test_drivers(session, count: int) -> dict[int, int]:
    """Create test drivers and return a dict of index -> driver_id."""
    driver_ids = {}
    for i in range(count):
        driver = DriverRepository.upsert(session, f"PerfDriver_{i}")
        driver_ids[i] = driver.id
    return driver_ids


def _create_test_trainers(session, count: int) -> dict[int, int]:
    """Create test trainers and return a dict of index -> trainer_id."""
    trainer_ids = {}
    for i in range(count):
        trainer = TrainerRepository.upsert(session, f"PerfTrainer_{i}")
        trainer_ids[i] = trainer.id
    return trainer_ids


def _create_test_starters(
    session,
    race_id: int,
    horse_ids: list[int],
    driver_ids: dict[int, int],
    trainer_ids: dict[int, int],
    count: int,
):
    """Create test starters for a race with horse/driver/trainer assignments."""
    import random

    selected_horses = random.sample(horse_ids, min(count, len(horse_ids)))
    selected_drivers = [
        driver_ids[i % len(driver_ids)] for i in range(len(selected_horses))
    ]
    selected_trainers = [
        trainer_ids[i % len(trainer_ids)] for i in range(len(selected_horses))
    ]

    for i, (horse_id, d_id, t_id) in enumerate(
        zip(selected_horses, selected_drivers, selected_trainers, strict=False)
    ):
        StarterRepository.upsert(
            session,
            race_id,
            {
                "name": f"PerfHorse_{horse_id - 10000}",
                "horse_id": horse_id,
                "horse_name": f"PerfHorse_{horse_id - 10000}",
                "driver_id": d_id,
                "driver_name": f"PerfDriver_{i}",
                "trainer_id": t_id,
                "trainer_name": f"PerfTrainer_{i}",
                "runner_number": i + 1,
                "barrier": i + 1,
            },
            placing=i + 1,
        )


def _create_rating_snapshots(
    session,
    race_id: int,
    horse_ids: list[int],
    driver_ids: dict[int, int],
    trainer_ids: dict[int, int],
):
    """Create initial rating snapshots for all entities."""
    for horse_id in horse_ids:
        RatingSnapshotRepository.upsert(
            session,
            entity_type=EntityType.HORSE,
            entity_id=horse_id,
            as_of_race_id=race_id,
            rating=1500.0,
            rd=350.0,
            meta={"race_count": 5},
        )

    for d_id in driver_ids.values():
        RatingSnapshotRepository.upsert(
            session,
            entity_type=EntityType.DRIVER,
            entity_id=d_id,
            as_of_race_id=race_id,
            rating=1500.0,
            rd=350.0,
            meta={"race_count": 5},
        )

    for t_id in trainer_ids.values():
        RatingSnapshotRepository.upsert(
            session,
            entity_type=EntityType.TRAINER,
            entity_id=t_id,
            as_of_race_id=race_id,
            rating=1500.0,
            rd=350.0,
            meta={"race_count": 5},
        )


# ── Fixtures ─────────────────────────────────────────────────────────


@pytest.fixture
def fixture_prediction_data(db_session):
    """Create 1 race with 10 starters and rating history for prediction benchmarks.

    Uses the same db_session as the test. All changes are rolled back
    after the test via the db_session fixture.

    Sets up:
    - A previous meeting/race with rating snapshots (so the engine can load them)
    - A current meeting/race with starters to predict
    """
    import random

    num_starters = 10

    # ── Earlier meeting (for pre-race snapshots) ──
    prev_meeting = _create_test_meeting(
        db_session, "perf_pred_prev", date(2025, 4, 1), "Auckland"
    )
    prev_race = _create_test_race(db_session, prev_meeting, 1, 2000)
    db_session.flush()

    # Create entities
    horse_ids = _create_test_horses(db_session, num_starters, start_id=20000)
    driver_ids = _create_test_drivers(db_session, num_starters)
    trainer_ids = _create_test_trainers(db_session, num_starters)
    db_session.flush()

    # Create starters for prev race (with results)
    _create_test_starters(
        db_session, prev_race.id, horse_ids, driver_ids, trainer_ids, num_starters
    )
    db_session.flush()

    # Create rating snapshots (so PredictionEngine can find pre-race ratings)
    _create_rating_snapshots(
        db_session, prev_race.id, horse_ids, driver_ids, trainer_ids
    )
    db_session.flush()

    # ── Current meeting (for prediction) ──
    curr_meeting = _create_test_meeting(
        db_session, "perf_pred_curr", date(2025, 5, 1), "Auckland"
    )
    curr_race = _create_test_race(db_session, curr_meeting, 1, 2000)
    db_session.flush()

    # Create starters for current race (without placings — upcoming race)
    selected_horses = random.sample(horse_ids, min(num_starters, len(horse_ids)))
    selected_drivers = [
        driver_ids[i % len(driver_ids)] for i in range(len(selected_horses))
    ]
    selected_trainers = [
        trainer_ids[i % len(trainer_ids)] for i in range(len(selected_horses))
    ]

    for i, (horse_id, d_id, t_id) in enumerate(
        zip(selected_horses, selected_drivers, selected_trainers, strict=False)
    ):
        StarterRepository.upsert(
            db_session,
            curr_race.id,
            {
                "name": f"PerfHorse_{horse_id - 20000}",
                "horse_id": horse_id,
                "horse_name": f"PerfHorse_{horse_id - 20000}",
                "driver_id": d_id,
                "driver_name": f"PerfDriver_{i}",
                "trainer_id": t_id,
                "trainer_name": f"PerfTrainer_{i}",
                "runner_number": i + 1,
                "barrier": i + 1,
            },
            placing=None,  # No result yet — upcoming race
        )

    db_session.flush()

    yield curr_race.id, horse_ids

    # No cleanup needed - db_session fixture rolls back


# ── Tests ────────────────────────────────────────────────────────────


def test_prediction_performance(benchmark, db_session, fixture_prediction_data):
    """Benchmark PredictionEngine.predict_race() on a 10-starter field.

    Uses pytest-benchmark to measure execution time across multiple runs.

    After the benchmark runs, compares the result against a saved
    baseline if one exists in tests/benchmark_baseline.json.
    """
    race_id, _ = fixture_prediction_data

    # Load race and starters
    from sqlalchemy.orm import joinedload

    from packages.core.storage.models import Race, Starter

    race = (
        db_session.query(Race)
        .options(joinedload(Race.meeting))
        .filter(Race.id == race_id)
        .first()
    )
    assert race is not None, "Race not found in test data"

    starters = db_session.query(Starter).filter(Starter.race_id == race_id).all()
    assert len(starters) >= 8, f"Expected >=8 starters, got {len(starters)}"

    engine = PredictionEngine(db_session)

    # Benchmark the predict_race call
    result = benchmark(engine.predict_race, race, starters)

    assert result is not None
    assert len(result.predictions) > 0, "Should generate predictions"

    # Check against saved baseline if available
    if hasattr(benchmark, "stats") and benchmark.stats is not None:
        mean_val = benchmark.stats.get("mean")
        if mean_val is not None:
            _check_benchmark_baseline("test_prediction_performance", mean_val)


def test_prediction_baseline_threshold(db_session, fixture_prediction_data):
    """Assert single race prediction completes within 100ms."""
    race_id, _ = fixture_prediction_data

    from sqlalchemy.orm import joinedload

    from packages.core.storage.models import Race, Starter

    race = (
        db_session.query(Race)
        .options(joinedload(Race.meeting))
        .filter(Race.id == race_id)
        .first()
    )
    assert race is not None

    starters = db_session.query(Starter).filter(Starter.race_id == race_id).all()

    engine = PredictionEngine(db_session)

    start = time.perf_counter()
    result = engine.predict_race(race, starters)
    elapsed = time.perf_counter() - start

    assert result is not None
    assert len(result.predictions) > 0
    assert elapsed < 0.20, f"Prediction took {elapsed * 1000:.1f}ms, expected < 200ms"


def test_prediction_api_endpoint(db_session, fixture_prediction_data):
    """Test the /v1/races/{race_id}/predictions API endpoint response time.

    Uses FastAPI TestClient to call the full API stack including
    serialization. Asserts response time < 200ms.
    """
    race_id, _ = fixture_prediction_data

    # Import here to avoid affecting other tests' env
    from fastapi.testclient import TestClient

    from apps.backend.api.main import app, get_db

    # Override the database dependency to use the test session
    def override_get_db():
        yield db_session

    app.dependency_overrides[get_db] = override_get_db

    try:
        client = TestClient(app)

        start = time.perf_counter()
        response = client.get(f"/v1/races/{race_id}/predictions")
        elapsed = time.perf_counter() - start

        assert (
            response.status_code == 200
        ), f"Expected 200, got {response.status_code}: {response.text[:200]}"

        data = response.json()
        assert "predictions" in data
        assert len(data["predictions"]) > 0

        assert (
            elapsed < 0.20
        ), f"API prediction endpoint took {elapsed * 1000:.1f}ms, expected < 200ms"

    finally:
        app.dependency_overrides.clear()
