"""
Subscription event listener.

Listens to transaction.completed and transaction.failed events.
When a completed/failed transaction is linked to a subscription invoice:

  - transaction.completed:
      * Increment subscription.invoices_paid, total_paid
      * Reset dunning state
      * Write CHARGE_SUCCEEDED activity
      * If all cycles done → mark COMPLETED

  - transaction.failed:
      * Increment dunning counter
      * Schedule retry Celery task
      * Write CHARGE_FAILED activity
      * On first failure: set past_due_since, move to PAST_DUE
      * On max retries exceeded: move to DUNNING_EXHAUSTED

Activity records are written synchronously here — same pattern as the
invoice listener.  If no Kafka/RabbitMQ consumer is wired, the listen()
function is a no-op.
"""

from __future__ import annotations

import logging
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Optional

logger = logging.getLogger(__name__)

MAX_DUNNING_RETRIES = 3
DUNNING_RETRY_DAYS = [3, 7]  # retry 1 at T+3 days, retry 2 at T+7 days (PRD HWSUB-114)


def handle_transaction_completed(event: Any) -> None:
    """
    Process a transaction.completed event for subscription invoices.

    event keys expected (in event.data):
        transaction_id  (int)
        invoice_id      (str)  — opaque invoice ID (optional)
        txn_amount      (float)
        merchant_id     (int)
        subscription_id (str, optional) — if known by emitter
    """
    payload: Dict[str, Any] = event.data if hasattr(event, "data") else event
    transaction_id = payload.get("transaction_id")
    invoice_id_str = payload.get("invoice_id")
    txn_amount = float(payload.get("txn_amount", 0.0))
    event_merchant_id = payload.get("merchant_id")

    if not transaction_id:
        logger.debug("handle_transaction_completed: no transaction_id in event, skipping")
        return

    # VULN-009: Validate txn_amount before processing
    if txn_amount <= 0:
        logger.warning(
            "handle_transaction_completed: invalid txn_amount=%.4f for transaction_id=%s, skipping",
            txn_amount,
            transaction_id,
        )
        return

    try:
        from src.core.database import SessionCelery
        from src.apps.invoices.models.invoice import Invoice
        from src.apps.subscriptions.models.subscription import Subscription
        from src.apps.subscriptions.enums import SubscriptionStatus, SubscriptionActivityTypes
        from src.apps.subscriptions import crud as sub_crud
        from src.core.utils.enums import InvoiceStatusTypes
        from sqlalchemy import select

        with SessionCelery() as db:
            # Find the invoice
            invoice = None
            if invoice_id_str:
                stmt = select(Invoice).where(
                    Invoice.invoice_id == invoice_id_str,
                    Invoice.deleted_at == None,
                )
                invoice = db.execute(stmt).scalar_one_or_none()

            if invoice is None or invoice.subscription_id is None:
                return  # Not a subscription invoice

            # VULN-010: Validate merchant scope before processing
            if event_merchant_id is not None and invoice.merchant_id != int(event_merchant_id):
                logger.error(
                    "handle_transaction_completed: merchant_id mismatch for invoice %s "
                    "(event merchant_id=%s, invoice merchant_id=%s) — skipping",
                    invoice_id_str,
                    event_merchant_id,
                    invoice.merchant_id,
                )
                return

            # Load the subscription
            sub_stmt = select(Subscription).where(Subscription.id == invoice.subscription_id)
            subscription = db.execute(sub_stmt).scalar_one_or_none()
            if subscription is None:
                return

            # Update invoice to PAID
            invoice.status = InvoiceStatusTypes.PAID
            invoice.paid_amount = txn_amount
            invoice.paid_date = datetime.now(timezone.utc)
            db.flush()

            # Update subscription running totals
            # total_billed is already incremented by the generate_invoices Celery task
            # when the invoice is created; do not modify it here.
            subscription.invoices_paid = (subscription.invoices_paid or 0) + 1
            subscription.total_paid = (subscription.total_paid or 0.0) + txn_amount

            # Reset dunning
            subscription.dunning_retry_count = 0
            subscription.dunning_last_retry_at = None
            subscription.dunning_next_retry_at = None
            subscription.past_due_since = None

            # Transition back to ACTIVE if PAST_DUE
            if subscription.status == SubscriptionStatus.PAST_DUE:
                subscription.status = SubscriptionStatus.ACTIVE

            # Check if subscription is complete
            _check_completion(db, subscription, invoice)

            subscription.updated_at = datetime.now(timezone.utc)
            db.flush()

            sub_crud.write_activity(
                db=db,
                subscription_id=subscription.id,
                activity_type=SubscriptionActivityTypes.CHARGE_SUCCEEDED,
                description=f"Payment of ${txn_amount:.2f} collected for invoice {invoice.invoice_literal}",
                actor_type="system",
                actor_id=None,
                metadata={
                    "transaction_id": transaction_id,
                    "invoice_id": invoice_id_str,
                    "amount": txn_amount,
                },
            )

            logger.info(
                "handle_transaction_completed: subscription %s invoice %s paid",
                subscription.subscription_literal,
                invoice.invoice_literal,
            )

    except Exception as exc:
        logger.error(
            "handle_transaction_completed: error processing transaction_id=%s: %s",
            transaction_id,
            exc,
            exc_info=True,
        )


