Coverage for packages / core / ratings / engine.py: 94%

262 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-05-08 08:37 +1200

1"""Multi-runner Elo rating engine for harness racing. 

2 

3Implements pairwise logistic Elo with support for: 

4- Multi-entity ratings (horse, driver, trainer) 

5- Condition adjustments (barrier, handicap) 

6- Rating deviation (RD) for uncertainty tracking 

7""" 

8 

9import math 

10from dataclasses import dataclass 

11from datetime import date 

12 

13from sqlalchemy.orm import Session 

14 

15from packages.core.common.logging import get_logger 

16from packages.core.common.settings import get_settings 

17from packages.core.common.utils import get_distance_bucket 

18from packages.core.storage.models import EntityType, Race, Starter 

19from packages.core.storage.repositories import ( 

20 BarrierAdjustmentRepository, 

21 HandicapAdjustmentRepository, 

22) 

23 

24logger = get_logger(__name__) 

25 

26 

27@dataclass 

28class RatingState: 

29 """Current rating state for an entity.""" 

30 

31 rating: float 

32 rd: float | None = None 

33 race_count: int = 0 

34 last_race_id: int | None = None 

35 last_race_date: date | None = None 

36 

37 

38@dataclass 

39class RatingUpdate: 

40 """Rating update to apply after a race.""" 

41 

42 entity_type: EntityType 

43 entity_id: int 

44 old_rating: float 

45 new_rating: float 

46 delta: float 

47 rd: float | None = None 

48 meta: dict | None = None 

49 

50 

51class RatingEngine: 

52 """Multi-runner Elo rating engine.""" 

53 

54 def __init__(self, db_session: Session | None = None): 

55 """Initialize rating engine with configuration. 

56 

57 Args: 

58 db_session: Optional database session for loading/saving adjustments 

59 """ 

60 self.settings = get_settings().rating 

61 self.db_session = db_session 

62 

63 # Rating states (in-memory cache during computation) 

64 self.states: dict[tuple[EntityType, int], RatingState] = {} 

65 

66 # Condition adjustments (loaded/updated during computation) 

67 self.barrier_adjustments: dict[tuple, float] = {} 

68 self.handicap_adjustments: dict[tuple, float] = {} 

69 self.barrier_adjustment_samples: dict[tuple, int] = {} 

70 self.handicap_adjustment_samples: dict[tuple, int] = {} 

71 

72 # Load adjustments from database if session provided 

73 if self.db_session and self.settings.enable_adjustments: 

74 self.load_adjustments_from_db() 

75 

76 def get_or_init_rating( 

77 self, entity_type: EntityType, entity_id: int 

78 ) -> RatingState: 

79 """Get current rating state or initialize new entity. 

80 

81 Args: 

82 entity_type: Type of entity 

83 entity_id: Entity ID 

84 

85 Returns: 

86 Current rating state 

87 """ 

88 key = (entity_type, entity_id) 

89 if key not in self.states: 

90 self.states[key] = RatingState( 

91 rating=self.settings.initial_rating, 

92 rd=self.settings.initial_rd if self.settings.enable_rd else None, 

93 race_count=0, 

94 ) 

95 return self.states[key] 

96 

97 def load_rating_state( 

98 self, 

99 entity_type: EntityType, 

100 entity_id: int, 

101 rating: float, 

102 rd: float | None = None, 

103 last_race_date: date | None = None, 

104 ) -> None: 

105 """Load existing rating state from database. 

106 

107 Args: 

108 entity_type: Type of entity 

109 entity_id: Entity ID 

110 rating: Current rating 

111 rd: Current rating deviation 

112 last_race_date: Date of last race 

113 """ 

114 key = (entity_type, entity_id) 

115 self.states[key] = RatingState( 

116 rating=rating, 

117 rd=rd, 

118 last_race_date=last_race_date, 

119 ) 

120 

121 def sigmoid(self, x: float) -> float: 

122 """Logistic sigmoid function. 

123 

124 Args: 

125 x: Input value 

126 

127 Returns: 

128 Value between 0 and 1 

129 """ 

