"""
Celery task: generate upcoming subscription invoices.

Task name: subscription.generate_upcoming_invoices
Schedule:  daily at 01:00 UTC (Celery Beat)

Queries subscriptions where:
  - status IN (INITIALIZING, ACTIVE, PAST_DUE)
  - next_billing_date <= now() + 7 days

For each, creates an Invoice via create_invoice() if none already exists
for that sequence_id (idempotency guard).  Advances next_billing_date
and handles end conditions.
"""

import logging
from datetime import datetime, timedelta, timezone

from celery.utils.log import get_task_logger

from src.worker.celery_app import celery_app
from src.core.database import SessionCelery

logger = get_task_logger(__name__)

_BILLABLE_STATUSES = [100, 200, 201]  # INITIALIZING, ACTIVE, PAST_DUE


@celery_app.task(name="subscription.generate_upcoming_invoices", bind=True, max_retries=3)
def generate_upcoming_invoices(self) -> dict:
    """
    Idempotent daily task at 01:00 UTC.

    Generates invoices for subscriptions whose next_billing_date falls
    within the next 7 days and no invoice already exists for that sequence.

    Returns:
        Dict with counts of subscriptions processed and invoices created.
    """
    logger.info("generate_upcoming_invoices task starting")

    horizon = datetime.now(timezone.utc) + timedelta(days=7)
    processed = 0
    created = 0
    errors = 0

    try:
        from sqlalchemy import select
        from src.apps.subscriptions.models.subscription import Subscription

        # Fetch subscription IDs in a dedicated session, then process each
        # in its own isolated session to prevent one failure rolling back others.
        with SessionCelery() as db:
            stmt = (
                select(Subscription.id)
                .where(
                    Subscription.status.in_(_BILLABLE_STATUSES),
                    Subscription.next_billing_date <= horizon,
                    Subscription.deleted_at == None,
                )
            )
            sub_ids = list(db.execute(stmt).scalars().all())

    except Exception as exc:
        logger.error("generate_upcoming_invoices: fatal error fetching subscriptions: %s", exc, exc_info=True)
        raise self.retry(exc=exc, countdown=300)

    for sub_id in sub_ids:
        try:
            with SessionCelery() as sub_db:
                from sqlalchemy import select as _select
                from sqlalchemy.orm import joinedload as _joinedload
                from src.apps.subscriptions.models.subscription import Subscription
                from src.apps.payment_requests.models.recurring_payment_request import RecurringPaymentRequests as _Rec
                from src.apps.payment_requests.models.payment_request import PaymentRequest as _PR
                _stmt = (
                    _select(Subscription)
                    .where(Subscription.id == sub_id)
                    .options(
                        _joinedload(Subscription.payment_request).joinedload(_PR.recurring_config)
                    )
                )
                subscription = sub_db.execute(_stmt).unique().scalar_one_or_none()
                if not subscription:
                    continue
                created_ref_container = [0]
                _generate_invoice_for_subscription(sub_db, subscription, created_ref_container)
                sub_db.commit()
                created += created_ref_container[0]
                processed += 1
        except Exception as sub_err:
            logger.error(
                "generate_upcoming_invoices: error for subscription id=%s: %s",
                sub_id,
                sub_err,
                exc_info=True,
            )
            errors += 1

    logger.info(
        "generate_upcoming_invoices: processed=%d created=%d errors=%d",
        processed, created, errors
    )
    return {"processed": processed, "created": created, "errors": errors}


