Coverage for packages / core / ratings / predictions.py: 57%

206 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 08:37 +1200

1"""Enhanced race prediction functionality with confidence intervals and tracking.""" 

2 

3import math 

4from dataclasses import dataclass 

5from datetime import date, datetime 

6 

7from sqlalchemy import and_, func, or_ 

8from sqlalchemy.orm import Session 

9 

10from packages.core.common.logging import get_logger 

11from packages.core.common.settings import get_settings 

12from packages.core.ratings.engine import RatingEngine 

13from packages.core.storage.models import EntityType, Race, Starter 

14from packages.core.storage.repositories import RatingSnapshotRepository 

15 

16logger = get_logger(__name__) 

17 

18 

19@dataclass 

20class PredictionResult: 

21 """Prediction for a single starter.""" 

22 

23 starter_id: int 

24 horse_id: int 

25 horse_name: str | None 

26 driver_id: int | None 

27 driver_name: str | None 

28 trainer_id: int | None 

29 trainer_name: str | None 

30 barrier: int | None 

31 handicap_m: float | None 

32 effective_rating: float 

33 win_probability: float 

34 place_probability: float # Top 3 

35 place_score: float 

36 confidence_interval_low: float # 95% CI lower bound 

37 confidence_interval_high: float # 95% CI upper bound 

38 predicted_placing: int 

39 

40 

41@dataclass 

42class RacePrediction: 

43 """Complete prediction for a race.""" 

44 

45 race_id: int 

46 race_number: int | None 

47 venue: str | None 

48 distance_m: int | None 

49 race_date: date | None 

50 predictions: list[PredictionResult] 

51 prediction_timestamp: datetime 

52 metadata: dict 

53 

54 

55class PredictionEngine: 

56 """Enhanced prediction engine with confidence intervals and tracking.""" 

57 

58 def __init__(self, session: Session): 

59 """Initialize prediction engine. 

60 

61 Args: 

62 session: Database session 

63 """ 

64 self.session = session 

65 self.rating_engine = RatingEngine(session) 

66 settings = get_settings().rating 

67 self.place_history_limit = settings.place_history_limit 

68 self.place_prior_rate = settings.place_prior_rate 

69 self.place_prior_weight = settings.place_prior_weight 

70 self.place_top3_weight = settings.place_top3_weight 

71 self.place_consistency_weight = settings.place_consistency_weight 

72 

73 def predict_race(self, race: Race, starters: list[Starter]) -> RacePrediction: 

74 """Generate predictions for a race. 

75 

76 Args: 

77 race: Race instance 

78 starters: List of starters 

79 

80 Returns: 

81 Complete race prediction with probabilities 

82 """ 

83 if not starters: 

84 logger.warning(f"No starters for race {race.id}") 

85 return RacePrediction( 

86 race_id=race.id, 

87 race_number=race.race_number, 

88 venue=race.meeting.venue if race.meeting else None, 

89 distance_m=race.distance_m, 

90 race_date=race.meeting.meeting_date if race.meeting else None, 

91 predictions=[], 

92 prediction_timestamp=datetime.now(), 

93 metadata={}, 

94 ) 

95 

96 # Load latest ratings for all starters (use pre-race ratings if available) 

97 for starter in starters: 

98 if starter.horse_id: 

99 latest = RatingSnapshotRepository.get_latest_rating( 

100 self.session, 

101 EntityType.HORSE, 

102 starter.horse_id, 

103 before_race_id=race.id, 

104 ) 

105 if latest: 

106 self.rating_engine.load_rating_state( 

107 EntityType.HORSE, 

108 starter.horse_id, 

109 latest.rating, 

110 latest.rd, 

111 last_race_date=( 

112 latest.race.meeting.meeting_date 

113 if latest.race and latest.race.meeting 

114 else None 

115 ), 

116 ) 

117 

118 if starter.driver_id: 

119 latest = RatingSnapshotRepository.get_latest_rating( 

120 self.session, 

121 EntityType.DRIVER, 

122 starter.driver_id, 

123 before_race_id=race.id, 

124 ) 

125 if latest: 

126 self.rating_engine.load_rating_state( 

127 EntityType.DRIVER, 

128 starter.driver_id, 

129 latest.rating, 

130 latest.rd, 

131 last_race_date=( 

132 latest.race.meeting.meeting_date 

133 if latest.race and latest.race.meeting 

134 else None 

135 ), 

136 ) 