def handle_transaction_failed(event: Any) -> None:
    """
    Process a transaction.failed event for subscription invoices.

    event keys expected (in event.data):
        transaction_id  (int)
        invoice_id      (str, optional)
        txn_amount      (float)
        merchant_id     (int)
        error_message   (str, optional)
    """
    payload: Dict[str, Any] = event.data if hasattr(event, "data") else event
    transaction_id = payload.get("transaction_id")
    invoice_id_str = payload.get("invoice_id")
    error_message = payload.get("error_message", "Payment failed")

    if not transaction_id:
        return

    try:
        from src.core.database import SessionCelery
        from src.apps.invoices.models.invoice import Invoice
        from src.apps.subscriptions.models.subscription import Subscription
        from src.apps.subscriptions.enums import SubscriptionStatus, SubscriptionActivityTypes
        from src.apps.subscriptions import crud as sub_crud
        from src.core.utils.enums import InvoiceStatusTypes
        from sqlalchemy import select

        with SessionCelery() as db:
            invoice = None
            if invoice_id_str:
                stmt = select(Invoice).where(
                    Invoice.invoice_id == invoice_id_str,
                    Invoice.deleted_at == None,
                )
                invoice = db.execute(stmt).scalar_one_or_none()

            if invoice is None or invoice.subscription_id is None:
                return

            sub_stmt = select(Subscription).where(Subscription.id == invoice.subscription_id)
            subscription = db.execute(sub_stmt).scalar_one_or_none()
            if subscription is None:
                return

            # Mark invoice as FAILED
            invoice.status = InvoiceStatusTypes.FAILED
            invoice.updated_at = datetime.now(timezone.utc)
            db.flush()

            now = datetime.now(timezone.utc)
            retry_count = (subscription.dunning_retry_count or 0) + 1
            subscription.dunning_retry_count = retry_count
            subscription.dunning_last_retry_at = now

            if retry_count == 1:
                # First failure: move to PAST_DUE
                subscription.past_due_since = now
                if subscription.status == SubscriptionStatus.ACTIVE:
                    subscription.status = SubscriptionStatus.PAST_DUE

            if retry_count >= MAX_DUNNING_RETRIES:
                # Exhaust dunning
                subscription.status = SubscriptionStatus.DUNNING_EXHAUSTED
                subscription.dunning_exhausted_at = now
                subscription.dunning_next_retry_at = None

                sub_crud.write_activity(
                    db=db,
                    subscription_id=subscription.id,
                    activity_type=SubscriptionActivityTypes.DUNNING_EXHAUSTED,
                    description="Maximum dunning retries exceeded. Subscription suspended.",
                    actor_type="system",
                    actor_id=None,
                    metadata={"retry_count": retry_count},
                )
            else:
                # Schedule next retry
                delay_days = DUNNING_RETRY_DAYS[min(retry_count - 1, len(DUNNING_RETRY_DAYS) - 1)]
                next_retry_at = now + timedelta(days=delay_days)
                subscription.dunning_next_retry_at = next_retry_at

                # Schedule Celery retry task
                try:
                    from src.apps.subscriptions.tasks.retry_payment import retry_subscription_payment
                    retry_subscription_payment.apply_async(
                        kwargs={
                            "subscription_id": subscription.subscription_id,
                            "invoice_id": invoice_id_str,
                        },
                        eta=next_retry_at,
                    )
                except Exception as task_err:
                    logger.warning("Could not schedule retry task: %s", task_err)

                sub_crud.write_activity(
                    db=db,
                    subscription_id=subscription.id,
                    activity_type=SubscriptionActivityTypes.DUNNING_RETRY_SCHEDULED,
                    description=f"Retry #{retry_count} scheduled for {next_retry_at.strftime('%Y-%m-%d')}",
                    actor_type="system",
                    actor_id=None,
                    metadata={
                        "retry_number": retry_count,
                        "next_retry_at": next_retry_at.isoformat(),
                    },
                )

            sub_crud.write_activity(
                db=db,
                subscription_id=subscription.id,
                activity_type=SubscriptionActivityTypes.CHARGE_FAILED,
                description=f"Payment failed: {error_message}",
                actor_type="system",
                actor_id=None,
                metadata={
                    "transaction_id": transaction_id,
                    "invoice_id": invoice_id_str,
                    "error_message": error_message,
                    "dunning_retry_count": retry_count,
                },
            )

            subscription.updated_at = now
            db.flush()

            logger.info(
                "handle_transaction_failed: subscription %s dunning retry=%d",
                subscription.subscription_literal,
                retry_count,
            )

    except Exception as exc:
        logger.error(
            "handle_transaction_failed: error processing transaction_id=%s: %s",
            transaction_id,
            exc,
            exc_info=True,
        )


