"""Tests for WebSocket live race update functionality."""

import json
from datetime import date
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from fastapi.testclient import TestClient

from apps.backend.api.main import app
from apps.backend.api.websocket import ConnectionManager
from packages.core.storage.models import Driver, Horse, Meeting, Race, Starter, Trainer


class TestWebSocket:
    """Tests for WebSocket race update endpoints."""

    @pytest.fixture
    def mock_race(self):
        """Create a mock race with related objects."""
        mock_meeting = Mock(spec=Meeting)
        mock_meeting.venue = "Test Venue"
        mock_meeting.meeting_date = date(2025, 6, 1)

        mock_horse = Mock(spec=Horse)
        mock_horse.id = 100
        mock_horse.name = "Test Horse"

        mock_driver = Mock(spec=Driver)
        mock_driver.id = 10
        mock_driver.name = "Test Driver"

        mock_trainer = Mock(spec=Trainer)
        mock_trainer.id = 20
        mock_trainer.name = "Test Trainer"

        mock_starter = Mock(spec=Starter)
        mock_starter.id = 1
        mock_starter.horse_id = 100
        mock_starter.driver_id = 10
        mock_starter.trainer_id = 20
        mock_starter.runner_number = 1
        mock_starter.barrier = 2
        mock_starter.handicap_m = 10
        mock_starter.placing = None
        mock_starter.did_not_finish = False
        mock_starter.horse = mock_horse
        mock_starter.driver = mock_driver
        mock_starter.trainer = mock_trainer

        mock_race = Mock(spec=Race)
        mock_race.id = 42
        mock_race.meeting_id = "M1"
        mock_race.race_number = 5
        mock_race.distance_m = 2000
        mock_race.start_type = "Mobile"
        mock_race.gait = "Pace"
        mock_race.weather = "Fine"
        mock_race.track_condition = "Good"
        mock_race.race_datetime = None
        mock_race.meeting = mock_meeting
        mock_race.starters = [mock_starter]

        return mock_race

    @pytest.fixture
    def client(self, mock_race):
        """Create a test client with mocked DB session and real WebSocket manager.

        Uses a real ConnectionManager (to properly accept WebSocket connections)
        but prevents the background simulation from starting.
        """

        test_manager = ConnectionManager()
        # Prevent background simulation from starting during tests
        test_manager.is_simulation_running = lambda race_id: True  # type: ignore[method-assign]

        with (
            patch("apps.backend.api.main.get_session") as mock_get_session,
            patch("apps.backend.api.websocket.manager", test_manager),
        ):
            # Mock the session context manager (use MagicMock for __enter__/__exit__)
            mock_session_ctx = MagicMock()
            mock_session = MagicMock()
            mock_session_ctx.__enter__.return_value = mock_session
            mock_get_session.return_value = mock_session_ctx

            # Mock query chain: session.query(Race).filter(...).first()
            # and session.query(Race).options(...).filter(...).first()
            mock_query = Mock()
            mock_query.filter.return_value.first.return_value = mock_race
            mock_query.options.return_value = mock_query
            mock_session.query.return_value = mock_query

            with TestClient(app) as test_client:
                yield test_client

    def test_websocket_connect_race_not_found(self, client, mock_race):
        """Test WebSocket connection with non-existent race returns close code 4004."""
        # Override - make the mock query return None for the race
        with patch("apps.backend.api.main.get_session") as mock_get_session:
            mock_session_ctx = MagicMock()
            mock_session = MagicMock()
            mock_session_ctx.__enter__.return_value = mock_session
            mock_get_session.return_value = mock_session_ctx

            mock_query = Mock()
            mock_query.filter.return_value.first.return_value = None
            mock_session.query.return_value = mock_query

            with pytest.raises(Exception):
                with client.websocket_connect("/ws/races/999"):
                    pass

    def test_websocket_connect_sends_initial_state(self, client):
        """Test that connecting to a valid race sends initial state."""
        with client.websocket_connect("/ws/races/42") as ws:
            # Should receive initial state message
            initial_data = ws.receive_text()
            msg = json.loads(initial_data)
            assert msg["type"] == "initial_state"
            assert msg["race_id"] == 42
            assert "data" in msg
            assert "race" in msg["data"]
            assert "starters" in msg["data"]
            assert msg["data"]["starter_count"] == 1
            # Verify race details are present
            race_data = msg["data"]["race"]
            assert race_data["race_number"] == 5
            assert race_data["venue"] == "Test Venue"
            # Verify starter details
            assert len(msg["data"]["starters"]) == 1
            starter = msg["data"]["starters"][0]
            assert starter["horse_id"] == 100
            assert starter["horse_name"] == "Test Horse"

    def test_websocket_subscribe_message(self, client):
        """Test handling subscribe message from client."""
        with client.websocket_connect("/ws/races/42") as ws:
            # Read initial state
            ws.receive_text()

            # Send subscribe message
            ws.send_text(json.dumps({"type": "subscribe", "race_id": 42}))

            # Should receive confirmation
            response = ws.receive_text()
            msg = json.loads(response)
            assert msg["type"] == "subscribed"
            assert msg["race_id"] == 42

    def test_websocket_disconnect_cleanup(self, client):
        """Test that disconnect removes client from race room without errors."""
        with client.websocket_connect("/ws/races/42") as ws:
            # Read initial state
            ws.receive_text()

        # After context exit, disconnect should have been called on manager
        # Verify no exceptions during cleanup


