Coverage for apps / backend / api / websocket.py: 79%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 08:37 +1200

1"""WebSocket connection manager and live race updates for TipSharks API. 

2 

3Provides a ConnectionManager for race-room-scoped WebSocket connections 

4and a background simulation task that generates mock odds/result updates 

5as a placeholder for future live data integration. 

6""" 

7 

8import asyncio 

9import random 

10from collections import defaultdict 

11from datetime import UTC, datetime 

12 

13from fastapi import WebSocket 

14from pydantic import BaseModel 

15 

16from packages.core.common.logging import get_logger 

17from packages.core.storage.models import Race, Starter 

18 

19logger = get_logger(__name__) 

20 

21 

22# ── Pydantic schemas for WebSocket messages ────────────────────────── 

23 

24 

25class WSMessage(BaseModel): 

26 """Schema for incoming WebSocket messages from clients.""" 

27 

28 type: str 

29 race_id: int 

30 data: dict | None = None 

31 

32 

33class OddsUpdateMessage(BaseModel): 

34 """Odds update broadcast message.""" 

35 

36 type: str = "odds_update" 

37 race_id: int 

38 timestamp: str 

39 odds: list[dict] 

40 

41 

42class ResultUpdateMessage(BaseModel): 

43 """Result update broadcast message.""" 

44 

45 type: str = "result_update" 

46 race_id: int 

47 timestamp: str 

48 results: list[dict] 

49 

50 

51class InitialStateMessage(BaseModel): 

52 """Initial race state sent on connection.""" 

53 

54 type: str = "initial_state" 

55 race_id: int 

56 data: dict 

57 

58 

59# ── Connection Manager ─────────────────────────────────────────────── 

60 

61 

62class ConnectionManager: 

63 """Manages WebSocket connections grouped by race room. 

64 

65 Each race has its own "room" identified by race_id. Operations 

66 are thread-safe via an asyncio lock. 

67 """ 

68 

69 def __init__(self) -> None: 

70 self._connections: dict[int, set[WebSocket]] = defaultdict(set) 

71 self._lock = asyncio.Lock() 

72 self._simulation_tasks: dict[int, asyncio.Task] = {} 

73 

74 async def connect(self, websocket: WebSocket, race_id: int) -> None: 

75 """Accept a WebSocket and add it to the race room. 

76 

77 Args: 

78 websocket: The WebSocket connection to register. 

79 race_id: Race room identifier. 

80 """ 

81 await websocket.accept() 

82 async with self._lock: 

83 self._connections[race_id].add(websocket) 

84 logger.info( 

85 "WebSocket connected for race", 

86 extra={ 

87 "race_id": race_id, 

88 "connection_count": self.get_connection_count(race_id), 

89 }, 

90 ) 

91 

92 async def disconnect(self, websocket: WebSocket, race_id: int) -> None: 

93 """Remove a WebSocket from the race room. 

94 

95 Cleans up the room entry when the last client disconnects. 

96 

97 Args: 

98 websocket: The WebSocket connection to remove. 

99 race_id: Race room identifier. 

100 """ 

101 async with self._lock: 

102 self._connections[race_id].discard(websocket) 

103 if not self._connections[race_id]: 

104 del self._connections[race_id] 

105 logger.info( 

106 "WebSocket disconnected from race", 

107 extra={"race_id": race_id, "remaining": self.get_connection_count(race_id)}, 

108 ) 

109 

110 async def send_personal_message(self, message: str, websocket: WebSocket) -> None: 

111 """Send a JSON message to a single WebSocket client. 

112 

113 Args: 

114 message: JSON-encoded message string. 

115 websocket: The target WebSocket connection. 

116 """ 

117 try: 

118 await websocket.send_text(message) 

119 except Exception: 

120 logger.warning("Failed to send personal message", exc_info=True) 

121 

122 async def broadcast_to_race(self, message: str, race_id: int) -> None: 

123 """Broadcast a JSON message to all clients in a race room. 

124 

125 Args: 

126 message: JSON-encoded message string. 

127 race_id: Target race room identifier. 

128 """ 

129 async with self._lock: 

130 connections = self._connections.get(race_id, set()).copy() 

131 for ws in connections: 

132 try: 

133 await ws.send_text(message) 

134 except Exception: 

135 logger.warning( 

136 "Failed to broadcast to client for race", 

137 extra={"race_id": race_id}, 

138 exc_info=True, 

139 ) 

140 

141 def get_connection_count(self, race_id: int) -> int: 

142 """Return the number of connections for a given race.""" 

143 return len(self._connections.get(race_id, set())) 

144 

145 def is_simulation_running(self, race_id: int) -> bool: 

146 """Check if a simulation task is already running for a race.""" 

147 return race_id in self._simulation_tasks 

148 

149 def start_simulation(self, race_id: int) -> None: 

150 """Start the background simulation for a race if not already running. 

151 

152 Args: 

153 race_id: Race identifier to simulate. 

154 """ 

155 if race_id not in self._simulation_tasks: 

156 task = asyncio.create_task(simulate_race_updates(race_id)) 

157 self._simulation_tasks[race_id] = task 

158 logger.info("Started race simulation", extra={"race_id": race_id}) 

159 

160 def stop_simulation(self, race_id: int) -> None: 

161 """Cancel the background simulation for a race. 

162 

163 Args: 

164 race_id: Race identifier to stop simulating. 

165 """ 

166 task = self._simulation_tasks.pop(race_id, None) 

167 if task is not None: 

168 task.cancel() 

169 logger.info("Stopped race simulation", extra={"race_id": race_id}) 