def _generate_invoice_for_subscription(db, subscription, created_ref: list) -> None:
    """Generate the next invoice for a subscription if not already generated."""
    from sqlalchemy import select
    from src.apps.subscriptions.models.subscription import Subscription
    from src.apps.subscriptions.enums import SubscriptionStatus, SubscriptionActivityTypes
    from src.apps.subscriptions.helpers.billing_date import compute_next_billing_date
    from src.apps.subscriptions import crud as sub_crud
    from src.apps.invoices.models.invoice import Invoice
    from src.apps.invoices.services.invoice_services import create_invoice
    from src.core.utils.enums import InvoiceStatusTypes
    from src.apps.payment_requests.models.recurring_payment_request import RecurringPaymentRequests

    pr = subscription.payment_request
    if pr is None:
        return

    # Load recurring config
    rec_stmt = select(RecurringPaymentRequests).where(
        RecurringPaymentRequests.payment_request_id == pr.id
    )
    rec = db.execute(rec_stmt).scalar_one_or_none()
    if rec is None:
        return

    # Check end conditions
    now = datetime.now(timezone.utc)
    next_billing = subscription.next_billing_date
    if next_billing and next_billing.tzinfo is None:
        next_billing = next_billing.replace(tzinfo=timezone.utc)

    # End-type checks
    end_type = getattr(rec, "end_type", "until_cancelled")
    if end_type == "date":
        end_date = getattr(rec, "end_date", None)
        if end_date:
            if end_date.tzinfo is None:
                end_date = end_date.replace(tzinfo=timezone.utc)
            if next_billing and next_billing > end_date:
                # Past end date — expire the subscription (date-based end)
                subscription.status = SubscriptionStatus.EXPIRED
                sub_crud.write_activity(
                    db=db,
                    subscription_id=subscription.id,
                    activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                    description="Subscription expired: end date reached",
                    actor_type="system",
                    actor_id=None,
                )
                return

    elif end_type == "until_count":
        pay_until_count = getattr(rec, "pay_until_count", None)
        if pay_until_count and subscription.invoices_generated >= pay_until_count:
            subscription.status = SubscriptionStatus.COMPLETED
            sub_crud.write_activity(
                db=db,
                subscription_id=subscription.id,
                activity_type=SubscriptionActivityTypes.STATUS_AUTO_UPDATED,
                description="Subscription completed: all billing cycles generated",
                actor_type="system",
                actor_id=None,
            )
            return

    # Sequence for this invoice
    next_sequence = (subscription.invoices_generated or 0) + 1

    # Idempotency guard: check if invoice already exists for this sequence
    existing_stmt = select(Invoice).where(
        Invoice.subscription_id == subscription.id,
        Invoice.sequence_id == next_sequence,
        Invoice.deleted_at == None,
    )
    existing = db.execute(existing_stmt).scalar_one_or_none()
    if existing:
        logger.debug(
            "Invoice already exists for subscription %s sequence %d, skipping",
            subscription.subscription_literal, next_sequence
        )
        return

    # Determine amount
    amount = pr.amount

    # Handle prorate for sequence 1
    if next_sequence == 1 and getattr(rec, "prorate_first_payment", False):
        prorate_amount = getattr(rec, "prorate_amount", None)
        if prorate_amount:
            amount = float(prorate_amount)
        # Use prorate_date as the due_date for the first prorated invoice
        prorate_date = getattr(rec, "prorate_date", None)
        if prorate_date:
            if prorate_date.tzinfo is None:
                prorate_date = prorate_date.replace(tzinfo=timezone.utc)
            next_billing = prorate_date

    # Create invoice
    due_date = next_billing
    billing_date = now

    invoice = create_invoice(
        db=db,
        payment_request=pr,
        merchant_id=subscription.merchant_id,
        customer_id=subscription.customer_id or (pr.customer_id if hasattr(pr, "customer_id") else None),
        amount=amount,
        status=InvoiceStatusTypes.PENDING,
        due_date=due_date,
        billing_date=billing_date,
        sequence_id=next_sequence,
    )

    # Link invoice to subscription
    invoice.subscription_id = subscription.id
    db.flush()

    # Write invoice activity
    from src.apps.invoices import crud as invoice_crud
    from src.core.utils.enums import InvoiceActivityTypes
    invoice_crud.write_activity(
        db=db,
        invoice_id=invoice.id,
        activity_type=InvoiceActivityTypes.INVOICE_CREATED,
        description=f"Invoice generated for billing cycle {next_sequence}",
        actor_type="system",
        metadata={"subscription_literal": subscription.subscription_literal, "sequence": next_sequence},
    )

    # Update subscription counters
    subscription.invoices_generated = next_sequence
    subscription.total_billed = (subscription.total_billed or 0.0) + amount

    # Advance next_billing_date
    interval = getattr(rec, "interval", "month")
    interval_value = getattr(rec, "interval_value", 1)
    new_next = compute_next_billing_date(interval, interval_value, next_billing)
    subscription.next_billing_date = new_next

    # Transition INITIALIZING → ACTIVE on first invoice
    if subscription.status == SubscriptionStatus.INITIALIZING:
        subscription.status = SubscriptionStatus.ACTIVE

    subscription.updated_at = now
    db.flush()

    # ── Create a pending Transaction and schedule the charge ─────────────────
    _schedule_payment_for_invoice(db, subscription, pr, invoice, amount, due_date, next_sequence)

    sub_crud.write_activity(
        db=db,
        subscription_id=subscription.id,
        activity_type=SubscriptionActivityTypes.INVOICE_AUTO_GENERATED,
        description=f"Invoice {invoice.invoice_literal} generated for cycle {next_sequence}",
        actor_type="system",
        actor_id=None,
        metadata={
            "invoice_id": invoice.invoice_id,
            "invoice_literal": invoice.invoice_literal,
            "sequence": next_sequence,
            "amount": amount,
            "due_date": due_date.isoformat() if due_date else None,
        },
    )

    created_ref[0] += 1
    logger.info(
        "Generated invoice %s (sequence %d) for subscription %s",
        invoice.invoice_literal,
        next_sequence,
        subscription.subscription_literal,
    )