class TestConnectionManager:
    """Unit tests for ConnectionManager."""

    @pytest.mark.asyncio
    async def test_connect_and_disconnect(self):
        """Test basic connect/disconnect flow."""
        cm = ConnectionManager()
        mock_ws = Mock()
        mock_ws.accept = AsyncMock()
        mock_ws.send_text = Mock()

        await cm.connect(mock_ws, 1)
        assert cm.get_connection_count(1) == 1
        mock_ws.accept.assert_awaited_once()

        await cm.disconnect(mock_ws, 1)
        assert cm.get_connection_count(1) == 0

    @pytest.mark.asyncio
    async def test_broadcast_to_race(self):
        """Test broadcasting to all clients in a race room."""
        cm = ConnectionManager()
        ws1 = Mock()
        ws1.accept = AsyncMock()
        ws1.send_text = Mock()
        ws2 = Mock()
        ws2.accept = AsyncMock()
        ws2.send_text = Mock()

        await cm.connect(ws1, 1)
        await cm.connect(ws2, 1)

        await cm.broadcast_to_race('{"test": true}', 1)

        ws1.send_text.assert_called_once_with('{"test": true}')
        ws2.send_text.assert_called_once_with('{"test": true}')

    @pytest.mark.asyncio
    async def test_broadcast_handles_disconnected_client(self):
        """Test broadcast continues even if one client fails."""
        cm = ConnectionManager()
        ws1 = Mock()
        ws1.accept = AsyncMock()
        ws1.send_text = Mock()
        ws2 = Mock()
        ws2.accept = AsyncMock()
        ws2.send_text = Mock(side_effect=Exception("Connection lost"))

        await cm.connect(ws1, 1)
        await cm.connect(ws2, 1)

        # Should not raise despite ws2 failing
        await cm.broadcast_to_race('{"test": true}', 1)

        ws1.send_text.assert_called_once()
        ws2.send_text.assert_called_once()

    @pytest.mark.asyncio
    async def test_simulation_lifecycle(self):
        """Test simulation task start/stop."""
        cm = ConnectionManager()

        assert not cm.is_simulation_running(1)
        cm.start_simulation(1)
        assert cm.is_simulation_running(1)

        cm.stop_simulation(1)
        assert not cm.is_simulation_running(1)

    @pytest.mark.asyncio
    async def test_multiple_race_rooms(self):
        """Test connections to different races are isolated."""
        cm = ConnectionManager()
        ws1 = Mock()
        ws1.accept = AsyncMock()
        ws2 = Mock()
        ws2.accept = AsyncMock()

        await cm.connect(ws1, 1)
        await cm.connect(ws2, 2)

        assert cm.get_connection_count(1) == 1
        assert cm.get_connection_count(2) == 1

        await cm.disconnect(ws1, 1)
        assert cm.get_connection_count(1) == 0
        assert cm.get_connection_count(2) == 1

    @pytest.mark.asyncio
    async def test_send_personal_message(self):
        """Test sending a message to a specific client."""
        cm = ConnectionManager()
        mock_ws = Mock()
        mock_ws.accept = AsyncMock()
        mock_ws.send_text = Mock()

        await cm.send_personal_message('{"hello": "world"}', mock_ws)
        mock_ws.send_text.assert_called_once_with('{"hello": "world"}')

    @pytest.mark.asyncio
    async def test_broadcast_to_nonexistent_race(self):
        """Test broadcasting to a race with no connections does not error."""
        cm = ConnectionManager()
        # Should not raise
        await cm.broadcast_to_race('{"test": true}', 999)

    @pytest.mark.asyncio
    async def test_close_all(self):
        """Test close_all cleans up all connections and simulations."""
        cm = ConnectionManager()
        ws1 = Mock()
        ws1.accept = AsyncMock()
        ws1.close = AsyncMock()

        await cm.connect(ws1, 1)
        cm.start_simulation(1)

        await cm.close_all()

        assert cm.get_connection_count(1) == 0
        assert not cm.is_simulation_running(1)


class TestWebSocketMessageSchemas:
    """Tests for WebSocket message Pydantic models."""

    def test_odds_update_message(self):
        """Test OddsUpdateMessage serialization."""
        from apps.backend.api.websocket import OddsUpdateMessage

        msg = OddsUpdateMessage(
            type="odds_update",
            race_id=42,
            timestamp="2025-06-01T12:00:00",
            odds=[{"horse_id": 1, "odds": 3.5}],
        )
        data = json.loads(msg.model_dump_json())
        assert data["type"] == "odds_update"
        assert data["race_id"] == 42
        assert data["odds"][0]["horse_id"] == 1
        assert data["odds"][0]["odds"] == 3.5

    def test_result_update_message(self):
        """Test ResultUpdateMessage serialization."""
        from apps.backend.api.websocket import ResultUpdateMessage

        msg = ResultUpdateMessage(
            type="result_update",
            race_id=42,
            timestamp="2025-06-01T12:05:00",
            results=[{"horse_id": 1, "placing": 1, "finished": True}],
        )
        data = json.loads(msg.model_dump_json())
        assert data["type"] == "result_update"
        assert data["race_id"] == 42
        assert data["results"][0]["placing"] == 1

    def test_ws_message_schema(self):
        """Test WSMessage schema validation."""
        from apps.backend.api.websocket import WSMessage

        msg = WSMessage(type="subscribe", race_id=42, data={"foo": "bar"})
        assert msg.type == "subscribe"
        assert msg.race_id == 42
        assert msg.data == {"foo": "bar"}

        # Test with no data
        msg2 = WSMessage(type="ping", race_id=1)
        assert msg2.data is None
