Coverage for packages / core / ratings / engine.py: 94%
262 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:37 +1200
1"""Multi-runner Elo rating engine for harness racing.
3Implements pairwise logistic Elo with support for:
4- Multi-entity ratings (horse, driver, trainer)
5- Condition adjustments (barrier, handicap)
6- Rating deviation (RD) for uncertainty tracking
7"""
9import math
10from dataclasses import dataclass
11from datetime import date
13from sqlalchemy.orm import Session
15from packages.core.common.logging import get_logger
16from packages.core.common.settings import get_settings
17from packages.core.common.utils import get_distance_bucket
18from packages.core.storage.models import EntityType, Race, Starter
19from packages.core.storage.repositories import (
20 BarrierAdjustmentRepository,
21 HandicapAdjustmentRepository,
22)
24logger = get_logger(__name__)
27@dataclass
28class RatingState:
29 """Current rating state for an entity."""
31 rating: float
32 rd: float | None = None
33 race_count: int = 0
34 last_race_id: int | None = None
35 last_race_date: date | None = None
38@dataclass
39class RatingUpdate:
40 """Rating update to apply after a race."""
42 entity_type: EntityType
43 entity_id: int
44 old_rating: float
45 new_rating: float
46 delta: float
47 rd: float | None = None
48 meta: dict | None = None
51class RatingEngine:
52 """Multi-runner Elo rating engine."""
54 def __init__(self, db_session: Session | None = None):
55 """Initialize rating engine with configuration.
57 Args:
58 db_session: Optional database session for loading/saving adjustments
59 """
60 self.settings = get_settings().rating
61 self.db_session = db_session
63 # Rating states (in-memory cache during computation)
64 self.states: dict[tuple[EntityType, int], RatingState] = {}
66 # Condition adjustments (loaded/updated during computation)
67 self.barrier_adjustments: dict[tuple, float] = {}
68 self.handicap_adjustments: dict[tuple, float] = {}
69 self.barrier_adjustment_samples: dict[tuple, int] = {}
70 self.handicap_adjustment_samples: dict[tuple, int] = {}
72 # Load adjustments from database if session provided
73 if self.db_session and self.settings.enable_adjustments:
74 self.load_adjustments_from_db()
76 def get_or_init_rating(
77 self, entity_type: EntityType, entity_id: int
78 ) -> RatingState:
79 """Get current rating state or initialize new entity.
81 Args:
82 entity_type: Type of entity
83 entity_id: Entity ID
85 Returns:
86 Current rating state
87 """
88 key = (entity_type, entity_id)
89 if key not in self.states:
90 self.states[key] = RatingState(
91 rating=self.settings.initial_rating,
92 rd=self.settings.initial_rd if self.settings.enable_rd else None,
93 race_count=0,
94 )
95 return self.states[key]
97 def load_rating_state(
98 self,
99 entity_type: EntityType,
100 entity_id: int,
101 rating: float,
102 rd: float | None = None,
103 last_race_date: date | None = None,
104 ) -> None:
105 """Load existing rating state from database.
107 Args:
108 entity_type: Type of entity
109 entity_id: Entity ID
110 rating: Current rating
111 rd: Current rating deviation
112 last_race_date: Date of last race
113 """
114 key = (entity_type, entity_id)
115 self.states[key] = RatingState(
116 rating=rating,
117 rd=rd,
118 last_race_date=last_race_date,
119 )
121 def sigmoid(self, x: float) -> float:
122 """Logistic sigmoid function.
124 Args:
125 x: Input value
127 Returns:
128 Value between 0 and 1
129 """
130 if x >= 0:
131 return 1.0 / (1.0 + math.exp(-x))
132 exp_x = math.exp(x)
133 return exp_x / (1.0 + exp_x)
135 def get_effective_k_factor(self, entity_type: EntityType, entity_id: int) -> float:
136 """Compute effective K-factor based on rating deviation.
138 When RD is enabled, adjust K-factor proportionally to entity's uncertainty:
139 - High RD (new/inactive) → larger K → faster rating changes
140 - Low RD (established) → smaller K → more stable ratings
142 Args:
143 entity_type: Type of entity (HORSE, DRIVER, TRAINER)
144 entity_id: Entity ID
146 Returns:
147 Effective K-factor for this entity
148 """
149 base_k = self.settings.elo_k_base
150 k_eff = base_k
152 # If RD not enabled, use base K-factor
153 if self.settings.enable_rd:
154 # Get entity's current RD
155 state = self.get_or_init_rating(entity_type, entity_id)
156 if state.rd is not None:
157 initial_rd = self.settings.initial_rd
158 if initial_rd > 0:
159 ratio = state.rd / initial_rd
160 if self.settings.rd_scaling_mode == "sqrt":
161 ratio = math.sqrt(ratio)
162 elif self.settings.rd_scaling_mode == "none":
163 ratio = 1.0
164 k_eff = base_k * ratio
166 if self.settings.elo_k_min is not None:
167 k_eff = max(k_eff, self.settings.elo_k_min)
168 if self.settings.elo_k_max is not None:
169 k_eff = min(k_eff, self.settings.elo_k_max)
171 return k_eff
173 def compute_effective_rating(
174 self,
175 starter: Starter,
176 race: Race,
177 ) -> float:
178 """Compute effective rating for a starter.
180 R_eff = R_horse + α*R_driver + β*R_trainer + barrier_adj + handicap_adj
182 Args:
183 starter: Starter instance
184 race: Race instance
186 Returns:
187 Effective rating
188 """
189 # Base horse rating
190 horse_state = self.get_or_init_rating(EntityType.HORSE, starter.horse_id)
191 r_eff = horse_state.rating
193 # Add driver contribution
194 if self.settings.enable_driver and starter.driver_id:
195 driver_state = self.get_or_init_rating(EntityType.DRIVER, starter.driver_id)
196 r_eff += self.settings.driver_weight_alpha * driver_state.rating
198 # Add trainer contribution
199 if self.settings.enable_trainer and starter.trainer_id:
200 trainer_state = self.get_or_init_rating(
201 EntityType.TRAINER, starter.trainer_id
202 )
203 r_eff += self.settings.trainer_weight_beta * trainer_state.rating
205 # Add condition adjustments
206 if self.settings.enable_adjustments:
207 # Barrier adjustment
208 if starter.barrier is not None:
209 barrier_adj = self._get_barrier_adjustment(
210 race.meeting.venue,
211 race.start_type,
212 race.distance_m,
213 starter.barrier,
214 )
215 r_eff += barrier_adj
217 # Handicap adjustment
218 if starter.handicap_m is not None and starter.handicap_m != 0:
219 handicap_adj = self._get_handicap_adjustment(
220 race.meeting.venue,
221 race.start_type,
222 race.distance_m,
223 starter.handicap_m,
224 )
225 r_eff += handicap_adj
227 return r_eff
229 def _get_barrier_adjustment(
230 self,
231 venue: str | None,
232 start_type: str | None,
233 distance_m: int | None,
234 barrier: int,
235 ) -> float:
236 """Get barrier adjustment from learned table.
238 Args:
239 venue: Venue name
240 start_type: mobile/standing
241 distance_m: Distance in meters
242 barrier: Barrier number
244 Returns:
245 Adjustment value (default 0.0)
246 """
247 if not self.settings.adj_barrier_enabled:
248 return 0.0
250 distance_bucket = get_distance_bucket(
251 distance_m,
252 self.settings.distance_buckets,
253 mode=self.settings.distance_bucket_mode,
254 bucket_size=self.settings.distance_bucket_size,
255 )
257 # Try specific key first, then fall back to global
258 key = (venue, start_type, distance_bucket, barrier)
260 # Fall back to global (no venue/start_type)
261 global_key = (None, None, distance_bucket, barrier)
262 return self._resolve_adjustment(
263 key,
264 global_key,
265 self.barrier_adjustments,
266 self.barrier_adjustment_samples,
267 )
269 def _get_handicap_adjustment(
270 self,
271 venue: str | None,
272 start_type: str | None,
273 distance_m: int | None,
274 handicap_m: int,
275 ) -> float:
276 """Get handicap adjustment from learned table.
278 Args:
279 venue: Venue name
280 start_type: mobile/standing
281 distance_m: Distance in meters
282 handicap_m: Handicap in meters
284 Returns:
285 Adjustment value (default 0.0)
286 """
287 if not self.settings.adj_handicap_enabled:
288 return 0.0
290 distance_bucket = get_distance_bucket(
291 distance_m,
292 self.settings.distance_buckets,
293 mode=self.settings.distance_bucket_mode,
294 bucket_size=self.settings.distance_bucket_size,
295 )
297 key = (venue, start_type, distance_bucket, handicap_m)
299 global_key = (None, None, distance_bucket, handicap_m)
300 return self._resolve_adjustment(
301 key,
302 global_key,
303 self.handicap_adjustments,
304 self.handicap_adjustment_samples,
305 )
307 def _resolve_adjustment(
308 self,
309 key: tuple,
310 global_key: tuple,
311 adjustments: dict[tuple, float],
312 samples: dict[tuple, int],
313 ) -> float:
314 for candidate in (key, global_key):
315 if candidate not in adjustments:
316 continue
317 if self.settings.adj_min_samples > 0:
318 count = samples.get(candidate, 0)
319 if count < self.settings.adj_min_samples:
320 continue
321 adjustment = adjustments[candidate]
322 return self._clamp_adjustment(adjustment)
323 return 0.0
325 def _clamp_adjustment(self, adjustment: float) -> float:
326 if self.settings.adj_clamp_min is not None:
327 adjustment = max(adjustment, self.settings.adj_clamp_min)
328 if self.settings.adj_clamp_max is not None:
329 adjustment = min(adjustment, self.settings.adj_clamp_max)
330 return adjustment
332 def load_adjustments_from_db(self) -> None:
333 """Load barrier and handicap adjustments from database into memory."""
334 if not self.db_session:
335 return
337 # Load barrier adjustments
338 barrier_adjs = BarrierAdjustmentRepository.get_all(self.db_session)
339 for adj in barrier_adjs:
340 key = (adj.venue, adj.start_type, adj.distance_bucket, adj.barrier)
341 self.barrier_adjustments[key] = adj.adjustment
342 self.barrier_adjustment_samples[key] = adj.sample_count
344 # Load handicap adjustments
345 handicap_adjs = HandicapAdjustmentRepository.get_all(self.db_session)
346 for adj in handicap_adjs:
347 key = (adj.venue, adj.start_type, adj.distance_bucket, adj.handicap_m)
348 self.handicap_adjustments[key] = adj.adjustment
349 self.handicap_adjustment_samples[key] = adj.sample_count
351 logger.info(
352 f"Loaded {len(barrier_adjs)} barrier adjustments and "
353 f"{len(handicap_adjs)} handicap adjustments from database"
354 )
356 def learn_adjustments_from_race(
357 self, race: Race, starters: list[Starter], use_global_only: bool | None = None
358 ) -> None:
359 """Learn barrier and handicap adjustments from a completed race.
361 Uses performance residuals: if a horse performs better than expected
362 given its rating, attribute some of that to favorable conditions.
364 Args:
365 race: Race instance
366 starters: List of starters with results
367 use_global_only: If True, only update global adjustments (no venue-specific)
368 """
369 if not self.db_session or not self.settings.enable_adjustments:
370 return
371 if (
372 not self.settings.adj_barrier_enabled
373 and not self.settings.adj_handicap_enabled
374 ):
375 return
377 valid_starters = [
378 s for s in starters if s.placing is not None and not s.did_not_finish
379 ]
380 if len(valid_starters) < self.settings.min_finishers:
381 return
383 # Compute effective ratings (without adjustments to avoid feedback loop)
384 saved_enable = self.settings.enable_adjustments
385 self.settings.enable_adjustments = False
387 effective_ratings = {}
388 for starter in valid_starters:
389 effective_ratings[starter.id] = self.compute_effective_rating(starter, race)
391 self.settings.enable_adjustments = saved_enable
393 # For each starter, compute performance residual
394 use_global_only = (
395 self.settings.adj_global_only
396 if use_global_only is None
397 else use_global_only
398 )
400 for i, starter in enumerate(valid_starters):
401 if not starter.horse_id:
402 continue
404 r_eff = effective_ratings[starter.id]
405 placing = starter.placing
407 # Compute expected vs actual: sum over pairwise comparisons
408 expected_sum = 0.0
409 actual_sum = 0.0
411 for j, other in enumerate(valid_starters):
412 if i == j or not other.horse_id:
413 continue
415 r_other = effective_ratings[other.id]
416 # Expected: probability of beating other
417 expected = self.sigmoid((r_eff - r_other) / self.settings.elo_scale_c)
418 # Actual: 1 if beat other, 0 otherwise
419 actual = 1.0 if placing < other.placing else 0.0
421 expected_sum += expected
422 actual_sum += actual
424 # Performance delta: positive means outperformed expectations
425 delta = (actual_sum - expected_sum) / max(len(valid_starters) - 1, 1)
427 scaled_delta = delta * self.settings.adj_update_scale
429 # Update barrier adjustment if present
430 if self.settings.adj_barrier_enabled and starter.barrier is not None:
431 distance_bucket = get_distance_bucket(
432 race.distance_m,
433 self.settings.distance_buckets,
434 mode=self.settings.distance_bucket_mode,
435 bucket_size=self.settings.distance_bucket_size,
436 )
438 if use_global_only:
439 venue, start_type = None, None
440 else:
441 venue, start_type = race.meeting.venue, race.start_type
443 BarrierAdjustmentRepository.increment_sample(
444 self.db_session,
445 venue=venue,
446 start_type=start_type,
447 distance_bucket=distance_bucket,
448 barrier=starter.barrier,
449 delta=scaled_delta,
450 learning_rate=self.settings.adj_learning_rate,
451 )
453 # Update handicap adjustment if present and non-zero
454 if (
455 self.settings.adj_handicap_enabled
456 and starter.handicap_m is not None
457 and starter.handicap_m != 0
458 ):
459 distance_bucket = get_distance_bucket(
460 race.distance_m,
461 self.settings.distance_buckets,
462 mode=self.settings.distance_bucket_mode,
463 bucket_size=self.settings.distance_bucket_size,
464 )
466 if use_global_only:
467 venue, start_type = None, None
468 else:
469 venue, start_type = race.meeting.venue, race.start_type
471 HandicapAdjustmentRepository.increment_sample(
472 self.db_session,
473 venue=venue,
474 start_type=start_type,
475 distance_bucket=distance_bucket,
476 handicap_m=starter.handicap_m,
477 delta=scaled_delta,
478 learning_rate=self.settings.adj_learning_rate,
479 )
481 def process_race(self, race: Race, starters: list[Starter]) -> list[RatingUpdate]:
482 """Process a race and compute rating updates.
484 Uses pairwise logistic Elo:
485 - For each pair (i, j), compute expected outcome E_ij
486 - Update based on actual outcome S_ij (1 if i beat j, 0 otherwise)
487 - ΔR_i = K * (1/(n-1)) * Σ_j (S_ij - E_ij)
489 Args:
490 race: Race instance
491 starters: List of starters in race
493 Returns:
494 List of rating updates to apply
495 """
496 # Filter starters with valid placings
497 finishers = [
498 s for s in starters if s.placing is not None and not s.did_not_finish
499 ]
500 dnf_starters = [s for s in starters if s.did_not_finish or s.placing is None]
502 placing_by_id: dict[int, int] = {}
503 for starter in finishers:
504 placing_by_id[starter.id] = starter.placing
506 if self.settings.dnf_treated_as_last and dnf_starters:
507 max_place = max(placing_by_id.values(), default=0)
508 for starter in dnf_starters:
509 placing_by_id[starter.id] = max_place + 1
510 valid_starters = finishers + dnf_starters
511 else:
512 valid_starters = finishers
514 if len(valid_starters) < self.settings.min_finishers:
515 logger.debug(
516 f"Skipping race {race.id} - fewer than {self.settings.min_finishers} finishers"
517 )
518 return []
520 n = len(valid_starters)
521 updates = []
523 # Compute effective ratings for all starters
524 effective_ratings = {}
525 for starter in valid_starters:
526 effective_ratings[starter.id] = self.compute_effective_rating(starter, race)
528 # Process each starter
529 for i, starter_i in enumerate(valid_starters):
530 if not starter_i.horse_id:
531 continue
533 r_eff_i = effective_ratings[starter_i.id]
534 placing_i = starter_i.placing
535 if starter_i.id in placing_by_id:
536 placing_i = placing_by_id[starter_i.id]
538 # Compute update based on pairwise comparisons
539 delta_sum = 0.0
540 comparisons = 0
542 for j, starter_j in enumerate(valid_starters):
543 if i == j or not starter_j.horse_id:
544 continue
546 r_eff_j = effective_ratings[starter_j.id]
547 placing_j = starter_j.placing
548 if starter_j.id in placing_by_id:
549 placing_j = placing_by_id[starter_j.id]
551 # Actual outcome: 1 if i beat j, 0 otherwise
552 if placing_i == placing_j:
553 if self.settings.tie_handling == "skip":
554 continue
555 s_ij = 0.5 if self.settings.tie_handling == "half" else 0.0
556 else:
557 s_ij = 1.0 if placing_i < placing_j else 0.0
559 # Expected outcome using logistic model
560 e_ij = self.sigmoid((r_eff_i - r_eff_j) / self.settings.elo_scale_c)
562 # Accumulate delta
563 delta_sum += s_ij - e_ij
564 comparisons += 1
566 # Average over pairwise comparisons, using effective K-factor
567 if self.settings.pairwise_normalizer == "comparisons":
568 normalizer = comparisons
569 elif self.settings.pairwise_normalizer == "n":
570 normalizer = n
571 else:
572 normalizer = n - 1
574 if normalizer <= 0:
575 continue
577 k_eff = (
578 self.get_effective_k_factor(EntityType.HORSE, starter_i.horse_id)
579 * self.settings.horse_k_scale
580 )
581 delta_r = k_eff * (delta_sum / normalizer)
583 # Get race date for RD calculations
584 race_date = race.meeting.meeting_date if race.meeting else None
586 # Apply updates to all entities involved
587 self._apply_update(
588 EntityType.HORSE,
589 starter_i.horse_id,
590 delta_r,
591 race.id,
592 race_date,
593 updates,
594 )
596 if self.settings.enable_driver and starter_i.driver_id:
597 driver_delta = (
598 delta_r
599 * self.settings.driver_weight_alpha
600 * self.settings.driver_k_scale
601 )
602 self._apply_update(
603 EntityType.DRIVER,
604 starter_i.driver_id,
605 driver_delta,
606 race.id,
607 race_date,
608 updates,
609 )
611 if self.settings.enable_trainer and starter_i.trainer_id:
612 trainer_delta = (
613 delta_r
614 * self.settings.trainer_weight_beta
615 * self.settings.trainer_k_scale
616 )
617 self._apply_update(
618 EntityType.TRAINER,
619 starter_i.trainer_id,
620 trainer_delta,
621 race.id,
622 race_date,
623 updates,
624 )
626 return updates
628 def _apply_update(
629 self,
630 entity_type: EntityType,
631 entity_id: int,
632 delta: float,
633 race_id: int,
634 race_date: date | None,
635 updates: list[RatingUpdate],
636 ) -> None:
637 """Apply rating update to an entity.
639 Args:
640 entity_type: Type of entity
641 entity_id: Entity ID
642 delta: Rating change
643 race_id: Race ID
644 race_date: Date of the race
645 updates: List to append update to
646 """
647 state = self.get_or_init_rating(entity_type, entity_id)
648 old_rating = state.rating
649 new_rating = old_rating + delta
650 if self.settings.rating_min is not None:
651 new_rating = max(new_rating, self.settings.rating_min)
652 if self.settings.rating_max is not None:
653 new_rating = min(new_rating, self.settings.rating_max)
654 delta = new_rating - old_rating
656 # Update RD if enabled
657 if self.settings.enable_rd and state.rd is not None:
658 # First apply inflation for inactivity
659 if state.last_race_date and race_date:
660 days_inactive = (race_date - state.last_race_date).days
661 if days_inactive > 0:
662 if self.settings.rd_inflation_cap_days is not None:
663 days_inactive = min(
664 days_inactive, self.settings.rd_inflation_cap_days
665 )
666 inflation = days_inactive * self.settings.rd_inflation_per_day
667 state.rd = min(state.rd + inflation, self.settings.rd_max)
669 # Then apply decay for participating in race
670 decay = max(self.settings.rd_decay_per_race, self.settings.rd_decay_floor)
671 state.rd = max(state.rd - decay, self.settings.rd_min)
673 # Update state
674 state.rating = new_rating
675 state.race_count += 1
676 state.last_race_id = race_id
677 state.last_race_date = race_date
679 # Record update
680 updates.append(
681 RatingUpdate(
682 entity_type=entity_type,
683 entity_id=entity_id,
684 old_rating=old_rating,
685 new_rating=new_rating,
686 delta=delta,
687 rd=state.rd,
688 meta={
689 "race_count": state.race_count,
690 },
691 )
692 )
694 logger.debug(
695 f"{entity_type.value} {entity_id}: "
696 f"{old_rating:.1f} -> {new_rating:.1f} (Δ{delta:+.1f})"
697 + (f", RD={state.rd:.1f}" if state.rd else "")
698 )