130 if x >= 0: 

131 return 1.0 / (1.0 + math.exp(-x)) 

132 exp_x = math.exp(x) 

133 return exp_x / (1.0 + exp_x) 

134 

135 def get_effective_k_factor(self, entity_type: EntityType, entity_id: int) -> float: 

136 """Compute effective K-factor based on rating deviation. 

137 

138 When RD is enabled, adjust K-factor proportionally to entity's uncertainty: 

139 - High RD (new/inactive) → larger K → faster rating changes 

140 - Low RD (established) → smaller K → more stable ratings 

141 

142 Args: 

143 entity_type: Type of entity (HORSE, DRIVER, TRAINER) 

144 entity_id: Entity ID 

145 

146 Returns: 

147 Effective K-factor for this entity 

148 """ 

149 base_k = self.settings.elo_k_base 

150 k_eff = base_k 

151 

152 # If RD not enabled, use base K-factor 

153 if self.settings.enable_rd: 

154 # Get entity's current RD 

155 state = self.get_or_init_rating(entity_type, entity_id) 

156 if state.rd is not None: 

157 initial_rd = self.settings.initial_rd 

158 if initial_rd > 0: 

159 ratio = state.rd / initial_rd 

160 if self.settings.rd_scaling_mode == "sqrt": 

161 ratio = math.sqrt(ratio) 

162 elif self.settings.rd_scaling_mode == "none": 

163 ratio = 1.0 

164 k_eff = base_k * ratio 

165 

166 if self.settings.elo_k_min is not None: 

167 k_eff = max(k_eff, self.settings.elo_k_min) 

168 if self.settings.elo_k_max is not None: 

169 k_eff = min(k_eff, self.settings.elo_k_max) 

170 

171 return k_eff 

172 

173 def compute_effective_rating( 

174 self, 

175 starter: Starter, 

176 race: Race, 

177 ) -> float: 

178 """Compute effective rating for a starter. 

179 

180 R_eff = R_horse + α*R_driver + β*R_trainer + barrier_adj + handicap_adj 

181 

182 Args: 

183 starter: Starter instance 

184 race: Race instance 

185 

186 Returns: 

187 Effective rating 

188 """ 

189 # Base horse rating 

190 horse_state = self.get_or_init_rating(EntityType.HORSE, starter.horse_id) 

191 r_eff = horse_state.rating 

192 

193 # Add driver contribution 

194 if self.settings.enable_driver and starter.driver_id: 

195 driver_state = self.get_or_init_rating(EntityType.DRIVER, starter.driver_id) 

196 r_eff += self.settings.driver_weight_alpha * driver_state.rating 

197 

198 # Add trainer contribution 

199 if self.settings.enable_trainer and starter.trainer_id: 

200 trainer_state = self.get_or_init_rating( 

201 EntityType.TRAINER, starter.trainer_id 

202 ) 

203 r_eff += self.settings.trainer_weight_beta * trainer_state.rating 

204 

205 # Add condition adjustments 

206 if self.settings.enable_adjustments: 

207 # Barrier adjustment 

208 if starter.barrier is not None: 

209 barrier_adj = self._get_barrier_adjustment( 

210 race.meeting.venue, 

211 race.start_type, 

212 race.distance_m, 

213 starter.barrier, 

214 ) 

215 r_eff += barrier_adj 

216 

217 # Handicap adjustment 

218 if starter.handicap_m is not None and starter.handicap_m != 0: 

219 handicap_adj = self._get_handicap_adjustment( 

220 race.meeting.venue, 

221 race.start_type, 

222 race.distance_m, 

223 starter.handicap_m, 

224 ) 

225 r_eff += handicap_adj 

226 

227 return r_eff 

228 

229 def _get_barrier_adjustment( 

230 self, 

231 venue: str | None, 

232 start_type: str | None, 

233 distance_m: int | None, 

234 barrier: int, 

235 ) -> float: 

236 """Get barrier adjustment from learned table. 

237 

238 Args: 

239 venue: Venue name 

240 start_type: mobile/standing 

241 distance_m: Distance in meters 

242 barrier: Barrier number 

243 

244 Returns: 

245 Adjustment value (default 0.0) 

246 """ 