137 

138 if starter.trainer_id: 

139 latest = RatingSnapshotRepository.get_latest_rating( 

140 self.session, 

141 EntityType.TRAINER, 

142 starter.trainer_id, 

143 before_race_id=race.id, 

144 ) 

145 if latest: 

146 self.rating_engine.load_rating_state( 

147 EntityType.TRAINER, 

148 starter.trainer_id, 

149 latest.rating, 

150 latest.rd, 

151 last_race_date=( 

152 latest.race.meeting.meeting_date 

153 if latest.race and latest.race.meeting 

154 else None 

155 ), 

156 ) 

157 

158 # Compute effective ratings for all starters 

159 effective_ratings = {} 

160 rating_uncertainties = {} 

161 

162 for starter in starters: 

163 if not starter.horse_id: 

164 continue 

165 

166 r_eff = self.rating_engine.compute_effective_rating(starter, race) 

167 effective_ratings[starter.id] = r_eff 

168 

169 # Get rating uncertainty (RD) for confidence interval 

170 horse_state = self.rating_engine.get_or_init_rating( 

171 EntityType.HORSE, starter.horse_id 

172 ) 

173 rating_uncertainties[starter.id] = horse_state.rd or 100.0 

174 

175 if not effective_ratings: 

176 logger.warning(f"No ratings available for race {race.id}") 

177 return RacePrediction( 

178 race_id=race.id, 

179 race_number=race.race_number, 

180 venue=race.meeting.venue if race.meeting else None, 

181 distance_m=race.distance_m, 

182 race_date=race.meeting.meeting_date if race.meeting else None, 

183 predictions=[], 

184 prediction_timestamp=datetime.now(), 

185 metadata={}, 

186 ) 

187 

188 # Compute win probabilities using softmax 

189 win_probs = self._compute_win_probabilities(effective_ratings) 

190 

191 # Compute place scores/probabilities (top 3) 

192 place_scores = self._compute_place_scores(race, starters, effective_ratings) 

193 place_probs = self._compute_place_probabilities(place_scores, top_n=3) 

194 

195 # Generate predictions for each starter 

196 predictions = [] 

197 place_ranks = self._compute_place_ranks(place_scores) 

198 for starter in starters: 

199 if starter.id not in effective_ratings: 

200 continue 

201 

202 r_eff = effective_ratings[starter.id] 

203 win_prob = win_probs.get(starter.id, 0.0) 

204 place_prob = place_probs.get(starter.id, 0.0) 

205 place_score = place_scores.get(starter.id, r_eff) 

206 rd = rating_uncertainties.get(starter.id, 100.0) 

207 

208 # Compute confidence interval (95% CI = ±1.96 * RD) 

209 ci_low = r_eff - 1.96 * rd 

210 ci_high = r_eff + 1.96 * rd 

211 

212 predicted_placing = place_ranks.get(starter.id, len(place_ranks)) 

213 

214 predictions.append( 

215 PredictionResult( 

216 starter_id=starter.id, 

217 horse_id=starter.horse_id, 

218 horse_name=starter.horse.name if starter.horse else None, 

219 driver_id=starter.driver_id, 

220 driver_name=starter.driver.name if starter.driver else None, 

221 trainer_id=starter.trainer_id, 

222 trainer_name=starter.trainer.name if starter.trainer else None, 

223 barrier=starter.barrier, 

224 handicap_m=starter.handicap_m, 

225 effective_rating=r_eff, 

226 win_probability=win_prob, 

227 place_probability=place_prob, 

228 place_score=place_score, 

229 confidence_interval_low=ci_low, 

230 confidence_interval_high=ci_high, 

231 predicted_placing=predicted_placing, 

232 ) 

233 ) 

234 

235 # Sort by predicted placing 

236 predictions.sort(key=lambda p: p.predicted_placing) 

237 

238 metadata = { 

239 "field_size": len(predictions), 

240 "avg_rating": ( 

241 sum(p.effective_rating for p in predictions) / len(predictions) 

242 if predictions 

243 else 0 

244 ), 

245 "rating_spread": ( 

246 max(p.effective_rating for p in predictions) 

247 - min(p.effective_rating for p in predictions) 

248 if predictions 

249 else 0 

250 ), 

251 "avg_place_score": ( 

252 sum(p.place_score for p in predictions) / len(predictions) 

253 if predictions 

254 else 0 

255 ), 

256 } 

