"""Track condition adjustments for Elo ratings.

Foundational stub for a model that learns how different track conditions
(wet, dry, heavy, fast) affect horse/driver/trainer performance.

The core idea:
  - Some horses perform significantly better on wet tracks ("mudders") while
    others prefer fast, dry surfaces.
  - We maintain per-horse track condition adjustments that are learned from
    historical performance residuals, similar to how barrier/handicap
    adjustments are learned.
  - These adjustments are applied to the effective rating: a horse with a
    positive wet-track adjustment gets a rating boost when the track is heavy.

This module is **not yet integrated** into the main ``RatingEngine``.
Integration would require:
  1. Adding a ``TrackConditionAdjustment`` model (DB table).
  2. Loading adjustments into the engine via a repository.
  3. Applying the adjustment in ``compute_effective_rating`` when a race's
     ``track_condition`` field is populated.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from enum import StrEnum


class TrackConditionCategory(StrEnum):
    """Broad categorisation of track conditions.

    Mapped from raw track condition strings (e.g. "Good3", "Heavy10",
    "Soft5", "Fast") into these categories for adjustment learning.
    """

    FAST = "fast"  # Firm / fast / hard
    GOOD = "good"  # Good / dead
    SOFT = "soft"  # Soft / slow / easy
    HEAVY = "heavy"  # Heavy / slop / muddy
    UNKNOWN = "unknown"


# ── Mapping helpers ──────────────────────────────────────────────────────


def _categorise_track(condition: str | None) -> TrackConditionCategory:
    """Map a raw track condition string to a category.

    Handles common NZ/AU track rating formats:
      - "Good3", "Good4" → ``GOOD``
      - "Soft5", "Soft6" → ``SOFT``
      - "Heavy8", "Heavy10" → ``HEAVY``
      - "Fast", "Firm" → ``FAST``
      - "Dead" → ``GOOD``
      - "Slow" → ``SOFT``
    """
    if not condition:
        return TrackConditionCategory.UNKNOWN

    c = condition.strip().lower()

    for keyword, cat in [
        ("heavy", TrackConditionCategory.HEAVY),
        ("soft", TrackConditionCategory.SOFT),
        ("slow", TrackConditionCategory.SOFT),
        ("slop", TrackConditionCategory.HEAVY),
        ("muddy", TrackConditionCategory.HEAVY),
        ("good", TrackConditionCategory.GOOD),
        ("dead", TrackConditionCategory.GOOD),
        ("fast", TrackConditionCategory.FAST),
        ("firm", TrackConditionCategory.FAST),
        ("hard", TrackConditionCategory.FAST),
    ]:
        if keyword in c:
            return cat

    return TrackConditionCategory.UNKNOWN


# ── Track condition model ────────────────────────────────────────────────


@dataclass
class TrackConditionModel:
    """Learn and apply track condition adjustments.

    Maintains a dictionary of ``(entity_type, entity_id, condition_category)``
    → adjustment value, learned from performance residuals.
    """

    adjustments: dict[tuple[str, int, str], float] = field(default_factory=dict)
    sample_counts: dict[tuple[str, int, str], int] = field(default_factory=dict)

    # Learning parameters
    learning_rate: float = 0.1
    min_samples: int = 5
    max_adjustment: float = 50.0

    # ── Learning ─────────────────────────────────────────────────────

    def learn_from_performance(
        self,
        entity_type: str,
        entity_id: int,
        track_condition: str | None,
        performance_residual: float,
    ) -> None:
        """Update the adjustment for an entity + track condition.

        ``performance_residual`` measures how much the entity outperformed
        (positive) or underperformed (negative) expectations, in rating
        points. This is analogous to the residual computed in
        ``RatingEngine.learn_adjustments_from_race``.

        Args:
            entity_type: ``"horse"``, ``"driver"``, or ``"trainer"``.
            entity_id: Entity ID.
            track_condition: Raw track condition string (e.g. ``"Good3"``).
            performance_residual: Rating delta to attribute to the condition.
        """
        category = _categorise_track(track_condition)
        if category == TrackConditionCategory.UNKNOWN:
            return

        key = (entity_type, entity_id, category.value)

        current = self.adjustments.get(key, 0.0)
        count = self.sample_counts.get(key, 0)

        # Incremental moving average
        new_count = count + 1
        new_adjustment = current + self.learning_rate * (performance_residual - current)
        new_adjustment = max(
            -self.max_adjustment, min(self.max_adjustment, new_adjustment)
        )

        self.adjustments[key] = new_adjustment
        self.sample_counts[key] = new_count

    # ── Application ───────────────────────────────────────────────────

    def get_adjustment(
        self,
        entity_type: str,
        entity_id: int,
        track_condition: str | None,
    ) -> float:
        """Get the track condition adjustment for an entity.

        Returns 0.0 if there are insufficient samples.

        Args:
            entity_type: ``"horse"``, ``"driver"``, or ``"trainer"``.
            entity_id: Entity ID.
            track_condition: Raw track condition string.

        Returns:
            Adjustment value in rating points (positive = better on this
            surface, negative = worse).
        """
        category = _categorise_track(track_condition)
        if category == TrackConditionCategory.UNKNOWN:
            return 0.0

        key = (entity_type, entity_id, category.value)
        count = self.sample_counts.get(key, 0)

        if count < self.min_samples:
            return 0.0

        return self.adjustments.get(key, 0.0)

    # ── Population-level stats ────────────────────────────────────────

    def get_population_adjustment(
        self,
        track_condition: str | None,
    ) -> float:
        """Get the average adjustment across all entities for a condition.

        Useful as a fallback when no per-entity data is available.

        Args:
            track_condition: Raw track condition string.

        Returns:
            Population-average adjustment.
        """
        category = _categorise_track(track_condition)
        if category == TrackConditionCategory.UNKNOWN:
            return 0.0

        vals = [
            adj
            for (_, _, cat), adj in self.adjustments.items()
            if cat == category.value
        ]
        if not vals:
            return 0.0
        return sum(vals) / len(vals)

    def reset(self) -> None:
        """Clear all learned adjustments (for re-learning)."""
        self.adjustments.clear()
        self.sample_counts.clear()


# ── Example / test usage ─────────────────────────────────────────────────


def _demo() -> None:
    """Demonstrate track condition learning and application."""
    model = TrackConditionModel(learning_rate=0.15, min_samples=3)

    print("Track Condition Model — demo")
    print()

    # Simulate a "mudder" horse that performs better on wet tracks
    print("Simulating a horse that performs well on heavy tracks...")
    for _ in range(10):
        # Residual: +20 rating points on heavy tracks
        model.learn_from_performance("horse", 42, "Heavy10", 20.0)
        # Residual: -5 on good tracks
        model.learn_from_performance("horse", 42, "Good3", -5.0)

    print(f"  Horse 42 on Heavy:  {model.get_adjustment('horse', 42, 'Heavy10'):+.1f}")
    print(f"  Horse 42 on Good:   {model.get_adjustment('horse', 42, 'Good3'):+.1f}")
    print(
        f"  Horse 42 on Fast:   {model.get_adjustment('horse', 42, 'Fast'):+.1f}  (no data)"
    )
    print(f"  Population heavy:   {model.get_population_adjustment('Heavy10'):+.1f}")


if __name__ == "__main__":
    _demo()