def _schedule_payment_for_invoice(db, subscription, pr, invoice, amount, due_date, sequence) -> None:
    """
    Create a PENDING Transaction for the invoice and enqueue process_scheduled_payment
    at the invoice's due_date.  Stores the Celery task ID in txn_metadata so that
    cancel/pause can revoke the task before it fires.
    """
    import uuid as _uuid
    from src.apps.transactions.models.transactions import Transactions, transactions_invoices_map
    from src.apps.transactions.services import generate_txn_literal
    from src.core.utils.enums import TransactionStatusTypes, TransactionCategories, TransactionTypes

    # Resolve payment method — use the first active method on the payment request
    payment_method_id = None
    if pr.payment_methods:
        payment_method_id = pr.payment_methods[0].id

    pending_txn = Transactions(
        txn_id=f"sub_{subscription.subscription_id}_{sequence}_{_uuid.uuid4().hex[:8]}",
        txn_literal=generate_txn_literal(db),
        txn_amount=amount,
        txn_status=TransactionStatusTypes.PENDING,
        payment_request_id=pr.id,
        merchant_id=subscription.merchant_id,
        customer_id=invoice.customer_id,
        payment_method_id=payment_method_id,
        transaction_type=TransactionTypes.SUBSCRIPTION,
        category=TransactionCategories.CHARGE,
        txn_metadata={
            "subscription_id": subscription.subscription_id,
            "sequence": sequence,
            "invoice_id": invoice.invoice_id,
        },
    )
    db.add(pending_txn)
    db.flush()  # populate pending_txn.id

    # Link transaction ↔ invoice via association table
    db.execute(
        transactions_invoices_map.insert().values(
            transaction_id=pending_txn.id,
            invoice_id=invoice.id,
        )
    )
    db.flush()

    if payment_method_id is None:
        logger.warning(
            "_schedule_payment_for_invoice: no payment method for subscription %s "
            "sequence %d — transaction %s created but payment NOT scheduled",
            subscription.subscription_literal,
            sequence,
            pending_txn.txn_id,
        )
        return

    if due_date is None:
        logger.warning(
            "_schedule_payment_for_invoice: no due_date for subscription %s "
            "sequence %d — payment NOT scheduled",
            subscription.subscription_literal,
            sequence,
        )
        return

    try:
        from src.worker.hpp_tasks import process_scheduled_payment
        task = process_scheduled_payment.apply_async(
            args=[pending_txn.id, payment_method_id],
            eta=due_date,
        )
        # Store task ID so cancel/pause can revoke it before it fires
        pending_txn.txn_metadata = {
            **(pending_txn.txn_metadata or {}),
            "celery_task_id": task.id,
        }
        db.flush()
        logger.info(
            "_schedule_payment_for_invoice: scheduled task %s for txn %s at %s",
            task.id,
            pending_txn.txn_id,
            due_date.isoformat(),
        )
    except Exception as task_err:
        logger.error(
            "_schedule_payment_for_invoice: failed to schedule payment for txn %s: %s",
            pending_txn.txn_id,
            task_err,
            exc_info=True,
        )