247 if not self.settings.adj_barrier_enabled: 

248 return 0.0 

249 

250 distance_bucket = get_distance_bucket( 

251 distance_m, 

252 self.settings.distance_buckets, 

253 mode=self.settings.distance_bucket_mode, 

254 bucket_size=self.settings.distance_bucket_size, 

255 ) 

256 

257 # Try specific key first, then fall back to global 

258 key = (venue, start_type, distance_bucket, barrier) 

259 

260 # Fall back to global (no venue/start_type) 

261 global_key = (None, None, distance_bucket, barrier) 

262 return self._resolve_adjustment( 

263 key, 

264 global_key, 

265 self.barrier_adjustments, 

266 self.barrier_adjustment_samples, 

267 ) 

268 

269 def _get_handicap_adjustment( 

270 self, 

271 venue: str | None, 

272 start_type: str | None, 

273 distance_m: int | None, 

274 handicap_m: int, 

275 ) -> float: 

276 """Get handicap adjustment from learned table. 

277 

278 Args: 

279 venue: Venue name 

280 start_type: mobile/standing 

281 distance_m: Distance in meters 

282 handicap_m: Handicap in meters 

283 

284 Returns: 

285 Adjustment value (default 0.0) 

286 """ 

287 if not self.settings.adj_handicap_enabled: 

288 return 0.0 

289 

290 distance_bucket = get_distance_bucket( 

291 distance_m, 

292 self.settings.distance_buckets, 

293 mode=self.settings.distance_bucket_mode, 

294 bucket_size=self.settings.distance_bucket_size, 

295 ) 

296 

297 key = (venue, start_type, distance_bucket, handicap_m) 

298 

299 global_key = (None, None, distance_bucket, handicap_m) 

300 return self._resolve_adjustment( 

301 key, 

302 global_key, 

303 self.handicap_adjustments, 

304 self.handicap_adjustment_samples, 

305 ) 

306 

307 def _resolve_adjustment( 

308 self, 

309 key: tuple, 

310 global_key: tuple, 

311 adjustments: dict[tuple, float], 

312 samples: dict[tuple, int], 

313 ) -> float: 

314 for candidate in (key, global_key): 

315 if candidate not in adjustments: 

316 continue 

317 if self.settings.adj_min_samples > 0: 

318 count = samples.get(candidate, 0) 

319 if count < self.settings.adj_min_samples: 

320 continue 

321 adjustment = adjustments[candidate] 

322 return self._clamp_adjustment(adjustment) 

323 return 0.0 

324 

325 def _clamp_adjustment(self, adjustment: float) -> float: 

326 if self.settings.adj_clamp_min is not None: 

327 adjustment = max(adjustment, self.settings.adj_clamp_min) 

328 if self.settings.adj_clamp_max is not None: 

329 adjustment = min(adjustment, self.settings.adj_clamp_max) 

330 return adjustment 

331 

332 def load_adjustments_from_db(self) -> None: 

333 """Load barrier and handicap adjustments from database into memory.""" 

334 if not self.db_session: 

335 return 

336 

337 # Load barrier adjustments 

338 barrier_adjs = BarrierAdjustmentRepository.get_all(self.db_session) 

339 for adj in barrier_adjs: 

340 key = (adj.venue, adj.start_type, adj.distance_bucket, adj.barrier) 

341 self.barrier_adjustments[key] = adj.adjustment 

342 self.barrier_adjustment_samples[key] = adj.sample_count 

343 

344 # Load handicap adjustments 

345 handicap_adjs = HandicapAdjustmentRepository.get_all(self.db_session) 

346 for adj in handicap_adjs: 

347 key = (adj.venue, adj.start_type, adj.distance_bucket, adj.handicap_m) 

348 self.handicap_adjustments[key] = adj.adjustment 

349 self.handicap_adjustment_samples[key] = adj.sample_count 

350 

351 logger.info( 

352 f"Loaded {len(barrier_adjs)} barrier adjustments and " 

353 f"{len(handicap_adjs)} handicap adjustments from database" 

354 ) 

