"""Tests for barrier/handicap adjustment learning."""

from datetime import date
from unittest.mock import Mock, patch

import pytest

from packages.core.ratings.engine import RatingEngine
from packages.core.storage.models import EntityType, Meeting, Race, Starter
from packages.core.storage.repositories import (
    BarrierAdjustmentRepository,
    HandicapAdjustmentRepository,
)


class TestAdjustmentLearning:
    """Tests for adjustment learning functionality."""

    @pytest.fixture
    def mock_db(self):
        """Create mock database session."""
        return Mock()

    @pytest.fixture
    def engine(self, mock_db, monkeypatch):
        """Create rating engine with adjustments enabled."""
        monkeypatch.setenv("HRNZ_USERNAME", "test")
        monkeypatch.setenv("HRNZ_PASSWORD", "test")
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        monkeypatch.setenv("ADJ_LEARNING_RATE", "0.5")

        from packages.core.common.settings import reload_settings

        reload_settings()

        # Mock get_all to return empty lists
        with patch.object(BarrierAdjustmentRepository, "get_all", return_value=[]):
            with patch.object(HandicapAdjustmentRepository, "get_all", return_value=[]):
                return RatingEngine(db_session=mock_db)

    def test_barrier_adjustment_loading(self, mock_db, monkeypatch):
        """Test loading barrier adjustments from database."""
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        from packages.core.common.settings import reload_settings

        reload_settings()

        mock_adj = Mock()
        mock_adj.venue = "Cambridge"
        mock_adj.start_type = "mobile"
        mock_adj.distance_bucket = "1700-2000"
        mock_adj.barrier = 1
        mock_adj.adjustment = 5.0

        with patch.object(
            BarrierAdjustmentRepository, "get_all", return_value=[mock_adj]
        ):
            with patch.object(HandicapAdjustmentRepository, "get_all", return_value=[]):
                engine = RatingEngine(db_session=mock_db)

                # Check adjustment was loaded
                key = ("Cambridge", "mobile", "1700-2000", 1)
                assert key in engine.barrier_adjustments
                assert engine.barrier_adjustments[key] == 5.0

    def test_handicap_adjustment_loading(self, mock_db, monkeypatch):
        """Test loading handicap adjustments from database."""
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        from packages.core.common.settings import reload_settings

        reload_settings()

        mock_adj = Mock()
        mock_adj.venue = None
        mock_adj.start_type = None
        mock_adj.distance_bucket = "1700-2000"
        mock_adj.handicap_m = 10
        mock_adj.adjustment = -8.0

        with patch.object(BarrierAdjustmentRepository, "get_all", return_value=[]):
            with patch.object(
                HandicapAdjustmentRepository, "get_all", return_value=[mock_adj]
            ):
                engine = RatingEngine(db_session=mock_db)

                # Check adjustment was loaded
                key = (None, None, "1700-2000", 10)
                assert key in engine.handicap_adjustments
                assert engine.handicap_adjustments[key] == -8.0

    def test_learn_adjustments_from_race(self, engine, mock_db):
        """Test learning adjustments from race results."""
        # Create mock race
        mock_meeting = Mock(spec=Meeting)
        mock_meeting.venue = "Cambridge"
        mock_meeting.meeting_date = date(2025, 1, 26)

        mock_race = Mock(spec=Race)
        mock_race.id = 1
        mock_race.distance_m = 1800
        mock_race.start_type = "mobile"
        mock_race.meeting = mock_meeting

        # Create starters with different barriers
        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 100
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = 1
        starter1.handicap_m = None
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 200
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = 8
        starter2.handicap_m = None
        starter2.placing = 2
        starter2.did_not_finish = False

        starters = [starter1, starter2]

        # Initialize ratings
        engine.get_or_init_rating(EntityType.HORSE, 100)
        engine.get_or_init_rating(EntityType.HORSE, 200)

        # Mock the repository increment_sample method
        with patch.object(
            BarrierAdjustmentRepository, "increment_sample"
        ) as mock_increment:
            engine.learn_adjustments_from_race(
                mock_race, starters, use_global_only=True
            )

            # Should have called increment_sample for each starter's barrier
            assert mock_increment.call_count == 2

    def test_adjustment_repository_upsert(self):
        """Test barrier adjustment repository upsert."""
        mock_session = Mock()

        # Mock the query chain
        mock_query = Mock()
        mock_query.filter.return_value.one.return_value = Mock()
        mock_session.query.return_value = mock_query

        # This test would need actual database to run properly
        # Here we're just checking the method exists and can be called
        assert hasattr(BarrierAdjustmentRepository, "upsert")
        assert hasattr(HandicapAdjustmentRepository, "upsert")
        assert hasattr(BarrierAdjustmentRepository, "increment_sample")
        assert hasattr(HandicapAdjustmentRepository, "increment_sample")

    def test_adjustment_min_samples_and_clamp(self, monkeypatch, mock_db):
        """Test adjustment minimum samples and clamping."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        monkeypatch.setenv("ADJ_MIN_SAMPLES", "2")
        monkeypatch.setenv("ADJ_CLAMP_MIN", "-5.0")
        monkeypatch.setenv("ADJ_CLAMP_MAX", "5.0")

        from packages.core.common.settings import reload_settings

        reload_settings()

        with patch.object(BarrierAdjustmentRepository, "get_all", return_value=[]):
            with patch.object(HandicapAdjustmentRepository, "get_all", return_value=[]):
                engine = RatingEngine(db_session=mock_db)
        key = ("Test", "mobile", "1700-2000", 1)
        engine.barrier_adjustments[key] = 12.0
        engine.barrier_adjustment_samples[key] = 1

        assert engine._get_barrier_adjustment("Test", "mobile", 1800, 1) == 0.0

        engine.barrier_adjustment_samples[key] = 2
        assert engine._get_barrier_adjustment("Test", "mobile", 1800, 1) == 5.0

    def test_adjustment_global_only_and_update_scale(self, monkeypatch, mock_db):
        """Test global-only updates and delta scaling."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        monkeypatch.setenv("ADJ_GLOBAL_ONLY", "true")
        monkeypatch.setenv("ADJ_UPDATE_SCALE", "2.0")

        from packages.core.common.settings import reload_settings

        reload_settings()

        with patch.object(BarrierAdjustmentRepository, "get_all", return_value=[]):
            with patch.object(HandicapAdjustmentRepository, "get_all", return_value=[]):
                engine = RatingEngine(db_session=mock_db)

        mock_meeting = Mock(spec=Meeting)
        mock_meeting.venue = "Cambridge"
        mock_meeting.meeting_date = date(2025, 1, 26)

        mock_race = Mock(spec=Race)
        mock_race.id = 2
        mock_race.distance_m = 1800
        mock_race.start_type = "mobile"
        mock_race.meeting = mock_meeting

        starter1 = Mock(spec=Starter)
        starter1.id = 1
        starter1.horse_id = 100
        starter1.driver_id = None
        starter1.trainer_id = None
        starter1.barrier = 1
        starter1.handicap_m = 10
        starter1.placing = 1
        starter1.did_not_finish = False

        starter2 = Mock(spec=Starter)
        starter2.id = 2
        starter2.horse_id = 200
        starter2.driver_id = None
        starter2.trainer_id = None
        starter2.barrier = 2
        starter2.handicap_m = 0
        starter2.placing = 2
        starter2.did_not_finish = False

        with patch.object(
            BarrierAdjustmentRepository, "increment_sample"
        ) as mock_barrier:
            with patch.object(
                HandicapAdjustmentRepository, "increment_sample"
            ) as mock_handicap:
                engine.learn_adjustments_from_race(mock_race, [starter1, starter2])
                assert mock_barrier.call_count == 2
                assert mock_handicap.call_count == 1
                for _, kwargs in mock_barrier.call_args_list:
                    assert kwargs["venue"] is None
                    assert kwargs["start_type"] is None

    def test_adjustment_feature_toggles(self, monkeypatch, mock_db):
        """Test barrier/handicap toggle flags."""
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_ADJUSTMENTS", "true")
        monkeypatch.setenv("ADJ_BARRIER_ENABLED", "false")
        monkeypatch.setenv("ADJ_HANDICAP_ENABLED", "true")

        from packages.core.common.settings import reload_settings

        reload_settings()

        with patch.object(BarrierAdjustmentRepository, "get_all", return_value=[]):
            with patch.object(HandicapAdjustmentRepository, "get_all", return_value=[]):
                engine = RatingEngine(db_session=mock_db)

        engine.barrier_adjustments[("V", "mobile", "1700-2000", 1)] = 10.0
        engine.barrier_adjustment_samples[("V", "mobile", "1700-2000", 1)] = 5
        engine.handicap_adjustments[("V", "mobile", "1700-2000", 10)] = -5.0
        engine.handicap_adjustment_samples[("V", "mobile", "1700-2000", 10)] = 5

        assert engine._get_barrier_adjustment("V", "mobile", 1800, 1) == 0.0
        assert engine._get_handicap_adjustment("V", "mobile", 1800, 10) == -5.0


