"""Celery task: monthly tax rate cache refresh."""
import asyncio
import logging
from datetime import datetime, timedelta

from src.worker.celery_app import celery_app
from src.core.database import SessionCelery
from src.core.config import settings as app_settings
from src.apps.tax.models import TaxRateCache
from src.apps.tax.taxjar_client import TaxJarClient

logger = logging.getLogger(__name__)


@celery_app.task(name="tax.refresh_tax_rate_cache", bind=True, max_retries=0)
def refresh_tax_rate_cache(self):
    """Re-fetch every row in tax_rate_cache from TaxJar and reset expires_at.

    Runs on the 1st of every month at 02:00 UTC via Celery Beat.
    Skips individual rows that fail; never aborts the full run.
    Idempotent — safe to run manually at any time.
    """
    client = TaxJarClient(app_settings.TAXJAR_API_KEY, app_settings.TAXJAR_BASE_URL)

    with SessionCelery() as db:
        rows = db.query(TaxRateCache).all()
        refreshed = skipped = 0

        for row in rows:
            try:
                fresh = asyncio.run(
                    client.get_rate(row.zip, row.city, row.state, row.country)
                )
                now = datetime.utcnow()
                row.combined_rate = fresh.combined_rate
                row.state_rate = fresh.state_rate
                row.county_rate = fresh.county_rate
                row.city_rate = fresh.city_rate
                row.special_district_rate = fresh.special_district_rate
                row.label = fresh.label
                row.cached_at = now
                row.expires_at = now + timedelta(days=app_settings.TAX_RATE_CACHE_TTL_DAYS)
                refreshed += 1
            except Exception as exc:
                logger.error("Failed to refresh tax cache for zip=%s: %s", row.zip, exc)
                skipped += 1

        db.commit()
        logger.info(
            "Tax rate cache refresh complete. refreshed=%d skipped=%d", refreshed, skipped
        )
