"""SQLAlchemy models for TipSharks database schema."""

from enum import StrEnum

from sqlalchemy import (
    Boolean,
    Column,
    Date,
    DateTime,
    Enum,
    Float,
    ForeignKey,
    Integer,
    String,
    UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func

Base = declarative_base()


class EntityType(StrEnum):
    """Entity types for ratings."""

    HORSE = "horse"
    DRIVER = "driver"
    TRAINER = "trainer"


class Meeting(Base):
    """Racing meeting/meet from TAB API."""

    __tablename__ = "meetings"

    id = Column(String(64), primary_key=True, comment="TAB meeting ID")
    meeting_date = Column(Date, nullable=False, index=True)
    venue = Column(String(255), nullable=False)
    category = Column(
        String(1), nullable=False, index=True, comment="Racing type: T, H, or G"
    )
    raw_json = Column(JSONB, nullable=False, comment="Original TAB API response")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Relationships
    races = relationship("Race", back_populates="meeting", cascade="all, delete-orphan")

    def __repr__(self):
        return f"<Meeting(id={self.id}, venue='{self.venue}', date={self.meeting_date}, category='{self.category}')>"


class Race(Base):
    """Individual race within a meeting from TAB API."""

    __tablename__ = "races"

    id = Column(Integer, primary_key=True, autoincrement=True)
    meeting_id = Column(
        String(64),
        ForeignKey("meetings.id", ondelete="CASCADE"),
        nullable=False,
        index=True,
    )
    tab_event_id = Column(String(64), nullable=True, index=True, comment="TAB event ID")
    race_number = Column(Integer, nullable=False)

    distance_m = Column(Integer, nullable=True, comment="Distance in meters")
    start_type = Column(String(50), nullable=True, comment="Mobile or Standing")
    gait = Column(String(50), nullable=True, comment="Pace or Trot")
    weather = Column(String(64), nullable=True, comment="Weather at race time")
    track_condition = Column(String(64), nullable=True, comment="Track condition")

    race_datetime = Column(DateTime, nullable=True, index=True)
    raw_json = Column(JSONB, nullable=False, comment="Original TAB API response")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Composite unique constraint
    __table_args__ = (
        UniqueConstraint("meeting_id", "race_number", name="uq_meeting_race"),
    )

    # Relationships
    meeting = relationship("Meeting", back_populates="races")
    starters = relationship(
        "Starter", back_populates="race", cascade="all, delete-orphan"
    )
    rating_snapshots = relationship("RatingSnapshot", back_populates="race")

    def __repr__(self):
        return (
            f"<Race(id={self.id}, meeting_id={self.meeting_id}, "
            f"race_number={self.race_number})>"
        )


class Horse(Base):
    """Horse dimension table."""

    __tablename__ = "horses"

    id = Column(Integer, primary_key=True, comment="TAB horse ID")
    name = Column(String(255), nullable=False, index=True)
    raw_json = Column(JSONB, nullable=True, comment="Additional horse metadata")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Relationships
    starters = relationship(
        "Starter", foreign_keys="Starter.horse_id", back_populates="horse"
    )

    def __repr__(self):
        return f"<Horse(id={self.id}, name='{self.name}')>"


class Driver(Base):
    """Driver dimension table.

    Driver IDs are generated from name hash since TAB API only provides names.
    """

    __tablename__ = "drivers"

    id = Column(Integer, primary_key=True, comment="Generated from name hash")
    name = Column(String(255), nullable=False, index=True)
    raw_json = Column(JSONB, nullable=True, comment="Additional driver metadata")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Relationships
    starters = relationship(
        "Starter", foreign_keys="Starter.driver_id", back_populates="driver"
    )

    def __repr__(self):
        return f"<Driver(id={self.id}, name='{self.name}')>"


class Trainer(Base):
    """Trainer dimension table.

    Trainer IDs are generated from name hash since TAB API only provides names.
    """

    __tablename__ = "trainers"

    id = Column(Integer, primary_key=True, comment="Generated from name hash")
    name = Column(String(255), nullable=False, index=True)
    raw_json = Column(JSONB, nullable=True, comment="Additional trainer metadata")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Relationships
    starters = relationship(
        "Starter", foreign_keys="Starter.trainer_id", back_populates="trainer"
    )

    def __repr__(self):
        return f"<Trainer(id={self.id}, name='{self.name}')>"


class Starter(Base):
    """Runner/starter in a race from TAB API."""

    __tablename__ = "starters"

    id = Column(Integer, primary_key=True, autoincrement=True)
    race_id = Column(
        Integer, ForeignKey("races.id", ondelete="CASCADE"), nullable=False, index=True
    )

    horse_id = Column(
        Integer, ForeignKey("horses.id", ondelete="SET NULL"), nullable=True, index=True
    )
    driver_id = Column(
        Integer,
        ForeignKey("drivers.id", ondelete="SET NULL"),
        nullable=True,
        index=True,
    )
    trainer_id = Column(
        Integer,
        ForeignKey("trainers.id", ondelete="SET NULL"),
        nullable=True,
        index=True,
    )

    runner_number = Column(
        Integer, nullable=True, index=True, comment="Saddlecloth number"
    )
    barrier = Column(Integer, nullable=True, comment="Starting barrier/gate number")
    barrier_position = Column(
        String(10), nullable=True, comment="Harness position: 1F, 2B, etc."
    )
    handicap_m = Column(
        Integer, nullable=True, comment="Handicap in meters (back marks)"
    )

    placing = Column(
        Integer, nullable=True, comment="Final placing (1=winner, NULL=DNF/no result)"
    )
    did_not_finish = Column(Boolean, default=False, comment="DNF, pulled up, etc.")

    raw_json = Column(JSONB, nullable=False, comment="Original runner data from TAB")

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    # Relationships
    race = relationship("Race", back_populates="starters")
    horse = relationship("Horse", back_populates="starters")
    driver = relationship("Driver", back_populates="starters")
    trainer = relationship("Trainer", back_populates="starters")

    def __repr__(self):
        return (
            f"<Starter(id={self.id}, race_id={self.race_id}, "
            f"horse_id={self.horse_id}, placing={self.placing})>"
        )


class RatingSnapshot(Base):
    """Rating snapshot for an entity at a point in time (after a race)."""

    __tablename__ = "rating_snapshots"

    id = Column(Integer, primary_key=True, autoincrement=True)

    entity_type = Column(
        Enum(EntityType),
        nullable=False,
        index=True,
        comment="horse, driver, or trainer",
    )
    entity_id = Column(
        Integer, nullable=False, index=True, comment="ID in respective table"
    )

    as_of_race_id = Column(
        Integer,
        ForeignKey("races.id", ondelete="CASCADE"),
        nullable=False,
        index=True,
        comment="Rating after this race",
    )

    rating = Column(Float, nullable=False, comment="Elo rating")
    rd = Column(Float, nullable=True, comment="Rating deviation (uncertainty)")

    meta = Column(
        JSONB,
        nullable=True,
        comment="Additional metadata: components, K used, race count, etc.",
    )

    created_at = Column(DateTime, server_default=func.now(), nullable=False)

    # Composite index for efficient lookups
    __table_args__ = (
        UniqueConstraint(
            "entity_type",
            "entity_id",
            "as_of_race_id",
            name="uq_entity_race_snapshot",
        ),
    )

    # Relationship
    race = relationship("Race", back_populates="rating_snapshots")

    def __repr__(self):
        return (
            f"<RatingSnapshot(entity_type={self.entity_type.value}, "
            f"entity_id={self.entity_id}, rating={self.rating:.1f})>"
        )


class BarrierAdjustment(Base):
    """Learned adjustments for barrier positions."""

    __tablename__ = "barrier_adjustments"

    id = Column(Integer, primary_key=True, autoincrement=True)

    venue = Column(
        String(255), nullable=True, comment="Specific venue or NULL for global"
    )
    start_type = Column(
        String(50), nullable=True, comment="mobile/standing or NULL for any"
    )
    distance_bucket = Column(
        String(50), nullable=False, comment="e.g., '<1700', '1700-2000'"
    )
    barrier = Column(Integer, nullable=False, comment="Barrier number")

    adjustment = Column(
        Float,
        default=0.0,
        nullable=False,
        comment="Rating points to add for this barrier",
    )

    sample_count = Column(
        Integer, default=0, nullable=False, comment="Number of observations"
    )

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    __table_args__ = (
        UniqueConstraint(
            "venue",
            "start_type",
            "distance_bucket",
            "barrier",
            name="uq_barrier_adjustment",
        ),
    )

    def __repr__(self):
        return (
            f"<BarrierAdjustment(venue={self.venue}, barrier={self.barrier}, "
            f"adjustment={self.adjustment:.2f})>"
        )


class HandicapAdjustment(Base):
    """Learned adjustments for handicaps (back marks)."""

    __tablename__ = "handicap_adjustments"

    id = Column(Integer, primary_key=True, autoincrement=True)

    venue = Column(
        String(255), nullable=True, comment="Specific venue or NULL for global"
    )
    start_type = Column(
        String(50), nullable=True, comment="mobile/standing or NULL for any"
    )
    distance_bucket = Column(
        String(50), nullable=False, comment="e.g., '<1700', '1700-2000'"
    )
    handicap_m = Column(Integer, nullable=False, comment="Handicap in meters")

    adjustment = Column(
        Float,
        default=0.0,
        nullable=False,
        comment="Rating points to add for this handicap",
    )

    sample_count = Column(
        Integer, default=0, nullable=False, comment="Number of observations"
    )

    created_at = Column(DateTime, server_default=func.now(), nullable=False)
    updated_at = Column(
        DateTime, server_default=func.now(), onupdate=func.now(), nullable=False
    )

    __table_args__ = (
        UniqueConstraint(
            "venue",
            "start_type",
            "distance_bucket",
            "handicap_m",
            name="uq_handicap_adjustment",
        ),
    )

    def __repr__(self):
        return (
            f"<HandicapAdjustment(venue={self.venue}, handicap_m={self.handicap_m}, "
            f"adjustment={self.adjustment:.2f})>"
        )


class PredictionHistory(Base):
    """Record of win/place predictions and actual results for accuracy tracking."""

    __tablename__ = "prediction_history"

    id = Column(Integer, primary_key=True, autoincrement=True)
    race_id = Column(
        Integer, ForeignKey("races.id", ondelete="CASCADE"), nullable=False, index=True
    )
    horse_id = Column(
        Integer, ForeignKey("horses.id", ondelete="CASCADE"), nullable=False, index=True
    )
    predicted_win_prob = Column(Float, nullable=False)
    predicted_place_prob = Column(Float, nullable=False)
    predicted_placing = Column(Integer, nullable=False)
    actual_placing = Column(Integer, nullable=True)
    brier_score = Column(Float, nullable=True)
    created_at = Column(DateTime, server_default=func.now(), nullable=False)


class AuditLog(Base):
    """Audit log for tracking data changes and corrections."""

    __tablename__ = "audit_logs"

    id = Column(Integer, primary_key=True, autoincrement=True)
    table_name = Column(String(64), nullable=False, index=True)
    record_id = Column(String(64), nullable=False, index=True)
    action = Column(
        String(20), nullable=False, index=True
    )  # INSERT, UPDATE, DELETE, CORRECT
    old_values = Column(JSONB, nullable=True)
    new_values = Column(JSONB, nullable=True)
    changed_by = Column(String(255), nullable=True, index=True)  # user identifier
    change_reason = Column(String(500), nullable=True)
    created_at = Column(DateTime, server_default=func.now(), nullable=False)
