"""Tests for notification service endpoints and providers."""

import asyncio
import os
import sys
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi.testclient import TestClient

# Ensure the backend directory is on the path so we can import server
BACKEND_DIR = os.path.join(os.path.dirname(__file__), "..", "backend")
if BACKEND_DIR not in sys.path:
    sys.path.insert(0, BACKEND_DIR)

# Set env vars before importing server (required at import time)
os.environ["MONGO_URL"] = "mongodb://localhost:27017"
os.environ["DB_NAME"] = "test_tipsharks"

import server  # noqa: E402

# ===================================================================
#  Helper
# ===================================================================


def _make_mock_collection():
    """Return a MagicMock that behaves like an async MongoDB collection."""
    m = MagicMock()
    m.insert_one = AsyncMock()
    m.find = MagicMock()
    m.find.return_value.sort = MagicMock()
    m.find.return_value.sort.return_value.to_list = AsyncMock(return_value=[])
    m.find.return_value.to_list = AsyncMock(return_value=[])
    m.update_one = AsyncMock()
    m.delete_one = AsyncMock()
    m.delete_many = AsyncMock()
    return m


# ===================================================================
#  Fixtures
# ===================================================================


@pytest.fixture
def client():
    """FastAPI TestClient with all external dependencies mocked.

    This fixture patches MongoDB, provider classes, and rate-limit
    helpers so that TestClient-based tests never hit real services.
    """
    patcher_db = patch.object(server, "db")
    patcher_twilio = patch.object(server, "TwilioClient")
    patcher_email = patch.object(server, "EmailClient")
    patcher_rate_check = patch.object(
        server, "check_rate_limit", new_callable=AsyncMock
    )
    patcher_rate_incr = patch.object(
        server, "increment_rate_limit", new_callable=AsyncMock
    )

    mock_db = patcher_db.start()
    mock_twilio_cls = patcher_twilio.start()
    mock_email_cls = patcher_email.start()
    mock_rate_check = patcher_rate_check.start()
    mock_rate_incr = patcher_rate_incr.start()

    # --- MongoDB mocks ---
    mock_db.notifications = _make_mock_collection()
    mock_db.scheduled_notifications = _make_mock_collection()
    mock_db.rate_limits = _make_mock_collection()
    mock_db.rate_limits.find_one = AsyncMock(return_value=None)

    # --- Rate-limit helpers ---
    mock_rate_check.return_value = (True, 10)
    mock_rate_incr.return_value = None

    # --- Provider mocks ---
    mock_twilio = MagicMock()
    mock_twilio.send_sms = AsyncMock(return_value=True)
    mock_twilio_cls.return_value = mock_twilio

    mock_email = MagicMock()
    mock_email.send_email = AsyncMock(return_value=True)
    mock_email_cls.return_value = mock_email

    with TestClient(server.app) as c:
        yield c

    patcher_db.stop()
    patcher_twilio.stop()
    patcher_email.stop()
    patcher_rate_check.stop()
    patcher_rate_incr.stop()


# ===================================================================
#  Provider-unit tests
# ===================================================================