170 

171 async def close_all(self) -> None: 

172 """Close all active connections and cancel all simulations. 

173 

174 Used for graceful shutdown. 

175 """ 

176 async with self._lock: 

177 for race_id in list(self._simulation_tasks): 

178 self.stop_simulation(race_id) 

179 for race_id in list(self._connections): 

180 for ws in self._connections[race_id]: 

181 try: 

182 await ws.close() 

183 except Exception: 

184 pass 

185 self._connections.clear() 

186 self._simulation_tasks.clear() 

187 

188 

189# Global connection manager singleton 

190manager = ConnectionManager() 

191 

192 

193# ── Helpers to build initial state ─────────────────────────────────── 

194 

195 

196def _build_initial_state(race: Race, starters: list[Starter]) -> str: 

197 """Build the initial race state JSON string. 

198 

199 Args: 

200 race: The Race ORM instance. 

201 starters: List of Starter ORM instances. 

202 

203 Returns: 

204 JSON-encoded InitialStateMessage. 

205 """ 

206 starter_list = [] 

207 for s in starters: 

208 horse_name = s.horse.name if s.horse else None 

209 driver_name = s.driver.name if s.driver else None 

210 trainer_name = s.trainer.name if s.trainer else None 

211 starter_list.append( 

212 { 

213 "id": s.id, 

214 "horse_id": s.horse_id, 

215 "horse_name": horse_name, 

216 "driver_id": s.driver_id, 

217 "driver_name": driver_name, 

218 "trainer_id": s.trainer_id, 

219 "trainer_name": trainer_name, 

220 "runner_number": s.runner_number, 

221 "barrier": s.barrier, 

222 "handicap_m": s.handicap_m, 

223 "placing": s.placing, 

224 "did_not_finish": s.did_not_finish, 

225 } 

226 ) 

227 

228 venue = race.meeting.venue if race.meeting else None 

229 meeting_date = race.meeting.meeting_date.isoformat() if race.meeting and race.meeting.meeting_date else None # type: ignore[union-attr] 

230 

231 race_dt = race.race_datetime 

232 race_dt_iso = race_dt.isoformat() if race_dt else None # type: ignore[union-attr] 

233 

234 data = { 

235 "race": { 

236 "id": race.id, 

237 "meeting_id": race.meeting_id, 

238 "race_number": race.race_number, 

239 "distance_m": race.distance_m, 

240 "start_type": race.start_type, 

241 "gait": race.gait, 

242 "weather": race.weather, 

243 "track_condition": race.track_condition, 

244 "race_datetime": race_dt_iso, 

245 "venue": venue, 

246 "meeting_date": meeting_date, 

247 }, 

248 "starters": starter_list, 

249 "starter_count": len(starters), 

250 } 

251 

252 msg = InitialStateMessage(race_id=int(race.id), data=data) # type: ignore[arg-type] 

253 return msg.model_dump_json() 

254 

255 

256# ── Background Simulation ──────────────────────────────────────────── 

257 

258 

259async def simulate_race_updates(race_id: int) -> None: 

260 """Simulate live race updates for a race room. 

261 

262 Broadcasts mock odds updates every 5-10 seconds for 60 seconds, 

263 then sends a final result update. This is a placeholder until 

264 real live data integration is built. 

265 

266 The function checks for active connections before each broadcast 

267 and exits early if the room is empty. 

268 

269 Args: 

270 race_id: Race identifier to simulate. 

271 """ 

272 logger.info("Race simulation task started", extra={"race_id": race_id}) 

273 start_time = datetime.now(UTC) 

274 elapsed = 0.0 

275 num_horses = random.randint(6, 12) 

276 

277 try: 

278 while elapsed < 60: 

279 if manager.get_connection_count(race_id) == 0: 

280 logger.info( 

281 "No more connections for race, simulation exiting", 

282 extra={"race_id": race_id}, 

283 ) 

284 return 

285 

286 # Regenerate mock odds for up to 12 horses each cycle 

287 num_horses = random.randint(6, 12) 

288 odds_list = [ 

289 {"horse_id": i + 1, "odds": round(random.uniform(1.5, 50.0), 2)} 

290 for i in range(num_horses) 

291 ] 

292 

293 message = OddsUpdateMessage( 

294 type="odds_update", 

295 race_id=race_id, 

296 timestamp=datetime.now(UTC).isoformat(), 

297 odds=odds_list, 

298 ) 

299 await manager.broadcast_to_race(message.model_dump_json(), race_id) 

300 

301 # Wait 5-10 seconds before next update 

302 delay = random.uniform(5.0, 10.0) 

303 await asyncio.sleep(delay) 

304 elapsed = (datetime.now(UTC) - start_time).total_seconds() 

305 

306 # Send result update after 60 seconds 

307 if manager.get_connection_count(race_id) > 0: 

308 result_message = ResultUpdateMessage( 

309 type="result_update", 

310 race_id=race_id, 

311 timestamp=datetime.now(UTC).isoformat(), 

312 results=[ 

313 {"horse_id": i + 1, "placing": i + 1, "finished": True} 

314 for i in range(num_horses) 

315 ], 

316 ) 

317 await manager.broadcast_to_race(result_message.model_dump_json(), race_id) 

318 

319 logger.info("Race simulation completed", extra={"race_id": race_id}) 

320 except asyncio.CancelledError: 

321 logger.info("Race simulation cancelled", extra={"race_id": race_id}) 

322 except Exception: 

323 logger.exception("Race simulation error", extra={"race_id": race_id})