"""WebSocket connection manager and live race updates for TipSharks API.

Provides a ConnectionManager for race-room-scoped WebSocket connections
and a background simulation task that generates mock odds/result updates
as a placeholder for future live data integration.
"""

import asyncio
import random
from collections import defaultdict
from datetime import UTC, datetime

from fastapi import WebSocket
from pydantic import BaseModel

from packages.core.common.logging import get_logger
from packages.core.storage.models import Race, Starter

logger = get_logger(__name__)


# ── Pydantic schemas for WebSocket messages ──────────────────────────


class WSMessage(BaseModel):
    """Schema for incoming WebSocket messages from clients."""

    type: str
    race_id: int
    data: dict | None = None


class OddsUpdateMessage(BaseModel):
    """Odds update broadcast message."""

    type: str = "odds_update"
    race_id: int
    timestamp: str
    odds: list[dict]


class ResultUpdateMessage(BaseModel):
    """Result update broadcast message."""

    type: str = "result_update"
    race_id: int
    timestamp: str
    results: list[dict]


class InitialStateMessage(BaseModel):
    """Initial race state sent on connection."""

    type: str = "initial_state"
    race_id: int
    data: dict


# ── Connection Manager ───────────────────────────────────────────────


class ConnectionManager:
    """Manages WebSocket connections grouped by race room.

    Each race has its own "room" identified by race_id. Operations
    are thread-safe via an asyncio lock.
    """

    def __init__(self) -> None:
        self._connections: dict[int, set[WebSocket]] = defaultdict(set)
        self._lock = asyncio.Lock()
        self._simulation_tasks: dict[int, asyncio.Task] = {}

    async def connect(self, websocket: WebSocket, race_id: int) -> None:
        """Accept a WebSocket and add it to the race room.

        Args:
            websocket: The WebSocket connection to register.
            race_id: Race room identifier.
        """
        await websocket.accept()
        async with self._lock:
            self._connections[race_id].add(websocket)
        logger.info(
            "WebSocket connected for race",
            extra={
                "race_id": race_id,
                "connection_count": self.get_connection_count(race_id),
            },
        )

    async def disconnect(self, websocket: WebSocket, race_id: int) -> None:
        """Remove a WebSocket from the race room.

        Cleans up the room entry when the last client disconnects.

        Args:
            websocket: The WebSocket connection to remove.
            race_id: Race room identifier.
        """
        async with self._lock:
            self._connections[race_id].discard(websocket)
            if not self._connections[race_id]:
                del self._connections[race_id]
        logger.info(
            "WebSocket disconnected from race",
            extra={"race_id": race_id, "remaining": self.get_connection_count(race_id)},
        )

    async def send_personal_message(self, message: str, websocket: WebSocket) -> None:
        """Send a JSON message to a single WebSocket client.

        Args:
            message: JSON-encoded message string.
            websocket: The target WebSocket connection.
        """
        try:
            await websocket.send_text(message)
        except Exception:
            logger.warning("Failed to send personal message", exc_info=True)

    async def broadcast_to_race(self, message: str, race_id: int) -> None:
        """Broadcast a JSON message to all clients in a race room.

        Args:
            message: JSON-encoded message string.
            race_id: Target race room identifier.
        """
        async with self._lock:
            connections = self._connections.get(race_id, set()).copy()
        for ws in connections:
            try:
                await ws.send_text(message)
            except Exception:
                logger.warning(
                    "Failed to broadcast to client for race",
                    extra={"race_id": race_id},
                    exc_info=True,
                )

    def get_connection_count(self, race_id: int) -> int:
        """Return the number of connections for a given race."""
        return len(self._connections.get(race_id, set()))

    def is_simulation_running(self, race_id: int) -> bool:
        """Check if a simulation task is already running for a race."""
        return race_id in self._simulation_tasks

    def start_simulation(self, race_id: int) -> None:
        """Start the background simulation for a race if not already running.

        Args:
            race_id: Race identifier to simulate.
        """
        if race_id not in self._simulation_tasks:
            task = asyncio.create_task(simulate_race_updates(race_id))
            self._simulation_tasks[race_id] = task
            logger.info("Started race simulation", extra={"race_id": race_id})

    def stop_simulation(self, race_id: int) -> None:
        """Cancel the background simulation for a race.

        Args:
            race_id: Race identifier to stop simulating.
        """
        task = self._simulation_tasks.pop(race_id, None)
        if task is not None:
            task.cancel()
            logger.info("Stopped race simulation", extra={"race_id": race_id})

    async def close_all(self) -> None:
        """Close all active connections and cancel all simulations.

        Used for graceful shutdown.
        """
        async with self._lock:
            for race_id in list(self._simulation_tasks):
                self.stop_simulation(race_id)
            for race_id in list(self._connections):
                for ws in self._connections[race_id]:
                    try:
                        await ws.close()
                    except Exception:
                        pass
            self._connections.clear()
            self._simulation_tasks.clear()


