Coverage for packages / ml / ensemble.py: 0%
13 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:14 +1200
« prev ^ index » next coverage.py v7.13.5, created at 2026-05-08 08:14 +1200
1"""Ensemble model combining Elo ratings with ML predictions.
3The EnsembleModel takes Elo-based probabilities from the rating engine
4and blends them with additional ML predictions (from feature-engineered
5models) to produce a final probability estimate.
6"""
8from __future__ import annotations
10from typing import Any
12from sqlalchemy.orm import Session
15class EnsembleModel:
16 """Combines Elo ratings with ML predictions for improved accuracy.
18 Supports multiple blending strategies:
19 - Weighted average (fixed weights learned offline)
20 - Logistic calibration (train a logistic regression on top of
21 Elo + ML features)
22 - Stacking (use ML predictions as features for a meta-model)
23 """
25 def __init__(
26 self,
27 session: Session,
28 blend_mode: str = "weighted",
29 ) -> None:
30 """Initialize the ensemble model.
32 Args:
33 session: Database session for loading data.
34 blend_mode: Blending strategy — "weighted", "logistic",
35 or "stacking".
36 """
37 self._session = session
38 self.blend_mode = blend_mode
39 self._weights: dict[str, float] = {}
41 def train(
42 self,
43 date_from: str,
44 date_to: str,
45 ) -> dict[str, Any]:
46 """Train the ensemble (e.g., learn blend weights).
48 Args:
49 date_from: Start date string (YYYY-MM-DD).
50 date_to: End date string (YYYY-MM-DD).
52 Returns:
53 Training metrics dict.
54 """
55 raise NotImplementedError
57 def predict(
58 self,
59 race_id: int,
60 ) -> list[dict[str, Any]]:
61 """Predict win/place probabilities for all starters in a race.
63 Combines Elo ratings with ML feature predictions.
65 Args:
66 race_id: The race ID.
68 Returns:
69 List of prediction dicts, one per starter, with keys:
70 - starter_id: int
71 - win_probability: float
72 - place_probability: float
73 - elo_win_prob: float
74 - ml_win_prob: float (if applicable)
75 """
76 raise NotImplementedError
78 def evaluate(
79 self,
80 date_from: str,
81 date_to: str,
82 ) -> dict[str, Any]:
83 """Evaluate ensemble performance on a test window.
85 Args:
86 date_from: Start date string (YYYY-MM-DD).
87 date_to: End date string (YYYY-MM-DD).
89 Returns:
90 Dict of evaluation metrics (accuracy, Brier score, log loss,
91 top-3 rate, etc.).
92 """
93 raise NotImplementedError
95 def save_weights(self, path: str) -> None:
96 """Persist learned blend weights to disk.
98 Args:
99 path: File path to save weights to (JSON format).
100 """
101 raise NotImplementedError
103 def load_weights(self, path: str) -> None:
104 """Load learned blend weights from disk.
106 Args:
107 path: File path to load weights from (JSON format).
108 """
109 raise NotImplementedError