355 

356 def learn_adjustments_from_race( 

357 self, race: Race, starters: list[Starter], use_global_only: bool | None = None 

358 ) -> None: 

359 """Learn barrier and handicap adjustments from a completed race. 

360 

361 Uses performance residuals: if a horse performs better than expected 

362 given its rating, attribute some of that to favorable conditions. 

363 

364 Args: 

365 race: Race instance 

366 starters: List of starters with results 

367 use_global_only: If True, only update global adjustments (no venue-specific) 

368 """ 

369 if not self.db_session or not self.settings.enable_adjustments: 

370 return 

371 if ( 

372 not self.settings.adj_barrier_enabled 

373 and not self.settings.adj_handicap_enabled 

374 ): 

375 return 

376 

377 valid_starters = [ 

378 s for s in starters if s.placing is not None and not s.did_not_finish 

379 ] 

380 if len(valid_starters) < self.settings.min_finishers: 

381 return 

382 

383 # Compute effective ratings (without adjustments to avoid feedback loop) 

384 saved_enable = self.settings.enable_adjustments 

385 self.settings.enable_adjustments = False 

386 

387 effective_ratings = {} 

388 for starter in valid_starters: 

389 effective_ratings[starter.id] = self.compute_effective_rating(starter, race) 

390 

391 self.settings.enable_adjustments = saved_enable 

392 

393 # For each starter, compute performance residual 

394 use_global_only = ( 

395 self.settings.adj_global_only 

396 if use_global_only is None 

397 else use_global_only 

398 ) 

399 

400 for i, starter in enumerate(valid_starters): 

401 if not starter.horse_id: 

402 continue 

403 

404 r_eff = effective_ratings[starter.id] 

405 placing = starter.placing 

406 

407 # Compute expected vs actual: sum over pairwise comparisons 

408 expected_sum = 0.0 

409 actual_sum = 0.0 

410 

411 for j, other in enumerate(valid_starters): 

412 if i == j or not other.horse_id: 

413 continue 

414 

415 r_other = effective_ratings[other.id] 

416 # Expected: probability of beating other 

417 expected = self.sigmoid((r_eff - r_other) / self.settings.elo_scale_c) 

418 # Actual: 1 if beat other, 0 otherwise 

419 actual = 1.0 if placing < other.placing else 0.0 

420 

421 expected_sum += expected 

422 actual_sum += actual 

423 

424 # Performance delta: positive means outperformed expectations 

425 delta = (actual_sum - expected_sum) / max(len(valid_starters) - 1, 1) 

426 

427 scaled_delta = delta * self.settings.adj_update_scale 

428 

429 # Update barrier adjustment if present 

430 if self.settings.adj_barrier_enabled and starter.barrier is not None: 

431 distance_bucket = get_distance_bucket( 

432 race.distance_m, 

433 self.settings.distance_buckets, 

434 mode=self.settings.distance_bucket_mode, 

435 bucket_size=self.settings.distance_bucket_size, 

436 ) 

437 

438 if use_global_only: 

439 venue, start_type = None, None 

440 else: 

441 venue, start_type = race.meeting.venue, race.start_type 

442 

443 BarrierAdjustmentRepository.increment_sample( 

444 self.db_session, 

445 venue=venue, 

446 start_type=start_type, 

447 distance_bucket=distance_bucket, 

448 barrier=starter.barrier, 

449 delta=scaled_delta, 

450 learning_rate=self.settings.adj_learning_rate, 

451 ) 

452 

453 # Update handicap adjustment if present and non-zero 

454 if ( 

455 self.settings.adj_handicap_enabled 

456 and starter.handicap_m is not None 

457 and starter.handicap_m != 0 

458 ): 

459 distance_bucket = get_distance_bucket( 

460 race.distance_m, 

461 self.settings.distance_buckets, 

462 mode=self.settings.distance_bucket_mode, 

463 bucket_size=self.settings.distance_bucket_size, 

464 ) 

465 

466 if use_global_only: 

467 venue, start_type = None, None 

