"""Cache-first TaxJar rate lookup service."""
from datetime import datetime, timedelta
from typing import Optional
import logging

from sqlalchemy.orm import Session
from sqlalchemy.dialects.postgresql import insert as pg_insert

from src.core.config import settings as app_settings
from src.apps.tax.models import TaxRateCache
from src.apps.tax.taxjar_client import TaxJarClient
from src.apps.tax.schemas import TaxRateResponse

logger = logging.getLogger(__name__)


def _client() -> TaxJarClient:
    return TaxJarClient(app_settings.TAXJAR_API_KEY, app_settings.TAXJAR_BASE_URL)


def _ttl_days() -> int:
    return app_settings.TAX_RATE_CACHE_TTL_DAYS


async def get_rate_cached(
    zip: str,
    city: Optional[str],
    state: Optional[str],
    country: str,
    db: Session,
) -> TaxRateResponse:
    """
    Cache-first lookup:
    1. Fresh cached row  → return immediately (no TaxJar call).
    2. Expired cached row → re-fetch TaxJar, upsert, return fresh.
       On TaxJar failure: serve stale value + log warning (graceful degradation).
    3. No cached row → fetch TaxJar, insert, return.
       On TaxJar failure: propagate exception so router can surface 502.
    """
    now = datetime.utcnow()
    cached = (
        db.query(TaxRateCache)
        .filter(TaxRateCache.zip == zip, TaxRateCache.country == country)
        .first()
    )

    if cached and cached.expires_at > now:
        # Cache hit — fresh
        return _to_response(cached)

    # Cache miss or expired — call TaxJar
    try:
        fresh = await _client().get_rate(zip, city, state, country)
        _upsert(fresh, db)
        db.commit()
        return fresh
    except Exception as exc:
        if cached:
            # Stale but available — degrade gracefully
            logger.warning("TaxJar unreachable, serving stale cache for zip=%s: %s", zip, exc)
            return _to_response(cached)
        raise  # No cache at all — let router surface the 502


def _upsert(rate: TaxRateResponse, db: Session) -> None:
    now = datetime.utcnow()
    expires = now + timedelta(days=_ttl_days())
    stmt = (
        pg_insert(TaxRateCache)
        .values(
            zip=rate.zip,
            city=rate.city,
            state=rate.state,
            country=rate.country,
            label=rate.label,
            combined_rate=rate.combined_rate,
            state_rate=rate.state_rate,
            county_rate=rate.county_rate,
            city_rate=rate.city_rate,
            special_district_rate=rate.special_district_rate,
            cached_at=now,
            expires_at=expires,
        )
        .on_conflict_do_update(
            constraint="uq_tax_rate_zip_country",
            set_={
                "city": rate.city,
                "state": rate.state,
                "label": rate.label,
                "combined_rate": rate.combined_rate,
                "state_rate": rate.state_rate,
                "county_rate": rate.county_rate,
                "city_rate": rate.city_rate,
                "special_district_rate": rate.special_district_rate,
                "cached_at": now,
                "expires_at": expires,
            },
        )
    )
    db.execute(stmt)


def _to_response(row: TaxRateCache) -> TaxRateResponse:
    return TaxRateResponse(
        zip=row.zip,
        city=row.city,
        state=row.state,
        country=row.country,
        combined_rate=row.combined_rate,
        state_rate=row.state_rate,
        county_rate=row.county_rate,
        city_rate=row.city_rate,
        special_district_rate=row.special_district_rate,
        label=row.label,
    )