class TestRatingDeviation:
    """Tests for Rating Deviation (RD) functionality."""

    @pytest.fixture
    def engine_with_rd(self, monkeypatch):
        """Create rating engine with RD enabled."""
        monkeypatch.setenv("HRNZ_USERNAME", "test")
        monkeypatch.setenv("HRNZ_PASSWORD", "test")
        monkeypatch.setenv("DATABASE_URL", "postgresql://test")
        monkeypatch.setenv("ENABLE_RD", "true")
        monkeypatch.setenv("INITIAL_RD", "350.0")
        monkeypatch.setenv("RD_MIN", "50.0")
        monkeypatch.setenv("RD_MAX", "350.0")
        monkeypatch.setenv("RD_DECAY_PER_RACE", "15.0")
        monkeypatch.setenv("RD_INFLATION_PER_DAY", "0.5")

        from packages.core.common.settings import reload_settings

        reload_settings()

        return RatingEngine()

    def test_initial_rd(self, engine_with_rd):
        """Test initial RD is set correctly."""
        state = engine_with_rd.get_or_init_rating(EntityType.HORSE, 1)
        assert state.rd == 350.0

    def test_rd_decay_after_race(self, engine_with_rd):
        """Test RD decreases after participating in a race."""
        state = engine_with_rd.get_or_init_rating(EntityType.HORSE, 1)
        initial_rd = state.rd

        updates = []
        engine_with_rd._apply_update(
            EntityType.HORSE,
            1,
            delta=10.0,
            race_id=1,
            race_date=date(2025, 1, 26),
            updates=updates,
        )

        # RD should have decreased
        assert state.rd < initial_rd
        assert state.rd == max(initial_rd - 15.0, 50.0)

    def test_rd_inflation_for_inactivity(self, engine_with_rd):
        """Test RD increases with inactivity."""
        state = engine_with_rd.get_or_init_rating(EntityType.HORSE, 1)

        # Simulate first race
        updates = []
        engine_with_rd._apply_update(
            EntityType.HORSE,
            1,
            delta=10.0,
            race_id=1,
            race_date=date(2025, 1, 1),
            updates=updates,
        )

        # Simulate second race 30 days later
        engine_with_rd._apply_update(
            EntityType.HORSE,
            1,
            delta=5.0,
            race_id=2,
            race_date=date(2025, 1, 31),
            updates=updates,
        )

        # RD should have inflated due to inactivity before decaying
        # Inflation: 30 days * 0.5 = 15
        # Then decay: -15
        # Net effect depends on min/max bounds
        assert state.rd is not None

    def test_rd_min_max_bounds(self, engine_with_rd):
        """Test RD respects min and max bounds."""
        state = engine_with_rd.get_or_init_rating(EntityType.HORSE, 1)

        # Apply many races to hit minimum
        updates = []
        for i in range(50):
            engine_with_rd._apply_update(
                EntityType.HORSE,
                1,
                delta=1.0,
                race_id=i,
                race_date=date(2025, 1, 26),
                updates=updates,
            )

        # Should not go below minimum
        assert state.rd >= 50.0

        # Test maximum by simulating long inactivity
        state2 = engine_with_rd.get_or_init_rating(EntityType.HORSE, 2)
        state2.last_race_date = date(2024, 1, 1)

        engine_with_rd._apply_update(
            EntityType.HORSE,
            2,
            delta=1.0,
            race_id=100,
            race_date=date(2025, 12, 31),  # ~2 years later
            updates=updates,
        )

        # Should not exceed maximum
        assert state2.rd <= 350.0

    def test_effective_k_factor_with_rd(self, engine_with_rd):
        """Test K-factor adjustment based on RD."""
        # New entity with high RD should have higher K-factor
        state_new = engine_with_rd.get_or_init_rating(EntityType.HORSE, 1)
        k_new = engine_with_rd.get_effective_k_factor(EntityType.HORSE, 1)

        # K-factor should equal base K for new entity (RD = initial RD)
        assert k_new == pytest.approx(24.0)  # base K-factor

        # Simulate many races to reduce RD
        updates = []
        for i in range(20):
            engine_with_rd._apply_update(
                EntityType.HORSE,
                1,
                delta=1.0,
                race_id=i,
                race_date=date(2025, 1, 26),
                updates=updates,
            )

        # Established entity with low RD should have lower K-factor
        k_established = engine_with_rd.get_effective_k_factor(EntityType.HORSE, 1)
        assert k_established < k_new
        assert k_established < 10.0  # Should be significantly lower

        # K-factor should scale proportionally with RD
        # K_eff = K_base * (RD / RD_initial)
        expected_k = 24.0 * (state_new.rd / 350.0)
        assert k_established == pytest.approx(expected_k)

    def test_effective_k_factor_without_rd(self, monkeypatch):
        """Test K-factor remains constant when RD is disabled."""
        monkeypatch.setenv("ENABLE_RD", "false")
        from packages.core.common.settings import reload_settings

        reload_settings()

        engine_no_rd = RatingEngine()

        # K-factor should always equal base K
        k1 = engine_no_rd.get_effective_k_factor(EntityType.HORSE, 1)
        k2 = engine_no_rd.get_effective_k_factor(EntityType.HORSE, 2)

        assert k1 == pytest.approx(24.0)
        assert k2 == pytest.approx(24.0)
