"""
Celery task: bulk-mark PENDING/WAITING invoices as OVERDUE.

Task name: invoice.mark_overdue_invoices
Schedule:  daily at 00:05 UTC (Celery Beat)

The task is fully idempotent — re-running it for the same date is safe.

For each affected invoice the task:
  1. Updates the status to OVERDUE.
  2. Writes an InvoiceActivity(INVOICE_OVERDUE) record.
  3. Emits an ``invoice.overdue`` Kafka event so the listener can
     dispatch the customer overdue notification email.
"""

import asyncio
import logging
from datetime import datetime, timezone

from celery.utils.log import get_task_logger

from src.worker.celery_app import celery_app
from src.core.database import SessionCelery
from src.core.utils.enums import InvoiceStatusTypes, InvoiceActivityTypes

logger = get_task_logger(__name__)

# Statuses that can transition to OVERDUE
_OVERDUEABLE_STATUSES = [
    InvoiceStatusTypes.PENDING,
    InvoiceStatusTypes.WAITING,
]


@celery_app.task(name="invoice.mark_overdue_invoices")
def mark_overdue_invoices() -> dict:
    """
    Idempotent daily task at 00:05 UTC.

    Bulk-updates all PENDING/WAITING invoices whose due_date < now() and
    have not yet been soft-deleted.  For each affected invoice:
      - Status is set to OVERDUE.
      - An InvoiceActivity(INVOICE_OVERDUE) record is written.
      - An ``invoice.overdue`` Kafka event is emitted so the listener can
        dispatch the customer notification email.

    Returns:
        Dict with count of invoices transitioned.
    """
    logger.info("mark_overdue_invoices task starting")

    event_payloads = []

    with SessionCelery() as db:
        from sqlalchemy import select, update
        from src.apps.invoices.models.invoice import Invoice
        from src.apps.invoices import crud as invoice_crud

        now = datetime.now(timezone.utc)
        overdueable_values = [s.value for s in _OVERDUEABLE_STATUSES]

        # Fetch affected invoice rows (need merchant_id, literal, amount, due_date,
        # customer_id so we can build rich Kafka event payloads).
        stmt = select(
            Invoice.id,
            Invoice.merchant_id,
            Invoice.invoice_literal,
            Invoice.amount,
            Invoice.due_date,
            Invoice.customer_id,
        ).where(
            Invoice.due_date < now,
            Invoice.status.in_(overdueable_values),
            Invoice.deleted_at == None,
        )
        rows = db.execute(stmt).all()

        if not rows:
            logger.info("mark_overdue_invoices: no invoices to update")
            return {"updated": 0}

        affected_ids = [r[0] for r in rows]

        # Bulk update status
        update_stmt = (
            update(Invoice)
            .where(Invoice.id.in_(affected_ids))
            .values(
                status=InvoiceStatusTypes.OVERDUE,
                updated_at=now,
            )
        )
        db.execute(update_stmt)
        db.flush()

        # Write activity per invoice and collect event payloads
        for inv_id, merchant_id, invoice_literal, amount, due_date, customer_id in rows:
            invoice_crud.write_activity(
                db=db,
                invoice_id=inv_id,
                activity_type=InvoiceActivityTypes.INVOICE_OVERDUE,
                description="Invoice automatically marked overdue",
                actor_type="system",
                actor_id=None,
                metadata={"transitioned_at": now.isoformat()},
            )
            event_payloads.append({
                "invoice_id": inv_id,
                "merchant_id": merchant_id,
                "invoice_literal": invoice_literal,
                "amount": float(amount or 0),
                "due_date": due_date.isoformat() if due_date else None,
                "customer_id": customer_id,
            })

        logger.info("mark_overdue_invoices: marked %d invoices as OVERDUE", len(affected_ids))

    # Emit invoice.overdue Kafka events outside the DB session so that a
    # dispatcher failure never rolls back the status update already committed.
    if event_payloads:
        _emit_overdue_events(event_payloads)

    return {"updated": len(affected_ids), "invoice_ids": affected_ids}


def _emit_overdue_events(payloads: list) -> None:
    """Fire one ``invoice.overdue`` Kafka event per affected invoice."""
    try:
        from src.events.base import BaseEvent
        from src.events.dispatcher import EventDispatcher

        async def _dispatch_all():
            for payload in payloads:
                await EventDispatcher.dispatch(
                    BaseEvent(event_type="invoice.overdue", data=payload)
                )

        loop = asyncio.new_event_loop()
        try:
            loop.run_until_complete(_dispatch_all())
        finally:
            loop.close()

        logger.info("mark_overdue_invoices: emitted %d invoice.overdue events", len(payloads))
    except Exception as exc:
        logger.error("mark_overdue_invoices: failed to emit overdue events: %s", exc)