# Global connection manager singleton
manager = ConnectionManager()


# ── Helpers to build initial state ───────────────────────────────────


def _build_initial_state(race: Race, starters: list[Starter]) -> str:
    """Build the initial race state JSON string.

    Args:
        race: The Race ORM instance.
        starters: List of Starter ORM instances.

    Returns:
        JSON-encoded InitialStateMessage.
    """
    starter_list = []
    for s in starters:
        horse_name = s.horse.name if s.horse else None
        driver_name = s.driver.name if s.driver else None
        trainer_name = s.trainer.name if s.trainer else None
        starter_list.append(
            {
                "id": s.id,
                "horse_id": s.horse_id,
                "horse_name": horse_name,
                "driver_id": s.driver_id,
                "driver_name": driver_name,
                "trainer_id": s.trainer_id,
                "trainer_name": trainer_name,
                "runner_number": s.runner_number,
                "barrier": s.barrier,
                "handicap_m": s.handicap_m,
                "placing": s.placing,
                "did_not_finish": s.did_not_finish,
            }
        )

    venue = race.meeting.venue if race.meeting else None
    meeting_date = race.meeting.meeting_date.isoformat() if race.meeting and race.meeting.meeting_date else None  # type: ignore[union-attr]

    race_dt = race.race_datetime
    race_dt_iso = race_dt.isoformat() if race_dt else None  # type: ignore[union-attr]

    data = {
        "race": {
            "id": race.id,
            "meeting_id": race.meeting_id,
            "race_number": race.race_number,
            "distance_m": race.distance_m,
            "start_type": race.start_type,
            "gait": race.gait,
            "weather": race.weather,
            "track_condition": race.track_condition,
            "race_datetime": race_dt_iso,
            "venue": venue,
            "meeting_date": meeting_date,
        },
        "starters": starter_list,
        "starter_count": len(starters),
    }

    msg = InitialStateMessage(race_id=int(race.id), data=data)  # type: ignore[arg-type]
    return msg.model_dump_json()


# ── Background Simulation ────────────────────────────────────────────


async def simulate_race_updates(race_id: int) -> None:
    """Simulate live race updates for a race room.

    Broadcasts mock odds updates every 5-10 seconds for 60 seconds,
    then sends a final result update. This is a placeholder until
    real live data integration is built.

    The function checks for active connections before each broadcast
    and exits early if the room is empty.

    Args:
        race_id: Race identifier to simulate.
    """
    logger.info("Race simulation task started", extra={"race_id": race_id})
    start_time = datetime.now(UTC)
    elapsed = 0.0
    num_horses = random.randint(6, 12)

    try:
        while elapsed < 60:
            if manager.get_connection_count(race_id) == 0:
                logger.info(
                    "No more connections for race, simulation exiting",
                    extra={"race_id": race_id},
                )
                return

            # Regenerate mock odds for up to 12 horses each cycle
            num_horses = random.randint(6, 12)
            odds_list = [
                {"horse_id": i + 1, "odds": round(random.uniform(1.5, 50.0), 2)}
                for i in range(num_horses)
            ]

            message = OddsUpdateMessage(
                type="odds_update",
                race_id=race_id,
                timestamp=datetime.now(UTC).isoformat(),
                odds=odds_list,
            )
            await manager.broadcast_to_race(message.model_dump_json(), race_id)

            # Wait 5-10 seconds before next update
            delay = random.uniform(5.0, 10.0)
            await asyncio.sleep(delay)
            elapsed = (datetime.now(UTC) - start_time).total_seconds()

        # Send result update after 60 seconds
        if manager.get_connection_count(race_id) > 0:
            result_message = ResultUpdateMessage(
                type="result_update",
                race_id=race_id,
                timestamp=datetime.now(UTC).isoformat(),
                results=[
                    {"horse_id": i + 1, "placing": i + 1, "finished": True}
                    for i in range(num_horses)
                ],
            )
            await manager.broadcast_to_race(result_message.model_dump_json(), race_id)

        logger.info("Race simulation completed", extra={"race_id": race_id})
    except asyncio.CancelledError:
        logger.info("Race simulation cancelled", extra={"race_id": race_id})
    except Exception:
        logger.exception("Race simulation error", extra={"race_id": race_id})