468 else: 

469 venue, start_type = race.meeting.venue, race.start_type 

470 

471 HandicapAdjustmentRepository.increment_sample( 

472 self.db_session, 

473 venue=venue, 

474 start_type=start_type, 

475 distance_bucket=distance_bucket, 

476 handicap_m=starter.handicap_m, 

477 delta=scaled_delta, 

478 learning_rate=self.settings.adj_learning_rate, 

479 ) 

480 

481 def process_race(self, race: Race, starters: list[Starter]) -> list[RatingUpdate]: 

482 """Process a race and compute rating updates. 

483 

484 Uses pairwise logistic Elo: 

485 - For each pair (i, j), compute expected outcome E_ij 

486 - Update based on actual outcome S_ij (1 if i beat j, 0 otherwise) 

487 - ΔR_i = K * (1/(n-1)) * Σ_j (S_ij - E_ij) 

488 

489 Args: 

490 race: Race instance 

491 starters: List of starters in race 

492 

493 Returns: 

494 List of rating updates to apply 

495 """ 

496 # Filter starters with valid placings 

497 finishers = [ 

498 s for s in starters if s.placing is not None and not s.did_not_finish 

499 ] 

500 dnf_starters = [s for s in starters if s.did_not_finish or s.placing is None] 

501 

502 placing_by_id: dict[int, int] = {} 

503 for starter in finishers: 

504 placing_by_id[starter.id] = starter.placing 

505 

506 if self.settings.dnf_treated_as_last and dnf_starters: 

507 max_place = max(placing_by_id.values(), default=0) 

508 for starter in dnf_starters: 

509 placing_by_id[starter.id] = max_place + 1 

510 valid_starters = finishers + dnf_starters 

511 else: 

512 valid_starters = finishers 

513 

514 if len(valid_starters) < self.settings.min_finishers: 

515 logger.debug( 

516 f"Skipping race {race.id} - fewer than {self.settings.min_finishers} finishers" 

517 ) 

518 return [] 

519 

520 n = len(valid_starters) 

521 updates = [] 

522 

523 # Compute effective ratings for all starters 

524 effective_ratings = {} 

525 for starter in valid_starters: 

526 effective_ratings[starter.id] = self.compute_effective_rating(starter, race) 

527 

528 # Process each starter 

529 for i, starter_i in enumerate(valid_starters): 

530 if not starter_i.horse_id: 

531 continue 

532 

533 r_eff_i = effective_ratings[starter_i.id] 

534 placing_i = starter_i.placing 

535 if starter_i.id in placing_by_id: 

536 placing_i = placing_by_id[starter_i.id] 

537 

538 # Compute update based on pairwise comparisons 

539 delta_sum = 0.0 

540 comparisons = 0 

541 

542 for j, starter_j in enumerate(valid_starters): 

543 if i == j or not starter_j.horse_id: 

544 continue 

545 

546 r_eff_j = effective_ratings[starter_j.id] 

547 placing_j = starter_j.placing 

548 if starter_j.id in placing_by_id: 

549 placing_j = placing_by_id[starter_j.id] 

550 

551 # Actual outcome: 1 if i beat j, 0 otherwise 

552 if placing_i == placing_j: 

553 if self.settings.tie_handling == "skip": 

554 continue 

555 s_ij = 0.5 if self.settings.tie_handling == "half" else 0.0 

556 else: 

557 s_ij = 1.0 if placing_i < placing_j else 0.0 

558 

559 # Expected outcome using logistic model 

560 e_ij = self.sigmoid((r_eff_i - r_eff_j) / self.settings.elo_scale_c) 

561 

562 # Accumulate delta 

563 delta_sum += s_ij - e_ij 

564 comparisons += 1 

565 

566 # Average over pairwise comparisons, using effective K-factor 

567 if self.settings.pairwise_normalizer == "comparisons": 

568 normalizer = comparisons 

569 elif self.settings.pairwise_normalizer == "n": 

570 normalizer = n 

571 else: 

572 normalizer = n - 1 

573 

574 if normalizer <= 0: 

575 continue 