257 

258 return RacePrediction( 

259 race_id=race.id, 

260 race_number=race.race_number, 

261 venue=race.meeting.venue if race.meeting else None, 

262 distance_m=race.distance_m, 

263 race_date=race.meeting.meeting_date if race.meeting else None, 

264 predictions=predictions, 

265 prediction_timestamp=datetime.now(), 

266 metadata=metadata, 

267 ) 

268 

269 def _compute_win_probabilities( 

270 self, effective_ratings: dict[int, float] 

271 ) -> dict[int, float]: 

272 """Compute win probabilities using softmax. 

273 

274 Args: 

275 effective_ratings: Dictionary of starter_id -> effective rating 

276 

277 Returns: 

278 Dictionary of starter_id -> win probability 

279 """ 

280 if not effective_ratings: 

281 return {} 

282 

283 # Numerically stable softmax 

284 rating_values = list(effective_ratings.values()) 

285 max_rating = max(rating_values) 

286 

287 # Scale factor (400 = standard Elo scale) 

288 scale = 400.0 

289 

290 # Compute exp(rating/scale) for all starters 

291 exp_ratings = { 

292 sid: math.exp((rating - max_rating) / scale) 

293 for sid, rating in effective_ratings.items() 

294 } 

295 

296 total_exp = sum(exp_ratings.values()) 

297 

298 # Normalize to probabilities 

299 return {sid: exp_val / total_exp for sid, exp_val in exp_ratings.items()} 

300 

301 def _compute_place_probabilities( 

302 self, place_scores: dict[int, float], top_n: int = 3 

303 ) -> dict[int, float]: 

304 """Compute place probabilities (finishing in top N). 

305 

306 Args: 

307 place_scores: Dictionary of starter_id -> place score 

308 top_n: Number of top placings to consider 

309 

310 Returns: 

311 Dictionary of starter_id -> place probability 

312 """ 

313 if not place_scores or len(place_scores) <= top_n: 

314 return dict.fromkeys(place_scores, 1.0) 

315 

316 base_probs = self._compute_win_probabilities(place_scores) 

317 return {sid: min(1.0, prob * (top_n + 0.5)) for sid, prob in base_probs.items()} 

318 

319 def _compute_place_scores( 

320 self, 

321 race: Race, 

322 starters: list[Starter], 

323 effective_ratings: dict[int, float], 

324 ) -> dict[int, float]: 

325 """Compute place scores from ratings and consistency history.""" 

326 scores = {} 

327 scale = self.rating_engine.settings.elo_scale_c 

328 

329 for starter in starters: 

330 if starter.id not in effective_ratings or not starter.horse_id: 

331 continue 

332 

333 history = self._get_recent_finish_stats(race, starter.horse_id) 

334 top3_rate = history["top3_rate"] 

335 consistency = history["consistency"] 

336 

337 rating_adjustment = ( 

338 self.place_top3_weight * (top3_rate - self.place_prior_rate) * scale 

339 + self.place_consistency_weight * (consistency - 0.5) * scale 

340 ) 

341 

342 scores[starter.id] = effective_ratings[starter.id] + rating_adjustment 

343 

344 return scores 

345 

346 def _compute_place_ranks(self, place_scores: dict[int, float]) -> dict[int, int]: 

347 """Convert place scores into ordinal ranks.""" 

348 ordered = sorted(place_scores.items(), key=lambda item: item[1], reverse=True) 

349 return {starter_id: idx + 1 for idx, (starter_id, _) in enumerate(ordered)} 

350 

351 def _get_recent_finish_stats(self, race: Race, horse_id: int) -> dict[str, float]: 

352 """Compute smoothed top-3 rate and consistency from recent finishes.""" 

353 from packages.core.storage.models import Meeting 

354 

355 field_size_subquery = ( 

356 self.session.query( 

357 Starter.race_id.label("race_id"), 

358 func.count(Starter.id).label("field_size"), 

359 ) 

360 .group_by(Starter.race_id) 

361 .subquery() 

362 ) 

363 

