"""Tests for rating engine."""

from datetime import date
from unittest.mock import Mock

import pytest

from packages.core.ratings.engine import RatingEngine, RatingState
from packages.core.storage.models import EntityType, Meeting, Race, Starter


class TestRatingEngine:
    """Tests for multi-runner Elo rating engine."""

    @pytest.fixture
    def engine(self, monkeypatch):
        """Create rating engine with test settings."""
        monkeypatch.setenv("HRNZ_USERNAME", "test")
        monkeypatch.setenv("HRNZ_PASSWORD", "test")
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ELO_SCALE_C", "400.0")
        monkeypatch.setenv("ELO_K_BASE", "24.0")
        monkeypatch.setenv("INITIAL_RATING", "1500.0")
        monkeypatch.setenv("ENABLE_DRIVER", "false")
        monkeypatch.setenv("ENABLE_TRAINER", "false")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "false")

        from packages.core.common.settings import reload_settings

        reload_settings()

        return RatingEngine()

    def test_sigmoid(self, engine):
        """Test sigmoid function."""
        assert engine.sigmoid(0.0) == 0.5
        assert engine.sigmoid(1000.0) > 0.99
        assert engine.sigmoid(-1000.0) < 0.01

    def test_get_or_init_rating(self, engine):
        """Test rating initialization."""
        state = engine.get_or_init_rating(EntityType.HORSE, 123)

        assert state.rating == 1500.0
        assert state.race_count == 0

        # Should return same state on second call
        state2 = engine.get_or_init_rating(EntityType.HORSE, 123)
        assert state is state2

    def test_two_runner_race_equal_ratings(self, engine):
        """Test 2-runner race with equal ratings.

        When ratings are equal, winner should gain and loser should lose.
        """
        # Create mock race and meeting
        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 1
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        # Create starters
        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 101
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1  # Winner
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 102
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 2  # Loser
        starter2.did_not_finish = False

        # Process race
        updates = engine.process_race(race, [starter1, starter2])

        # Should have 2 updates (one per horse)
        assert len(updates) == 2

        # Winner should gain rating
        winner_update = [u for u in updates if u.entity_id == 101][0]
        assert winner_update.delta > 0
        assert winner_update.new_rating > 1500.0

        # Loser should lose rating
        loser_update = [u for u in updates if u.entity_id == 102][0]
        assert loser_update.delta < 0
        assert loser_update.new_rating < 1500.0

        # Gains and losses should be symmetric for equal prior ratings
        assert abs(winner_update.delta + loser_update.delta) < 0.01

    def test_two_runner_race_unequal_ratings(self, engine):
        """Test 2-runner race with unequal ratings.

        Favorite wins: small rating change
        Underdog wins: large rating change
        """
        # Set up unequal ratings
        engine.states[(EntityType.HORSE, 201)] = RatingState(rating=1700.0)  # Strong
        engine.states[(EntityType.HORSE, 202)] = RatingState(rating=1300.0)  # Weak

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 2
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        # Test 1: Favorite wins (expected outcome)
        starter_strong = Mock(spec=Starter)
        starter_strong.id = 1
        starter_strong.horse_id = 201
        starter_strong.driver_id = None
        starter_strong.trainer_id = None
        starter_strong.barrier = None
        starter_strong.handicap_m = None
        starter_strong.placing = 1
        starter_strong.did_not_finish = False

        starter_weak = Mock(spec=Starter)
        starter_weak.id = 2
        starter_weak.horse_id = 202
        starter_weak.driver_id = None
        starter_weak.trainer_id = None
        starter_weak.barrier = None
        starter_weak.handicap_m = None
        starter_weak.placing = 2
        starter_weak.did_not_finish = False

        updates = engine.process_race(race, [starter_strong, starter_weak])

        # Favorite should gain small amount
        favorite_update = [u for u in updates if u.entity_id == 201][0]
        assert 0 < favorite_update.delta < 7.0  # Small gain

        # Reset ratings
        engine.states[(EntityType.HORSE, 201)] = RatingState(rating=1700.0)
        engine.states[(EntityType.HORSE, 202)] = RatingState(rating=1300.0)

        # Test 2: Underdog wins (upset)
        starter_strong.placing = 2
        starter_weak.placing = 1

        race.id = 3  # Different race
        updates = engine.process_race(race, [starter_strong, starter_weak])

        # Underdog should gain large amount
        underdog_update = [u for u in updates if u.entity_id == 202][0]
        assert underdog_update.delta > 15.0  # Large gain

    def test_multi_runner_race(self, engine):
        """Test race with 10 runners.

        Rating updates should sum to approximately zero (no rating inflation).
        """
        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 10
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        # Create 10 starters with equal ratings
        starters = []
        for i in range(10):
            starter = Mock(spec=Starter)
            starter.id = i
            starter.horse_id = 1000 + i
            starter.driver_id = None
            starter.trainer_id = None
            starter.barrier = None
            starter.handicap_m = None
            starter.placing = i + 1  # 1st, 2nd, 3rd, ..., 10th
            starter.did_not_finish = False
            starters.append(starter)

        updates = engine.process_race(race, starters)

        # Should have 10 updates
        assert len(updates) == 10

        # Winner should gain
        winner_update = updates[0]
        assert winner_update.delta > 0

        # Last place should lose
        last_update = updates[-1]
        assert last_update.delta < 0

        # Sum of deltas should be near zero (no rating inflation)
        total_delta = sum(u.delta for u in updates)
        assert abs(total_delta) < 1.0

    def test_incomplete_race_skipped(self, engine):
        """Test that races with fewer than 2 finishers are skipped."""
        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 20
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        # Only one finisher
        starter = Mock(spec=Starter)
        starter.id = 1
        starter.horse_id = 301
        starter.driver_id = None
        starter.trainer_id = None
        starter.barrier = None
        starter.handicap_m = None
        starter.placing = 1
        starter.did_not_finish = False

        updates = engine.process_race(race, [starter])

        # Should return empty updates
        assert len(updates) == 0

    def test_dnf_excluded_from_calculations(self, engine):
        """Test that DNF starters are excluded from rating updates."""
        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 30
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        # Two finishers
        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 401
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 402
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 2
        starter2.did_not_finish = False

        # One DNF
        starter3 = Mock(spec=Starter)
        starter3.id = 3
        starter3.horse_id = 403
        starter3.driver_id = None
        starter3.trainer_id = None
        starter3.barrier = None
        starter3.handicap_m = None
        starter3.placing = None
        starter3.did_not_finish = True

        updates = engine.process_race(race, [starter1, starter2, starter3])

        # Should only have 2 updates (DNF excluded)
        assert len(updates) == 2
        assert all(u.entity_id in [401, 402] for u in updates)

    def test_load_existing_rating(self, engine):
        """Test loading existing rating state."""
        engine.load_rating_state(EntityType.HORSE, 500, rating=1650.0, rd=250.0)

        state = engine.get_or_init_rating(EntityType.HORSE, 500)
        assert state.rating == 1650.0
        assert state.rd == 250.0

    def test_k_factor_clamps(self, monkeypatch):
        """Test K-factor clamping."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "false")
        monkeypatch.setenv("ELO_K_BASE", "24.0")
        monkeypatch.setenv("ELO_K_MIN", "12.0")
        monkeypatch.setenv("ELO_K_MAX", "20.0")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        k_eff = engine.get_effective_k_factor(EntityType.HORSE, 1)
        assert k_eff == pytest.approx(20.0)

    def test_rd_scaling_sqrt(self, monkeypatch):
        """Test square-root RD scaling mode."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "true")
        monkeypatch.setenv("ELO_K_BASE", "24.0")
        monkeypatch.setenv("INITIAL_RD", "400.0")
        monkeypatch.setenv("RD_SCALING_MODE", "sqrt")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        engine.load_rating_state(EntityType.HORSE, 1, rating=1500.0, rd=100.0)
        k_eff = engine.get_effective_k_factor(EntityType.HORSE, 1)
        assert k_eff == pytest.approx(12.0)

    def test_rd_scaling_none(self, monkeypatch):
        """Test RD scaling disabled while RD tracking remains enabled."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "true")
        monkeypatch.setenv("ELO_K_BASE", "24.0")
        monkeypatch.setenv("INITIAL_RD", "350.0")
        monkeypatch.setenv("RD_SCALING_MODE", "none")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        engine.load_rating_state(EntityType.HORSE, 1, rating=1500.0, rd=100.0)
        k_eff = engine.get_effective_k_factor(EntityType.HORSE, 1)
        assert k_eff == pytest.approx(24.0)

    def test_pairwise_normalizer_comparisons_with_ties(self, monkeypatch):
        """Test comparisons normalizer with skipped ties."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("PAIRWISE_NORMALIZER", "comparisons")
        monkeypatch.setenv("TIE_HANDLING", "skip")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "false")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 40
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 501
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 502
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 1
        starter2.did_not_finish = False

        starter3 = Mock(spec=Starter)
        starter3.id = 3
        starter3.horse_id = 503
        starter3.driver_id = None
        starter3.trainer_id = None
        starter3.barrier = None
        starter3.handicap_m = None
        starter3.placing = 2
        starter3.did_not_finish = False

        updates = engine.process_race(race, [starter1, starter2, starter3])
        assert len(updates) == 3

    def test_pairwise_normalizer_n(self, monkeypatch):
        """Test normalizing by n in pairwise updates."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("PAIRWISE_NORMALIZER", "n")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "false")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 42
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 701
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 702
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 2
        starter2.did_not_finish = False

        updates = engine.process_race(race, [starter1, starter2])
        assert len(updates) == 2

    def test_tie_handling_half(self, monkeypatch):
        """Test half-score tie handling."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("TIE_HANDLING", "half")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "false")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 43
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 801
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 802
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 1
        starter2.did_not_finish = False

        updates = engine.process_race(race, [starter1, starter2])
        assert len(updates) == 2

    def test_rating_bounds(self, monkeypatch):
        """Test rating min/max bounds."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("RATING_MIN", "1400.0")
        monkeypatch.setenv("RATING_MAX", "1600.0")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        updates = []
        engine._apply_update(
            EntityType.HORSE,
            600,
            delta=200.0,
            race_id=1,
            race_date=None,
            updates=updates,
        )
        assert updates[0].new_rating == pytest.approx(1600.0)

    def test_rd_decay_floor(self, monkeypatch):
        """Test RD decay floor is applied."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "true")
        monkeypatch.setenv("RD_DECAY_PER_RACE", "0.0")
        monkeypatch.setenv("RD_DECAY_FLOOR", "5.0")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        state = engine.get_or_init_rating(EntityType.HORSE, 1)
        initial_rd = state.rd
        updates = []
        engine._apply_update(
            EntityType.HORSE,
            1,
            delta=1.0,
            race_id=1,
            race_date=date(2025, 1, 26),
            updates=updates,
        )
        assert state.rd == pytest.approx(initial_rd - 5.0)

    def test_rd_inflation_cap_days(self, monkeypatch):
        """Test RD inflation cap days."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "true")
        monkeypatch.setenv("RD_INFLATION_PER_DAY", "1.0")
        monkeypatch.setenv("RD_INFLATION_CAP_DAYS", "10")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()
        state = engine.get_or_init_rating(EntityType.HORSE, 2)
        state.last_race_date = date(2025, 1, 1)
        updates = []
        engine._apply_update(
            EntityType.HORSE,
            2,
            delta=1.0,
            race_id=2,
            race_date=date(2025, 2, 1),
            updates=updates,
        )
        assert state.rd <= engine.settings.rd_max

    def test_entity_k_scales(self, monkeypatch):
        """Test entity-specific K scaling."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("HORSE_K_SCALE", "1.0")
        monkeypatch.setenv("DRIVER_K_SCALE", "2.0")
        monkeypatch.setenv("TRAINER_K_SCALE", "0.5")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 44
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 901
        starter1.driver_id = 1001
        starter1.trainer_id = 1101
        starter1.barrier = None
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 902
        starter2.driver_id = 1002
        starter2.trainer_id = 1102
        starter2.barrier = None
        starter2.handicap_m = None
        starter2.placing = 2
        starter2.did_not_finish = False

        updates = engine.process_race(race, [starter1, starter2])
        driver_updates = [u for u in updates if u.entity_type == EntityType.DRIVER]
        trainer_updates = [u for u in updates if u.entity_type == EntityType.TRAINER]
        assert driver_updates
        assert trainer_updates

    def test_min_finishers_with_dnf_as_last(self, monkeypatch):
        """Test DNF treated as last for minimum finishers."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("MIN_FINISHERS", "2")
        monkeypatch.setenv("DNF_TREATED_AS_LAST", "true")

        from packages.core.common.settings import reload_settings

        reload_settings()

        engine = RatingEngine()

        meeting = Mock(spec=Meeting)
        meeting.venue = "Test Track"

        race = Mock(spec=Race)
        race.id = 41
        race.meeting = meeting
        race.distance_m = 2000
        race.start_type = "mobile"

        finisher = Mock(spec=Starter)
        finisher.id = 1
        finisher.horse_id = 601
        finisher.driver_id = None
        finisher.trainer_id = None
        finisher.barrier = None
        finisher.handicap_m = None
        finisher.placing = 1
        finisher.did_not_finish = False

        dnf = Mock(spec=Starter)
        dnf.id = 2
        dnf.horse_id = 602
        dnf.driver_id = None
        dnf.trainer_id = None
        dnf.barrier = None
        dnf.handicap_m = None
        dnf.placing = None
        dnf.did_not_finish = True

        updates = engine.process_race(race, [finisher, dnf])
        assert len(updates) == 2