def _check_completion(db, subscription, last_invoice) -> None:
    """Transition subscription to COMPLETED if all billing cycles are done."""
    from src.apps.subscriptions.enums import SubscriptionStatus, SubscriptionActivityTypes
    from src.apps.subscriptions import crud as sub_crud

    pr = getattr(subscription, "payment_request", None)
    if pr is None:
        return

    rec = getattr(pr, "recurring_config", None)
    if isinstance(rec, list):
        rec = rec[0] if rec else None
    if rec is None:
        return

    end_type = getattr(rec, "end_type", None)
    now = datetime.now(timezone.utc)

    should_complete = False
    if end_type == "until_count":
        pay_until_count = getattr(rec, "pay_until_count", None)
        if pay_until_count and subscription.invoices_paid >= pay_until_count:
            should_complete = True
    elif end_type == "date":
        end_date = getattr(rec, "end_date", None)
        if end_date and now > end_date and subscription.next_billing_date and subscription.next_billing_date > end_date:
            should_complete = True

    if should_complete:
        subscription.status = SubscriptionStatus.COMPLETED
        sub_crud.write_activity(
            db=db,
            subscription_id=subscription.id,
            activity_type=SubscriptionActivityTypes.COMPLETED,
            description="All billing cycles completed",
            actor_type="system",
            actor_id=None,
            metadata={
                "invoices_paid": subscription.invoices_paid,
                "total_paid": subscription.total_paid,
            },
        )


def listen() -> None:
    """
    Register this listener with the event infrastructure.

    Registers handle_transaction_completed and handle_transaction_failed
    as synchronous listeners on the global EventDispatcher so that both the
    Kafka consumer path and the local-fallback path fire these handlers.
    """
    try:
        from src.events.dispatcher import EventDispatcher

        EventDispatcher.register("transaction.completed", handle_transaction_completed)
        EventDispatcher.register("transaction.failed", handle_transaction_failed)

        logger.info(
            "Subscription event listener registered for: "
            "transaction.completed, transaction.failed"
        )
    except Exception as exc:
        logger.error(
            "Subscription event listener: failed to register with EventDispatcher: %s",
            exc,
            exc_info=True,
        )


# Auto-register when this module is imported by the event listener discovery system.
listen()
