Coverage for packages / core / ratings / predictions.py: 57%
206 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
1"""Enhanced race prediction functionality with confidence intervals and tracking."""
3import math
4from dataclasses import dataclass
5from datetime import date, datetime
7from sqlalchemy import and_, func, or_
8from sqlalchemy.orm import Session
10from packages.core.common.logging import get_logger
11from packages.core.common.settings import get_settings
12from packages.core.ratings.engine import RatingEngine
13from packages.core.storage.models import EntityType, Race, Starter
14from packages.core.storage.repositories import RatingSnapshotRepository
16logger = get_logger(__name__)
19@dataclass
20class PredictionResult:
21 """Prediction for a single starter."""
23 starter_id: int
24 horse_id: int
25 horse_name: str | None
26 driver_id: int | None
27 driver_name: str | None
28 trainer_id: int | None
29 trainer_name: str | None
30 barrier: int | None
31 handicap_m: float | None
32 effective_rating: float
33 win_probability: float
34 place_probability: float # Top 3
35 place_score: float
36 confidence_interval_low: float # 95% CI lower bound
37 confidence_interval_high: float # 95% CI upper bound
38 predicted_placing: int
41@dataclass
42class RacePrediction:
43 """Complete prediction for a race."""
45 race_id: int
46 race_number: int | None
47 venue: str | None
48 distance_m: int | None
49 race_date: date | None
50 predictions: list[PredictionResult]
51 prediction_timestamp: datetime
52 metadata: dict
55class PredictionEngine:
56 """Enhanced prediction engine with confidence intervals and tracking."""
58 def __init__(self, session: Session):
59 """Initialize prediction engine.
61 Args:
62 session: Database session
63 """
64 self.session = session
65 self.rating_engine = RatingEngine(session)
66 settings = get_settings().rating
67 self.place_history_limit = settings.place_history_limit
68 self.place_prior_rate = settings.place_prior_rate
69 self.place_prior_weight = settings.place_prior_weight
70 self.place_top3_weight = settings.place_top3_weight
71 self.place_consistency_weight = settings.place_consistency_weight
73 def predict_race(self, race: Race, starters: list[Starter]) -> RacePrediction:
74 """Generate predictions for a race.
76 Args:
77 race: Race instance
78 starters: List of starters
80 Returns:
81 Complete race prediction with probabilities
82 """
83 if not starters:
84 logger.warning(f"No starters for race {race.id}")
85 return RacePrediction(
86 race_id=race.id,
87 race_number=race.race_number,
88 venue=race.meeting.venue if race.meeting else None,
89 distance_m=race.distance_m,
90 race_date=race.meeting.meeting_date if race.meeting else None,
91 predictions=[],
92 prediction_timestamp=datetime.now(),
93 metadata={},
94 )
96 # Load latest ratings for all starters (use pre-race ratings if available)
97 for starter in starters:
98 if starter.horse_id:
99 latest = RatingSnapshotRepository.get_latest_rating(
100 self.session,
101 EntityType.HORSE,
102 starter.horse_id,
103 before_race_id=race.id,
104 )
105 if latest:
106 self.rating_engine.load_rating_state(
107 EntityType.HORSE,
108 starter.horse_id,
109 latest.rating,
110 latest.rd,
111 last_race_date=(
112 latest.race.meeting.meeting_date
113 if latest.race and latest.race.meeting
114 else None
115 ),
116 )
118 if starter.driver_id:
119 latest = RatingSnapshotRepository.get_latest_rating(
120 self.session,
121 EntityType.DRIVER,
122 starter.driver_id,
123 before_race_id=race.id,
124 )
125 if latest:
126 self.rating_engine.load_rating_state(
127 EntityType.DRIVER,
128 starter.driver_id,
129 latest.rating,
130 latest.rd,
131 last_race_date=(
132 latest.race.meeting.meeting_date
133 if latest.race and latest.race.meeting
134 else None
135 ),
136 )
138 if starter.trainer_id:
139 latest = RatingSnapshotRepository.get_latest_rating(
140 self.session,
141 EntityType.TRAINER,
142 starter.trainer_id,
143 before_race_id=race.id,
144 )
145 if latest:
146 self.rating_engine.load_rating_state(
147 EntityType.TRAINER,
148 starter.trainer_id,
149 latest.rating,
150 latest.rd,
151 last_race_date=(
152 latest.race.meeting.meeting_date
153 if latest.race and latest.race.meeting
154 else None
155 ),
156 )
158 # Compute effective ratings for all starters
159 effective_ratings = {}
160 rating_uncertainties = {}
162 for starter in starters:
163 if not starter.horse_id:
164 continue
166 r_eff = self.rating_engine.compute_effective_rating(starter, race)
167 effective_ratings[starter.id] = r_eff
169 # Get rating uncertainty (RD) for confidence interval
170 horse_state = self.rating_engine.get_or_init_rating(
171 EntityType.HORSE, starter.horse_id
172 )
173 rating_uncertainties[starter.id] = horse_state.rd or 100.0
175 if not effective_ratings:
176 logger.warning(f"No ratings available for race {race.id}")
177 return RacePrediction(
178 race_id=race.id,
179 race_number=race.race_number,
180 venue=race.meeting.venue if race.meeting else None,
181 distance_m=race.distance_m,
182 race_date=race.meeting.meeting_date if race.meeting else None,
183 predictions=[],
184 prediction_timestamp=datetime.now(),
185 metadata={},
186 )
188 # Compute win probabilities using softmax
189 win_probs = self._compute_win_probabilities(effective_ratings)
191 # Compute place scores/probabilities (top 3)
192 place_scores = self._compute_place_scores(race, starters, effective_ratings)
193 place_probs = self._compute_place_probabilities(place_scores, top_n=3)
195 # Generate predictions for each starter
196 predictions = []
197 place_ranks = self._compute_place_ranks(place_scores)
198 for starter in starters:
199 if starter.id not in effective_ratings:
200 continue
202 r_eff = effective_ratings[starter.id]
203 win_prob = win_probs.get(starter.id, 0.0)
204 place_prob = place_probs.get(starter.id, 0.0)
205 place_score = place_scores.get(starter.id, r_eff)
206 rd = rating_uncertainties.get(starter.id, 100.0)
208 # Compute confidence interval (95% CI = ±1.96 * RD)
209 ci_low = r_eff - 1.96 * rd
210 ci_high = r_eff + 1.96 * rd
212 predicted_placing = place_ranks.get(starter.id, len(place_ranks))
214 predictions.append(
215 PredictionResult(
216 starter_id=starter.id,
217 horse_id=starter.horse_id,
218 horse_name=starter.horse.name if starter.horse else None,
219 driver_id=starter.driver_id,
220 driver_name=starter.driver.name if starter.driver else None,
221 trainer_id=starter.trainer_id,
222 trainer_name=starter.trainer.name if starter.trainer else None,
223 barrier=starter.barrier,
224 handicap_m=starter.handicap_m,
225 effective_rating=r_eff,
226 win_probability=win_prob,
227 place_probability=place_prob,
228 place_score=place_score,
229 confidence_interval_low=ci_low,
230 confidence_interval_high=ci_high,
231 predicted_placing=predicted_placing,
232 )
233 )
235 # Sort by predicted placing
236 predictions.sort(key=lambda p: p.predicted_placing)
238 metadata = {
239 "field_size": len(predictions),
240 "avg_rating": (
241 sum(p.effective_rating for p in predictions) / len(predictions)
242 if predictions
243 else 0
244 ),
245 "rating_spread": (
246 max(p.effective_rating for p in predictions)
247 - min(p.effective_rating for p in predictions)
248 if predictions
249 else 0
250 ),
251 "avg_place_score": (
252 sum(p.place_score for p in predictions) / len(predictions)
253 if predictions
254 else 0
255 ),
256 }
258 return RacePrediction(
259 race_id=race.id,
260 race_number=race.race_number,
261 venue=race.meeting.venue if race.meeting else None,
262 distance_m=race.distance_m,
263 race_date=race.meeting.meeting_date if race.meeting else None,
264 predictions=predictions,
265 prediction_timestamp=datetime.now(),
266 metadata=metadata,
267 )
269 def _compute_win_probabilities(
270 self, effective_ratings: dict[int, float]
271 ) -> dict[int, float]:
272 """Compute win probabilities using softmax.
274 Args:
275 effective_ratings: Dictionary of starter_id -> effective rating
277 Returns:
278 Dictionary of starter_id -> win probability
279 """
280 if not effective_ratings:
281 return {}
283 # Numerically stable softmax
284 rating_values = list(effective_ratings.values())
285 max_rating = max(rating_values)
287 # Scale factor (400 = standard Elo scale)
288 scale = 400.0
290 # Compute exp(rating/scale) for all starters
291 exp_ratings = {
292 sid: math.exp((rating - max_rating) / scale)
293 for sid, rating in effective_ratings.items()
294 }
296 total_exp = sum(exp_ratings.values())
298 # Normalize to probabilities
299 return {sid: exp_val / total_exp for sid, exp_val in exp_ratings.items()}
301 def _compute_place_probabilities(
302 self, place_scores: dict[int, float], top_n: int = 3
303 ) -> dict[int, float]:
304 """Compute place probabilities (finishing in top N).
306 Args:
307 place_scores: Dictionary of starter_id -> place score
308 top_n: Number of top placings to consider
310 Returns:
311 Dictionary of starter_id -> place probability
312 """
313 if not place_scores or len(place_scores) <= top_n:
314 return dict.fromkeys(place_scores, 1.0)
316 base_probs = self._compute_win_probabilities(place_scores)
317 return {sid: min(1.0, prob * (top_n + 0.5)) for sid, prob in base_probs.items()}
319 def _compute_place_scores(
320 self,
321 race: Race,
322 starters: list[Starter],
323 effective_ratings: dict[int, float],
324 ) -> dict[int, float]:
325 """Compute place scores from ratings and consistency history."""
326 scores = {}
327 scale = self.rating_engine.settings.elo_scale_c
329 for starter in starters:
330 if starter.id not in effective_ratings or not starter.horse_id:
331 continue
333 history = self._get_recent_finish_stats(race, starter.horse_id)
334 top3_rate = history["top3_rate"]
335 consistency = history["consistency"]
337 rating_adjustment = (
338 self.place_top3_weight * (top3_rate - self.place_prior_rate) * scale
339 + self.place_consistency_weight * (consistency - 0.5) * scale
340 )
342 scores[starter.id] = effective_ratings[starter.id] + rating_adjustment
344 return scores
346 def _compute_place_ranks(self, place_scores: dict[int, float]) -> dict[int, int]:
347 """Convert place scores into ordinal ranks."""
348 ordered = sorted(place_scores.items(), key=lambda item: item[1], reverse=True)
349 return {starter_id: idx + 1 for idx, (starter_id, _) in enumerate(ordered)}
351 def _get_recent_finish_stats(self, race: Race, horse_id: int) -> dict[str, float]:
352 """Compute smoothed top-3 rate and consistency from recent finishes."""
353 from packages.core.storage.models import Meeting
355 field_size_subquery = (
356 self.session.query(
357 Starter.race_id.label("race_id"),
358 func.count(Starter.id).label("field_size"),
359 )
360 .group_by(Starter.race_id)
361 .subquery()
362 )
364 query = (
365 self.session.query(
366 Starter.placing,
367 field_size_subquery.c.field_size,
368 Race.race_datetime,
369 Race.race_number,
370 Meeting.meeting_date,
371 )
372 .join(Race, Starter.race_id == Race.id)
373 .join(Meeting, Race.meeting_id == Meeting.id)
374 .join(field_size_subquery, Starter.race_id == field_size_subquery.c.race_id)
375 .filter(
376 Starter.horse_id == horse_id,
377 Starter.placing.isnot(None),
378 Starter.did_not_finish.is_(False),
379 )
380 )
382 if race.race_datetime:
383 query = query.filter(Race.race_datetime < race.race_datetime)
384 query = query.order_by(Race.race_datetime.desc(), Race.race_number.desc())
385 elif race.meeting and race.meeting.meeting_date:
386 query = query.filter(
387 or_(
388 Meeting.meeting_date < race.meeting.meeting_date,
389 and_(
390 Meeting.meeting_date == race.meeting.meeting_date,
391 Race.race_number < race.race_number,
392 ),
393 )
394 )
395 query = query.order_by(Meeting.meeting_date.desc(), Race.race_number.desc())
396 else:
397 query = query.order_by(Race.id.desc())
399 rows = query.limit(self.place_history_limit).all()
401 if not rows:
402 return {"top3_rate": self.place_prior_rate, "consistency": 0.5}
404 top3_count = 0
405 percentiles = []
406 for placing, field_size, _, _, _ in rows:
407 field_size = field_size or 1
408 if placing <= 3:
409 top3_count += 1
410 if field_size <= 1:
411 percentiles.append(1.0)
412 else:
413 percentiles.append(1.0 - (placing - 1) / (field_size - 1))
415 sample_count = len(rows)
416 top3_rate = (top3_count + self.place_prior_rate * self.place_prior_weight) / (
417 sample_count + self.place_prior_weight
418 )
420 if sample_count < 2:
421 consistency = 0.5
422 else:
423 mean = sum(percentiles) / sample_count
424 variance = sum((p - mean) ** 2 for p in percentiles) / sample_count
425 stddev = math.sqrt(variance)
426 consistency = 1.0 - min(1.0, stddev / 0.5)
428 return {"top3_rate": top3_rate, "consistency": consistency}
430 def get_upcoming_races(self, race_date: date | None = None) -> list[Race]:
431 """Get upcoming races for prediction.
433 Args:
434 race_date: Date to get races for (defaults to today)
436 Returns:
437 List of races
438 """
439 if race_date is None:
440 race_date = date.today()
442 from packages.core.storage.models import Meeting
444 races = (
445 self.session.query(Race)
446 .join(Race.meeting)
447 .filter(Meeting.meeting_date == race_date)
448 .order_by(Race.race_number)
449 .all()
450 )
452 return races
454 def compare_prediction_to_actual(self, race_id: int) -> dict | None:
455 """Compare prediction to actual result for completed race.
457 Args:
458 race_id: Race ID
460 Returns:
461 Comparison dictionary with prediction accuracy metrics
462 """
463 race = self.session.query(Race).filter(Race.id == race_id).first()
464 if not race:
465 return None
467 starters = race.starters
468 if not starters:
469 return None
471 # Generate prediction
472 prediction = self.predict_race(race, starters)
474 # Compare to actual results
475 starter_by_id = {starter.id: starter for starter in starters}
476 actual_winner_id = None
477 actual_top3_ids = []
479 for starter in starters:
480 if starter.placing and not starter.did_not_finish:
481 if starter.placing == 1:
482 actual_winner_id = starter.id
483 if starter.placing <= 3:
484 actual_top3_ids.append(starter.id)
486 if not actual_winner_id:
487 return None # Race not completed
489 # Find predicted winner
490 predicted_winner = max(prediction.predictions, key=lambda p: p.win_probability)
492 # Check if prediction was correct
493 winner_correct = predicted_winner.starter_id == actual_winner_id
495 # Check top-3 overlap
496 predicted_top3_ids = [
497 p.starter_id
498 for p in sorted(prediction.predictions, key=lambda p: p.predicted_placing)[
499 :3
500 ]
501 ]
502 top3_overlap = len(set(predicted_top3_ids) & set(actual_top3_ids))
504 # Calculate Brier score for winner prediction
505 brier_score = sum(
506 (p.win_probability - (1.0 if p.starter_id == actual_winner_id else 0.0))
507 ** 2
508 for p in prediction.predictions
509 ) / len(prediction.predictions)
511 predictions_with_actuals = []
512 for pred in prediction.predictions:
513 starter = starter_by_id.get(pred.starter_id)
514 actual_placing = None
515 if starter and starter.placing and not starter.did_not_finish:
516 actual_placing = starter.placing
517 predictions_with_actuals.append(
518 {
519 "starter_id": pred.starter_id,
520 "horse_id": pred.horse_id,
521 "horse_name": pred.horse_name,
522 "driver_id": pred.driver_id,
523 "driver_name": pred.driver_name,
524 "trainer_id": pred.trainer_id,
525 "trainer_name": pred.trainer_name,
526 "barrier": pred.barrier,
527 "handicap_m": pred.handicap_m,
528 "effective_rating": pred.effective_rating,
529 "win_probability": pred.win_probability,
530 "place_probability": pred.place_probability,
531 "place_score": pred.place_score,
532 "ci_lower": pred.confidence_interval_low,
533 "ci_upper": pred.confidence_interval_high,
534 "predicted_placing": pred.predicted_placing,
535 "actual_placing": actual_placing,
536 }
537 )
539 return {
540 "race_id": race_id,
541 "race_number": race.race_number,
542 "venue": race.meeting.venue if race.meeting else None,
543 "race_date": (
544 race.meeting.meeting_date.isoformat() if race.meeting else None
545 ),
546 "winner_correct": winner_correct,
547 "predicted_winner_id": predicted_winner.starter_id,
548 "actual_winner_id": actual_winner_id,
549 "top3_overlap": top3_overlap,
550 "top3_overlap_rate": top3_overlap / 3.0,
551 "brier_score": brier_score,
552 "field_size": len(prediction.predictions),
553 "predictions": predictions_with_actuals,
554 }
557def export_predictions_csv(predictions: list[RacePrediction], output_file: str) -> None:
558 """Export predictions to CSV file.
560 Args:
561 predictions: List of race predictions
562 output_file: Output CSV file path
563 """
564 import csv
566 with open(output_file, "w", newline="") as f:
567 writer = csv.writer(f)
569 # Header
570 writer.writerow(
571 [
572 "Race ID",
573 "Race Number",
574 "Venue",
575 "Distance (m)",
576 "Starter ID",
577 "Horse Name",
578 "Driver Name",
579 "Barrier",
580 "Handicap (m)",
581 "Effective Rating",
582 "Win Probability",
583 "Place Probability",
584 "Predicted Placing",
585 "CI Low",
586 "CI High",
587 ]
588 )
590 # Data rows
591 for race_pred in predictions:
592 for pred in race_pred.predictions:
593 writer.writerow(
594 [
595 race_pred.race_id,
596 race_pred.race_number or "",
597 race_pred.venue or "",
598 race_pred.distance_m or "",
599 pred.starter_id,
600 pred.horse_name or "",
601 pred.driver_name or "",
602 pred.barrier or "",
603 pred.handicap_m or "",
604 f"{pred.effective_rating:.1f}",
605 f"{pred.win_probability:.3f}",
606 f"{pred.place_probability:.3f}",
607 pred.predicted_placing,
608 f"{pred.confidence_interval_low:.1f}",
609 f"{pred.confidence_interval_high:.1f}",
610 ]
611 )
613 logger.info(f"Predictions exported to {output_file}")