class TestTwilioClient:
    """TwilioClient – direct unit tests (no fixture-based mocking)."""

    def test_send_sms_configured(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_resp = MagicMock()
            mock_resp.json.return_value = {"sid": "SM123"}
            mock_ac.return_value.__aenter__.return_value.post = AsyncMock(
                return_value=mock_resp
            )

            tc = server.TwilioClient()
            tc.account_sid = "sid"
            tc.auth_token = "token"
            tc.phone_number = "+15551234567"

            loop = asyncio.get_event_loop()
            success = loop.run_until_complete(tc.send_sms("+15559876543", "Hello!"))
            assert success is True

    def test_send_sms_fallback(self):
        tc = server.TwilioClient()
        # No credentials set → enabled is False → fallback to logging
        loop = asyncio.get_event_loop()
        success = loop.run_until_complete(tc.send_sms("+15559876543", "Hello!"))
        assert success is True  # fallback always succeeds

    def test_send_sms_failure(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_ac.return_value.__aenter__.return_value.post = AsyncMock(
                side_effect=Exception("Twilio down")
            )

            tc = server.TwilioClient()
            tc.account_sid = "sid"
            tc.auth_token = "token"
            tc.phone_number = "+15551234567"

            loop = asyncio.get_event_loop()
            success = loop.run_until_complete(tc.send_sms("+15559876543", "Oops"))
            assert success is False

    def test_send_sms_not_enabled_without_creds(self):
        tc = server.TwilioClient()
        tc.account_sid = None
        tc.auth_token = None
        tc.phone_number = None
        assert tc.enabled is False

    def test_send_sms_enabled_with_creds(self):
        tc = server.TwilioClient()
        tc.account_sid = "sid"
        tc.auth_token = "token"
        tc.phone_number = "+15551234567"
        assert tc.enabled is True


class TestEmailClient:
    """EmailClient – direct unit tests (no fixture-based mocking)."""

    def test_sendgrid_configured(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_ac.return_value.__aenter__.return_value.post = AsyncMock()

            ec = server.EmailClient()
            ec.sendgrid_key = "sg_key"
            ec.resend_key = None

            loop = asyncio.get_event_loop()
            result = loop.run_until_complete(ec.send_email("a@b.com", "Sub", "Body"))
            assert result is True

    def test_resend_fallback(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_ac.return_value.__aenter__.return_value.post = AsyncMock()

            ec = server.EmailClient()
            ec.sendgrid_key = None
            ec.resend_key = "re_key"

            loop = asyncio.get_event_loop()
            result = loop.run_until_complete(ec.send_email("a@b.com", "Sub", "Body"))
            assert result is True

    def test_email_fallback_logging(self):
        ec = server.EmailClient()
        ec.sendgrid_key = None
        ec.resend_key = None

        loop = asyncio.get_event_loop()
        result = loop.run_until_complete(ec.send_email("a@b.com", "Sub", "Body"))
        assert result is True  # fallback always succeeds

    def test_email_failure(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_ac.return_value.__aenter__.return_value.post = AsyncMock(
                side_effect=Exception("SendGrid down")
            )

            ec = server.EmailClient()
            ec.sendgrid_key = "sg_key"
            ec.resend_key = None

            loop = asyncio.get_event_loop()
            result = loop.run_until_complete(ec.send_email("a@b.com", "Sub", "Body"))
            assert result is False

    def test_send_email_with_html(self):
        with patch.object(server.httpx, "AsyncClient") as mock_ac:
            mock_post = AsyncMock()
            mock_ac.return_value.__aenter__.return_value.post = mock_post

            ec = server.EmailClient()
            ec.sendgrid_key = "sg_key"

            loop = asyncio.get_event_loop()
            result = loop.run_until_complete(
                ec.send_email("a@b.com", "Sub", "Body", html="<p>Body</p>")
            )
            assert result is True

            # Verify the content includes text/html
            call_kwargs = mock_post.call_args[1]
            payload = call_kwargs["json"]
            content_types = {c["type"] for c in payload["content"]}
            assert "text/plain" in content_types
            assert "text/html" in content_types


# ===================================================================
#  Rate-limit unit tests (use real functions, mock only db)
# ===================================================================


class TestRateLimiting:
    """Rate-limit helper tests using real functions with mocked db."""

    @pytest.fixture(autouse=True)
    def _mock_db(self):
        """Mock server.db only (don't patch the functions themselves)."""
        with patch.object(server, "db") as mock_db:
            mock_db.rate_limits = _make_mock_collection()
            yield mock_db

    def test_check_rate_limit_allows(self, _mock_db):
        _mock_db.rate_limits.find_one = AsyncMock(return_value=None)

        loop = asyncio.get_event_loop()
        allowed, remaining = loop.run_until_complete(
            server.check_rate_limit("user-1", "sms")
        )
        assert allowed is True
        assert remaining == 10

    def test_check_rate_limit_blocks(self, _mock_db):
        _mock_db.rate_limits.find_one = AsyncMock(return_value={"count": 10})

        loop = asyncio.get_event_loop()
        allowed, remaining = loop.run_until_complete(
            server.check_rate_limit("user-1", "sms")
        )
        assert allowed is False
        assert remaining == 0

    def test_check_rate_limit_partial(self, _mock_db):
        _mock_db.rate_limits.find_one = AsyncMock(return_value={"count": 4})

        loop = asyncio.get_event_loop()
        allowed, remaining = loop.run_until_complete(
            server.check_rate_limit("user-1", "sms")
        )
        assert allowed is True
        assert remaining == 6

    def test_increment_rate_limit(self, _mock_db):
        loop = asyncio.get_event_loop()
        loop.run_until_complete(server.increment_rate_limit("user-1", "email"))
        _mock_db.rate_limits.update_one.assert_called_once()

    def test_rate_limits_constants(self):
        assert server.RATE_LIMITS["sms"] == 10
        assert server.RATE_LIMITS["email"] == 50
        assert server.RATE_LIMITS["push"] == 100


# ===================================================================
#  API endpoint tests (use the client fixture)
# ===================================================================


class TestSendNotification:
    """POST /api/notifications/send"""

    def test_send_sms(self, client):
        payload = {
            "to": "+15551234567",
            "subject": "Test SMS",
            "body": "Hello from test",
            "channel": "sms",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "sent"
        assert data["channel"] == "sms"
        assert "id" in data

        # Verify rate-limit was incremented
        server.increment_rate_limit.assert_called_once()

    def test_send_email(self, client):
        payload = {
            "to": "user@example.com",
            "subject": "Test Email",
            "body": "Hello from test",
            "channel": "email",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "sent"
        assert data["channel"] == "email"

    def test_send_push(self, client):
        payload = {
            "to": "device-token",
            "subject": "Test Push",
            "body": "Hello from test",
            "channel": "push",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "sent"
        assert data["channel"] == "push"

    def test_invalid_channel(self, client):
        payload = {
            "to": "someone",
            "subject": "Bad",
            "body": "Bad channel",
            "channel": "fax",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 400, resp.text
        assert "fax" in resp.text or "Unsupported channel" in resp.text

    def test_rate_limit_exceeded(self, client):
        """When check_rate_limit returns (False, 0) the endpoint should 429."""
        server.check_rate_limit.return_value = (False, 0)

        payload = {
            "to": "+15551234567",
            "subject": "Rate limited",
            "body": "Should fail",
            "channel": "sms",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 429, resp.text
        assert "Rate limit" in resp.text

    def test_sms_provider_failure(self, client):
        """When Twilio returns failure, status should be 'failed'."""
        twilio_instance = server.TwilioClient.return_value
        twilio_instance.send_sms = AsyncMock(return_value=False)

        payload = {
            "to": "+15551234567",
            "subject": "Failing SMS",
            "body": "Will fail",
            "channel": "sms",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "failed"
        assert data["error"] is not None

    def test_custom_user_id(self, client):
        """user_id in request body overrides the authenticated user."""
        payload = {
            "to": "+15551234567",
            "subject": "Custom user",
            "body": "Test",
            "channel": "push",
            "user_id": "custom-user-123",
        }
        resp = client.post("/api/notifications/send", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "sent"


class TestScheduleNotification:
    """POST /api/notifications/schedule"""

    def test_schedule_future(self, client):
        future = (datetime.utcnow() + timedelta(hours=2)).isoformat()
        payload = {
            "to": "user@example.com",
            "subject": "Scheduled",
            "body": "Later",
            "channel": "email",
            "scheduled_time": future,
        }
        resp = client.post("/api/notifications/schedule", json=payload)
        assert resp.status_code == 200, resp.text
        data = resp.json()
        assert data["status"] == "scheduled"
        assert data["channel"] == "email"
        assert "id" in data

    def test_schedule_past_rejected(self, client):
        past = (datetime.utcnow() - timedelta(hours=1)).isoformat()
        payload = {
            "to": "user@example.com",
            "subject": "Past",
            "body": "Too late",
            "channel": "email",
            "scheduled_time": past,
        }
        resp = client.post("/api/notifications/schedule", json=payload)
        assert resp.status_code == 400, resp.text
        assert "future" in resp.text

    def test_schedule_with_race_id(self, client):
        future = (datetime.utcnow() + timedelta(hours=2)).isoformat()
        payload = {
            "to": "user@example.com",
            "subject": "Race alert",
            "body": "Race starts soon",
            "channel": "sms",
            "scheduled_time": future,
            "race_id": "race-123",
        }
        resp = client.post("/api/notifications/schedule", json=payload)
        assert resp.status_code == 200, resp.text


class TestProvidersEndpoint:
    """GET /api/notifications/providers"""

    def test_providers_default(self, client):
        """Without env vars for Twilio/SendGrid/Resend, only push should show."""
        resp = client.get("/api/notifications/providers")
        assert resp.status_code == 200
        data = resp.json()
        prov = data["providers"]
        assert prov["push"]["available"] is True
        assert prov["sms"]["available"] is False
        assert prov["email"]["available"] is False

    def test_providers_twilio_configured(self, client):
        with patch.dict(
            os.environ,
            {
                "TWILIO_ACCOUNT_SID": "sid",
                "TWILIO_AUTH_TOKEN": "token",
                "TWILIO_PHONE_NUMBER": "+15551234567",
            },
        ):
            resp = client.get("/api/notifications/providers")
            assert resp.status_code == 200
            data = resp.json()
            assert data["providers"]["sms"]["available"] is True
            assert data["providers"]["sms"]["provider"] == "twilio"

    def test_providers_email_configured(self, client):
        with patch.dict(os.environ, {"SENDGRID_API_KEY": "sg_key"}):
            resp = client.get("/api/notifications/providers")
            assert resp.status_code == 200
            data = resp.json()
            assert data["providers"]["email"]["available"] is True
            assert data["providers"]["email"]["providers"]["sendgrid"] is True