364 query = ( 

365 self.session.query( 

366 Starter.placing, 

367 field_size_subquery.c.field_size, 

368 Race.race_datetime, 

369 Race.race_number, 

370 Meeting.meeting_date, 

371 ) 

372 .join(Race, Starter.race_id == Race.id) 

373 .join(Meeting, Race.meeting_id == Meeting.id) 

374 .join(field_size_subquery, Starter.race_id == field_size_subquery.c.race_id) 

375 .filter( 

376 Starter.horse_id == horse_id, 

377 Starter.placing.isnot(None), 

378 Starter.did_not_finish.is_(False), 

379 ) 

380 ) 

381 

382 if race.race_datetime: 

383 query = query.filter(Race.race_datetime < race.race_datetime) 

384 query = query.order_by(Race.race_datetime.desc(), Race.race_number.desc()) 

385 elif race.meeting and race.meeting.meeting_date: 

386 query = query.filter( 

387 or_( 

388 Meeting.meeting_date < race.meeting.meeting_date, 

389 and_( 

390 Meeting.meeting_date == race.meeting.meeting_date, 

391 Race.race_number < race.race_number, 

392 ), 

393 ) 

394 ) 

395 query = query.order_by(Meeting.meeting_date.desc(), Race.race_number.desc()) 

396 else: 

397 query = query.order_by(Race.id.desc()) 

398 

399 rows = query.limit(self.place_history_limit).all() 

400 

401 if not rows: 

402 return {"top3_rate": self.place_prior_rate, "consistency": 0.5} 

403 

404 top3_count = 0 

405 percentiles = [] 

406 for placing, field_size, _, _, _ in rows: 

407 field_size = field_size or 1 

408 if placing <= 3: 

409 top3_count += 1 

410 if field_size <= 1: 

411 percentiles.append(1.0) 

412 else: 

413 percentiles.append(1.0 - (placing - 1) / (field_size - 1)) 

414 

415 sample_count = len(rows) 

416 top3_rate = (top3_count + self.place_prior_rate * self.place_prior_weight) / ( 

417 sample_count + self.place_prior_weight 

418 ) 

419 

420 if sample_count < 2: 

421 consistency = 0.5 

422 else: 

423 mean = sum(percentiles) / sample_count 

424 variance = sum((p - mean) ** 2 for p in percentiles) / sample_count 

425 stddev = math.sqrt(variance) 

426 consistency = 1.0 - min(1.0, stddev / 0.5) 

427 

428 return {"top3_rate": top3_rate, "consistency": consistency} 

429 

430 def get_upcoming_races(self, race_date: date | None = None) -> list[Race]: 

431 """Get upcoming races for prediction. 

432 

433 Args: 

434 race_date: Date to get races for (defaults to today) 

435 

436 Returns: 

437 List of races 

438 """ 

439 if race_date is None: 

440 race_date = date.today() 

441 

442 from packages.core.storage.models import Meeting 

443 

444 races = ( 

445 self.session.query(Race) 

446 .join(Race.meeting) 

447 .filter(Meeting.meeting_date == race_date) 

448 .order_by(Race.race_number) 

449 .all() 

450 ) 

451 

452 return races 

453 

454 def compare_prediction_to_actual(self, race_id: int) -> dict | None: 

455 """Compare prediction to actual result for completed race. 

456 

457 Args: 

458 race_id: Race ID 

459 

460 Returns: 

461 Comparison dictionary with prediction accuracy metrics 

462 """ 

463 race = self.session.query(Race).filter(Race.id == race_id).first() 

464 if not race: 

465 return None 

466 

467 starters = race.starters 

468 if not starters: 

469 return None 

470 

471 # Generate prediction 

472 prediction = self.predict_race(race, starters) 

473 

474 # Compare to actual results 

475 starter_by_id = {starter.id: starter for starter in starters} 

476 actual_winner_id = None 

477 actual_top3_ids = [] 

478 

479 for starter in starters: 

480 if starter.placing and not starter.did_not_finish: 

481 if starter.placing == 1: 

482 actual_winner_id = starter.id 

483 if starter.placing <= 3: 

484 actual_top3_ids.append(starter.id) 

485 

486 if not actual_winner_id: 

487 return None # Race not completed 

488 

489 # Find predicted winner 

490 predicted_winner = max(prediction.predictions, key=lambda p: p.win_probability) 

491 

492 # Check if prediction was correct 

