"""
Celery task: update subscription statuses based on invoice states.

Task name: subscription.update_subscription_statuses
Schedule:  daily at 00:30 UTC (Celery Beat)

Transitions:
  - ACTIVE subscriptions whose latest invoice is OVERDUE → PAST_DUE
  - PAST_DUE subscriptions with past_due_since + 30 days < now() → DUNNING_EXHAUSTED
  - PAST_DUE subscriptions whose latest invoice is PAID and future next_billing → ACTIVE
  - Subscriptions at end condition → COMPLETED or EXPIRED
"""

import logging
from datetime import datetime, timedelta, timezone

from celery.utils.log import get_task_logger

from src.worker.celery_app import celery_app
from src.core.database import SessionCelery

logger = get_task_logger(__name__)


@celery_app.task(name="subscription.update_subscription_statuses", bind=True, max_retries=3)
def update_subscription_statuses(self) -> dict:
    """
    Idempotent daily task at 00:30 UTC.

    Returns:
        Dict with counts of transitions performed.
    """
    logger.info("update_subscription_statuses task starting")
    now = datetime.now(timezone.utc)
    past_due_threshold = now - timedelta(days=30)

    results = {
        "active_to_past_due": 0,
        "past_due_to_active": 0,
        "past_due_to_exhausted": 0,
        "completed": 0,
    }

    try:
        from sqlalchemy import select, func
        from src.apps.subscriptions.models.subscription import Subscription
        from src.apps.subscriptions.enums import SubscriptionStatus, SubscriptionActivityTypes
        from src.apps.subscriptions import crud as sub_crud
        from src.apps.invoices.models.invoice import Invoice
        from src.apps.payment_requests.models.payment_request import PaymentRequest
        from src.apps.payment_requests.models.recurring_payment_request import RecurringPaymentRequests
        from src.core.utils.enums import InvoiceStatusTypes

        # VULN-013: Process all queries in batches of 500 to avoid loading the
        # entire subscription table into memory on large deployments.
        BATCH_SIZE = 500

        with SessionCelery() as db:
            # 0a. Rule 1 — DATE-based end: end_date passed and all invoices paid → EXPIRED
            date_end_stmt = (
                select(Subscription)
                .join(PaymentRequest, PaymentRequest.id == Subscription.payment_request_id)
                .join(RecurringPaymentRequests, RecurringPaymentRequests.payment_request_id == PaymentRequest.id)
                .where(
                    RecurringPaymentRequests.end_type == "date",
                    RecurringPaymentRequests.end_date < now,
                    Subscription.status.in_([SubscriptionStatus.ACTIVE, SubscriptionStatus.PAST_DUE]),
                    Subscription.deleted_at == None,
                )
            )
            offset = 0
            while True:
                batch = db.execute(date_end_stmt.offset(offset).limit(BATCH_SIZE)).scalars().all()
                if not batch:
                    break
                for sub in batch:
                    # Only expire when all generated invoices have been paid
                    total = sub.invoices_generated or 0
                    if total == 0:
                        offset += 1
                        continue
                    paid = db.execute(
                        select(func.count(Invoice.id)).where(
                            Invoice.subscription_id == sub.id,
                            Invoice.status == InvoiceStatusTypes.PAID,
                            Invoice.deleted_at == None,
                        )
                    ).scalar() or 0
                    if paid >= total:
                        sub.status = SubscriptionStatus.EXPIRED
                        sub.updated_at = now
                        db.flush()
                        sub_crud.write_activity(
                            db=db,
                            subscription_id=sub.id,
                            activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                            description="Subscription expired: end date reached and all invoices paid",
                            actor_type="system",
                            actor_id=None,
                            metadata={"old_status": "active", "new_status": "expired", "reason": "end_date_passed"},
                        )
                        results["completed"] += 1
                offset += BATCH_SIZE

            # 0b. Rule 2 — UNTIL_COUNT end: invoices_paid >= pay_until_count → COMPLETED
            count_end_stmt = (
                select(Subscription)
                .join(PaymentRequest, PaymentRequest.id == Subscription.payment_request_id)
                .join(RecurringPaymentRequests, RecurringPaymentRequests.payment_request_id == PaymentRequest.id)
                .where(
                    RecurringPaymentRequests.end_type == "until_count",
                    Subscription.status.in_([SubscriptionStatus.ACTIVE, SubscriptionStatus.PAST_DUE]),
                    Subscription.deleted_at == None,
                )
            )
            offset = 0
            while True:
                batch = db.execute(count_end_stmt.offset(offset).limit(BATCH_SIZE)).scalars().all()
                if not batch:
                    break
                for sub in batch:
                    # Fetch pay_until_count via the joined rec config
                    rec_row = db.execute(
                        select(RecurringPaymentRequests).where(
                            RecurringPaymentRequests.payment_request_id == sub.payment_request_id
                        )
                    ).scalar_one_or_none()
                    if rec_row and rec_row.pay_until_count and (sub.invoices_paid or 0) >= rec_row.pay_until_count:
                        sub.status = SubscriptionStatus.COMPLETED
                        sub.updated_at = now
                        db.flush()
                        sub_crud.write_activity(
                            db=db,
                            subscription_id=sub.id,
                            activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                            description="Subscription completed: all billing cycles paid",
                            actor_type="system",
                            actor_id=None,
                            metadata={
                                "old_status": "active",
                                "new_status": "completed",
                                "reason": "until_count_reached",
                                "invoices_paid": sub.invoices_paid,
                                "pay_until_count": rec_row.pay_until_count,
                            },
                        )
                        results["completed"] += 1
                offset += BATCH_SIZE

            # 1. ACTIVE → PAST_DUE: latest invoice is OVERDUE
            active_stmt = select(Subscription).where(
                Subscription.status == SubscriptionStatus.ACTIVE,
                Subscription.deleted_at == None,
            )
            offset = 0
            while True:
                batch = db.execute(
                    active_stmt.offset(offset).limit(BATCH_SIZE)
                ).scalars().all()
                if not batch:
                    break
                for sub in batch:
                    latest_inv = _get_latest_invoice(db, sub.id)
                    if latest_inv and latest_inv.status == InvoiceStatusTypes.OVERDUE:
                        sub.status = SubscriptionStatus.PAST_DUE
                        sub.past_due_since = now
                        sub.updated_at = now
                        db.flush()
                        sub_crud.write_activity(
                            db=db,
                            subscription_id=sub.id,
                            activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                            description="Subscription moved to PAST_DUE due to overdue invoice",
                            actor_type="system",
                            actor_id=None,
                        )
                        results["active_to_past_due"] += 1
                offset += BATCH_SIZE

            # 2. PAST_DUE → DUNNING_EXHAUSTED: past_due_since + 30 days exceeded
            past_due_stmt = select(Subscription).where(
                Subscription.status == SubscriptionStatus.PAST_DUE,
                Subscription.past_due_since <= past_due_threshold,
                Subscription.deleted_at == None,
            )
            offset = 0
            while True:
                batch = db.execute(
                    past_due_stmt.offset(offset).limit(BATCH_SIZE)
                ).scalars().all()
                if not batch:
                    break
                for sub in batch:
                    # Only exhaust if dunning retries also exceeded
                    if (sub.dunning_retry_count or 0) >= 3:
                        sub.status = SubscriptionStatus.DUNNING_EXHAUSTED
                        sub.dunning_exhausted_at = now
                        sub.updated_at = now
                        db.flush()
                        sub_crud.write_activity(
                            db=db,
                            subscription_id=sub.id,
                            activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                            description="Dunning exhausted: subscription suspended after 30 days past due",
                            actor_type="system",
                            actor_id=None,
                        )
                        results["past_due_to_exhausted"] += 1
                offset += BATCH_SIZE

            # 3. PAST_DUE → ACTIVE: latest invoice PAID and future next_billing
            still_past_due_stmt = select(Subscription).where(
                Subscription.status == SubscriptionStatus.PAST_DUE,
                Subscription.deleted_at == None,
            )
            offset = 0
            while True:
                batch = db.execute(
                    still_past_due_stmt.offset(offset).limit(BATCH_SIZE)
                ).scalars().all()
                if not batch:
                    break
                for sub in batch:
                    latest_inv = _get_latest_invoice(db, sub.id)
                    if (
                        latest_inv
                        and latest_inv.status in (InvoiceStatusTypes.PAID, InvoiceStatusTypes.PARTIALLY_PAID)
                        and sub.next_billing_date
                        and sub.next_billing_date > now
                    ):
                        sub.status = SubscriptionStatus.ACTIVE
                        sub.past_due_since = None
                        sub.dunning_retry_count = 0
                        sub.dunning_next_retry_at = None
                        sub.updated_at = now
                        db.flush()
                        sub_crud.write_activity(
                            db=db,
                            subscription_id=sub.id,
                            activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                            description="Subscription restored to ACTIVE after successful payment",
                            actor_type="system",
                            actor_id=None,
                        )
                        results["past_due_to_active"] += 1
                offset += BATCH_SIZE

    except Exception as exc:
        logger.error("update_subscription_statuses: fatal error: %s", exc, exc_info=True)
        raise self.retry(exc=exc, countdown=300)

    logger.info("update_subscription_statuses completed: %s", results)
    return results


def _get_latest_invoice(db, subscription_id: int):
    """Return the most recently created non-deleted invoice for a subscription."""
    from sqlalchemy import select
    from src.apps.invoices.models.invoice import Invoice

    stmt = (
        select(Invoice)
        .where(
            Invoice.subscription_id == subscription_id,
            Invoice.deleted_at == None,
        )
        .order_by(Invoice.sequence_id.desc(), Invoice.created_at.desc())
        .limit(1)
    )
    return db.execute(stmt).scalar_one_or_none()