576 

577 k_eff = ( 

578 self.get_effective_k_factor(EntityType.HORSE, starter_i.horse_id) 

579 * self.settings.horse_k_scale 

580 ) 

581 delta_r = k_eff * (delta_sum / normalizer) 

582 

583 # Get race date for RD calculations 

584 race_date = race.meeting.meeting_date if race.meeting else None 

585 

586 # Apply updates to all entities involved 

587 self._apply_update( 

588 EntityType.HORSE, 

589 starter_i.horse_id, 

590 delta_r, 

591 race.id, 

592 race_date, 

593 updates, 

594 ) 

595 

596 if self.settings.enable_driver and starter_i.driver_id: 

597 driver_delta = ( 

598 delta_r 

599 * self.settings.driver_weight_alpha 

600 * self.settings.driver_k_scale 

601 ) 

602 self._apply_update( 

603 EntityType.DRIVER, 

604 starter_i.driver_id, 

605 driver_delta, 

606 race.id, 

607 race_date, 

608 updates, 

609 ) 

610 

611 if self.settings.enable_trainer and starter_i.trainer_id: 

612 trainer_delta = ( 

613 delta_r 

614 * self.settings.trainer_weight_beta 

615 * self.settings.trainer_k_scale 

616 ) 

617 self._apply_update( 

618 EntityType.TRAINER, 

619 starter_i.trainer_id, 

620 trainer_delta, 

621 race.id, 

622 race_date, 

623 updates, 

624 ) 

625 

626 return updates 

627 

628 def _apply_update( 

629 self, 

630 entity_type: EntityType, 

631 entity_id: int, 

632 delta: float, 

633 race_id: int, 

634 race_date: date | None, 

635 updates: list[RatingUpdate], 

636 ) -> None: 

637 """Apply rating update to an entity. 

638 

639 Args: 

640 entity_type: Type of entity 

641 entity_id: Entity ID 

642 delta: Rating change 

643 race_id: Race ID 

644 race_date: Date of the race 

645 updates: List to append update to 

646 """ 

647 state = self.get_or_init_rating(entity_type, entity_id) 

648 old_rating = state.rating 

649 new_rating = old_rating + delta 

650 if self.settings.rating_min is not None: 

651 new_rating = max(new_rating, self.settings.rating_min) 

652 if self.settings.rating_max is not None: 

653 new_rating = min(new_rating, self.settings.rating_max) 

654 delta = new_rating - old_rating 

655 

656 # Update RD if enabled 

657 if self.settings.enable_rd and state.rd is not None: 

658 # First apply inflation for inactivity 

659 if state.last_race_date and race_date: 

660 days_inactive = (race_date - state.last_race_date).days 

661 if days_inactive > 0: 

662 if self.settings.rd_inflation_cap_days is not None: 

663 days_inactive = min( 

664 days_inactive, self.settings.rd_inflation_cap_days 

665 ) 

666 inflation = days_inactive * self.settings.rd_inflation_per_day 

667 state.rd = min(state.rd + inflation, self.settings.rd_max) 

668 

669 # Then apply decay for participating in race 

670 decay = max(self.settings.rd_decay_per_race, self.settings.rd_decay_floor) 

671 state.rd = max(state.rd - decay, self.settings.rd_min) 

672 

673 # Update state 

674 state.rating = new_rating 

675 state.race_count += 1 

676 state.last_race_id = race_id 

677 state.last_race_date = race_date 

678 

679 # Record update 

680 updates.append( 

681 RatingUpdate( 

682 entity_type=entity_type, 

683 entity_id=entity_id, 

684 old_rating=old_rating, 

685 new_rating=new_rating, 

686 delta=delta, 

687 rd=state.rd, 

688 meta={ 

689 "race_count": state.race_count, 

690 }, 

691 ) 

692 ) 

693 

694 logger.debug( 

695 f"{entity_type.value} {entity_id}: " 

696 f"{old_rating:.1f} -> {new_rating:.1f}{delta:+.1f})" 

697 + (f", RD={state.rd:.1f}" if state.rd else "") 

698 )