Coverage for apps / backend / api / websocket.py: 79%
122 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
1"""WebSocket connection manager and live race updates for TipSharks API.
3Provides a ConnectionManager for race-room-scoped WebSocket connections
4and a background simulation task that generates mock odds/result updates
5as a placeholder for future live data integration.
6"""
8import asyncio
9import random
10from collections import defaultdict
11from datetime import UTC, datetime
13from fastapi import WebSocket
14from pydantic import BaseModel
16from packages.core.common.logging import get_logger
17from packages.core.storage.models import Race, Starter
19logger = get_logger(__name__)
22# ── Pydantic schemas for WebSocket messages ──────────────────────────
25class WSMessage(BaseModel):
26 """Schema for incoming WebSocket messages from clients."""
28 type: str
29 race_id: int
30 data: dict | None = None
33class OddsUpdateMessage(BaseModel):
34 """Odds update broadcast message."""
36 type: str = "odds_update"
37 race_id: int
38 timestamp: str
39 odds: list[dict]
42class ResultUpdateMessage(BaseModel):
43 """Result update broadcast message."""
45 type: str = "result_update"
46 race_id: int
47 timestamp: str
48 results: list[dict]
51class InitialStateMessage(BaseModel):
52 """Initial race state sent on connection."""
54 type: str = "initial_state"
55 race_id: int
56 data: dict
59# ── Connection Manager ───────────────────────────────────────────────
62class ConnectionManager:
63 """Manages WebSocket connections grouped by race room.
65 Each race has its own "room" identified by race_id. Operations
66 are thread-safe via an asyncio lock.
67 """
69 def __init__(self) -> None:
70 self._connections: dict[int, set[WebSocket]] = defaultdict(set)
71 self._lock = asyncio.Lock()
72 self._simulation_tasks: dict[int, asyncio.Task] = {}
74 async def connect(self, websocket: WebSocket, race_id: int) -> None:
75 """Accept a WebSocket and add it to the race room.
77 Args:
78 websocket: The WebSocket connection to register.
79 race_id: Race room identifier.
80 """
81 await websocket.accept()
82 async with self._lock:
83 self._connections[race_id].add(websocket)
84 logger.info(
85 "WebSocket connected for race",
86 extra={
87 "race_id": race_id,
88 "connection_count": self.get_connection_count(race_id),
89 },
90 )
92 async def disconnect(self, websocket: WebSocket, race_id: int) -> None:
93 """Remove a WebSocket from the race room.
95 Cleans up the room entry when the last client disconnects.
97 Args:
98 websocket: The WebSocket connection to remove.
99 race_id: Race room identifier.
100 """
101 async with self._lock:
102 self._connections[race_id].discard(websocket)
103 if not self._connections[race_id]:
104 del self._connections[race_id]
105 logger.info(
106 "WebSocket disconnected from race",
107 extra={"race_id": race_id, "remaining": self.get_connection_count(race_id)},
108 )
110 async def send_personal_message(self, message: str, websocket: WebSocket) -> None:
111 """Send a JSON message to a single WebSocket client.
113 Args:
114 message: JSON-encoded message string.
115 websocket: The target WebSocket connection.
116 """
117 try:
118 await websocket.send_text(message)
119 except Exception:
120 logger.warning("Failed to send personal message", exc_info=True)
122 async def broadcast_to_race(self, message: str, race_id: int) -> None:
123 """Broadcast a JSON message to all clients in a race room.
125 Args:
126 message: JSON-encoded message string.
127 race_id: Target race room identifier.
128 """
129 async with self._lock:
130 connections = self._connections.get(race_id, set()).copy()
131 for ws in connections:
132 try:
133 await ws.send_text(message)
134 except Exception:
135 logger.warning(
136 "Failed to broadcast to client for race",
137 extra={"race_id": race_id},
138 exc_info=True,
139 )
141 def get_connection_count(self, race_id: int) -> int:
142 """Return the number of connections for a given race."""
143 return len(self._connections.get(race_id, set()))
145 def is_simulation_running(self, race_id: int) -> bool:
146 """Check if a simulation task is already running for a race."""
147 return race_id in self._simulation_tasks
149 def start_simulation(self, race_id: int) -> None:
150 """Start the background simulation for a race if not already running.
152 Args:
153 race_id: Race identifier to simulate.
154 """
155 if race_id not in self._simulation_tasks:
156 task = asyncio.create_task(simulate_race_updates(race_id))
157 self._simulation_tasks[race_id] = task
158 logger.info("Started race simulation", extra={"race_id": race_id})
160 def stop_simulation(self, race_id: int) -> None:
161 """Cancel the background simulation for a race.
163 Args:
164 race_id: Race identifier to stop simulating.
165 """
166 task = self._simulation_tasks.pop(race_id, None)
167 if task is not None:
168 task.cancel()
169 logger.info("Stopped race simulation", extra={"race_id": race_id})
171 async def close_all(self) -> None:
172 """Close all active connections and cancel all simulations.
174 Used for graceful shutdown.
175 """
176 async with self._lock:
177 for race_id in list(self._simulation_tasks):
178 self.stop_simulation(race_id)
179 for race_id in list(self._connections):
180 for ws in self._connections[race_id]:
181 try:
182 await ws.close()
183 except Exception:
184 pass
185 self._connections.clear()
186 self._simulation_tasks.clear()
189# Global connection manager singleton
190manager = ConnectionManager()
193# ── Helpers to build initial state ───────────────────────────────────
196def _build_initial_state(race: Race, starters: list[Starter]) -> str:
197 """Build the initial race state JSON string.
199 Args:
200 race: The Race ORM instance.
201 starters: List of Starter ORM instances.
203 Returns:
204 JSON-encoded InitialStateMessage.
205 """
206 starter_list = []
207 for s in starters:
208 horse_name = s.horse.name if s.horse else None
209 driver_name = s.driver.name if s.driver else None
210 trainer_name = s.trainer.name if s.trainer else None
211 starter_list.append(
212 {
213 "id": s.id,
214 "horse_id": s.horse_id,
215 "horse_name": horse_name,
216 "driver_id": s.driver_id,
217 "driver_name": driver_name,
218 "trainer_id": s.trainer_id,
219 "trainer_name": trainer_name,
220 "runner_number": s.runner_number,
221 "barrier": s.barrier,
222 "handicap_m": s.handicap_m,
223 "placing": s.placing,
224 "did_not_finish": s.did_not_finish,
225 }
226 )
228 venue = race.meeting.venue if race.meeting else None
229 meeting_date = race.meeting.meeting_date.isoformat() if race.meeting and race.meeting.meeting_date else None # type: ignore[union-attr]
231 race_dt = race.race_datetime
232 race_dt_iso = race_dt.isoformat() if race_dt else None # type: ignore[union-attr]
234 data = {
235 "race": {
236 "id": race.id,
237 "meeting_id": race.meeting_id,
238 "race_number": race.race_number,
239 "distance_m": race.distance_m,
240 "start_type": race.start_type,
241 "gait": race.gait,
242 "weather": race.weather,
243 "track_condition": race.track_condition,
244 "race_datetime": race_dt_iso,
245 "venue": venue,
246 "meeting_date": meeting_date,
247 },
248 "starters": starter_list,
249 "starter_count": len(starters),
250 }
252 msg = InitialStateMessage(race_id=int(race.id), data=data) # type: ignore[arg-type]
253 return msg.model_dump_json()
256# ── Background Simulation ────────────────────────────────────────────
259async def simulate_race_updates(race_id: int) -> None:
260 """Simulate live race updates for a race room.
262 Broadcasts mock odds updates every 5-10 seconds for 60 seconds,
263 then sends a final result update. This is a placeholder until
264 real live data integration is built.
266 The function checks for active connections before each broadcast
267 and exits early if the room is empty.
269 Args:
270 race_id: Race identifier to simulate.
271 """
272 logger.info("Race simulation task started", extra={"race_id": race_id})
273 start_time = datetime.now(UTC)
274 elapsed = 0.0
275 num_horses = random.randint(6, 12)
277 try:
278 while elapsed < 60:
279 if manager.get_connection_count(race_id) == 0:
280 logger.info(
281 "No more connections for race, simulation exiting",
282 extra={"race_id": race_id},
283 )
284 return
286 # Regenerate mock odds for up to 12 horses each cycle
287 num_horses = random.randint(6, 12)
288 odds_list = [
289 {"horse_id": i + 1, "odds": round(random.uniform(1.5, 50.0), 2)}
290 for i in range(num_horses)
291 ]
293 message = OddsUpdateMessage(
294 type="odds_update",
295 race_id=race_id,
296 timestamp=datetime.now(UTC).isoformat(),
297 odds=odds_list,
298 )
299 await manager.broadcast_to_race(message.model_dump_json(), race_id)
301 # Wait 5-10 seconds before next update
302 delay = random.uniform(5.0, 10.0)
303 await asyncio.sleep(delay)
304 elapsed = (datetime.now(UTC) - start_time).total_seconds()
306 # Send result update after 60 seconds
307 if manager.get_connection_count(race_id) > 0:
308 result_message = ResultUpdateMessage(
309 type="result_update",
310 race_id=race_id,
311 timestamp=datetime.now(UTC).isoformat(),
312 results=[
313 {"horse_id": i + 1, "placing": i + 1, "finished": True}
314 for i in range(num_horses)
315 ],
316 )
317 await manager.broadcast_to_race(result_message.model_dump_json(), race_id)
319 logger.info("Race simulation completed", extra={"race_id": race_id})
320 except asyncio.CancelledError:
321 logger.info("Race simulation cancelled", extra={"race_id": race_id})
322 except Exception:
323 logger.exception("Race simulation error", extra={"race_id": race_id})