493 winner_correct = predicted_winner.starter_id == actual_winner_id 

494 

495 # Check top-3 overlap 

496 predicted_top3_ids = [ 

497 p.starter_id 

498 for p in sorted(prediction.predictions, key=lambda p: p.predicted_placing)[ 

499 :3 

500 ] 

501 ] 

502 top3_overlap = len(set(predicted_top3_ids) & set(actual_top3_ids)) 

503 

504 # Calculate Brier score for winner prediction 

505 brier_score = sum( 

506 (p.win_probability - (1.0 if p.starter_id == actual_winner_id else 0.0)) 

507 ** 2 

508 for p in prediction.predictions 

509 ) / len(prediction.predictions) 

510 

511 predictions_with_actuals = [] 

512 for pred in prediction.predictions: 

513 starter = starter_by_id.get(pred.starter_id) 

514 actual_placing = None 

515 if starter and starter.placing and not starter.did_not_finish: 

516 actual_placing = starter.placing 

517 predictions_with_actuals.append( 

518 { 

519 "starter_id": pred.starter_id, 

520 "horse_id": pred.horse_id, 

521 "horse_name": pred.horse_name, 

522 "driver_id": pred.driver_id, 

523 "driver_name": pred.driver_name, 

524 "trainer_id": pred.trainer_id, 

525 "trainer_name": pred.trainer_name, 

526 "barrier": pred.barrier, 

527 "handicap_m": pred.handicap_m, 

528 "effective_rating": pred.effective_rating, 

529 "win_probability": pred.win_probability, 

530 "place_probability": pred.place_probability, 

531 "place_score": pred.place_score, 

532 "ci_lower": pred.confidence_interval_low, 

533 "ci_upper": pred.confidence_interval_high, 

534 "predicted_placing": pred.predicted_placing, 

535 "actual_placing": actual_placing, 

536 } 

537 ) 

538 

539 return { 

540 "race_id": race_id, 

541 "race_number": race.race_number, 

542 "venue": race.meeting.venue if race.meeting else None, 

543 "race_date": ( 

544 race.meeting.meeting_date.isoformat() if race.meeting else None 

545 ), 

546 "winner_correct": winner_correct, 

547 "predicted_winner_id": predicted_winner.starter_id, 

548 "actual_winner_id": actual_winner_id, 

549 "top3_overlap": top3_overlap, 

550 "top3_overlap_rate": top3_overlap / 3.0, 

551 "brier_score": brier_score, 

552 "field_size": len(prediction.predictions), 

553 "predictions": predictions_with_actuals, 

554 } 

555 

556 

557def export_predictions_csv(predictions: list[RacePrediction], output_file: str) -> None: 

558 """Export predictions to CSV file. 

559 

560 Args: 

561 predictions: List of race predictions 

562 output_file: Output CSV file path 

563 """ 

564 import csv 

565 

566 with open(output_file, "w", newline="") as f: 

567 writer = csv.writer(f) 

568 

569 # Header 

570 writer.writerow( 

571 [ 

572 "Race ID", 

573 "Race Number", 

574 "Venue", 

575 "Distance (m)", 

576 "Starter ID", 

577 "Horse Name", 

578 "Driver Name", 

579 "Barrier", 

580 "Handicap (m)", 

581 "Effective Rating", 

582 "Win Probability", 

583 "Place Probability", 

584 "Predicted Placing", 

585 "CI Low", 

586 "CI High", 

587 ] 

588 ) 

589 

590 # Data rows 

591 for race_pred in predictions: 

592 for pred in race_pred.predictions: 

593 writer.writerow( 

594 [ 

595 race_pred.race_id, 

596 race_pred.race_number or "", 

597 race_pred.venue or "", 

598 race_pred.distance_m or "", 

599 pred.starter_id, 

600 pred.horse_name or "", 

601 pred.driver_name or "", 

602 pred.barrier or "", 

603 pred.handicap_m or "", 

604 f"{pred.effective_rating:.1f}", 

605 f"{pred.win_probability:.3f}", 

606 f"{pred.place_probability:.3f}", 

607 pred.predicted_placing, 

608 f"{pred.confidence_interval_low:.1f}", 

609 f"{pred.confidence_interval_high:.1f}", 

610 ] 

611 ) 

612 

613 logger.info(f"Predictions exported to {output_file}")