"""
Subscriptions CRUD — pure DB operations, zero business logic.

All functions accept a SQLAlchemy Session and return ORM instances or plain
dicts. Callers own the transaction boundary (commit/rollback).
"""

from datetime import datetime, time, timezone
from typing import Dict, List, Optional, Tuple

from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session, joinedload, selectinload

from src.apps.subscriptions.models.subscription import Subscription
from src.apps.subscriptions.models.subscription_activity import SubscriptionActivity
from src.apps.subscriptions.enums import SubscriptionStatus
from src.apps.base.utils.functions import generate_secure_id


# ─── Subscription lookups ─────────────────────────────────────────────────────


def get_subscription_by_id(
    db: Session,
    subscription_id: str,
    merchant_id: int,
) -> Optional[Subscription]:
    """Fetch a single non-deleted subscription by opaque ID, scoped to merchant."""
    stmt = (
        select(Subscription)
        .where(
            Subscription.subscription_id == subscription_id,
            Subscription.merchant_id == merchant_id,
            Subscription.deleted_at == None,
        )
        .options(
            joinedload(Subscription.customer),
            joinedload(Subscription.payment_request),
            selectinload(Subscription.activities),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def get_subscription_by_literal(
    db: Session,
    literal: str,
    merchant_id: int,
) -> Optional[Subscription]:
    """Fetch a single non-deleted subscription by human-readable literal."""
    stmt = (
        select(Subscription)
        .where(
            Subscription.subscription_literal == literal,
            Subscription.merchant_id == merchant_id,
            Subscription.deleted_at == None,
        )
        .options(
            joinedload(Subscription.customer),
            joinedload(Subscription.payment_request),
            selectinload(Subscription.activities),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def get_subscription_by_id_or_literal(
    db: Session,
    sub_id: str,
    merchant_id: int,
) -> Optional[Subscription]:
    """
    Try subscription_id (opaque) first, fall back to subscription_literal.
    Both are scoped to merchant_id.
    """
    result = get_subscription_by_id(db, sub_id, merchant_id)
    if result is None:
        result = get_subscription_by_literal(db, sub_id, merchant_id)
    return result


def get_subscription_by_payment_request_id(
    db: Session,
    payment_request_id: int,
) -> Optional[Subscription]:
    """Fetch the subscription tied to a specific payment_request_id (no merchant scope)."""
    stmt = select(Subscription).where(
        Subscription.payment_request_id == payment_request_id,
        Subscription.deleted_at == None,
    )
    return db.execute(stmt).scalar_one_or_none()


def get_subscription_unscoped(
    db: Session,
    subscription_id: str,
) -> Optional[Subscription]:
    """Admin use only: fetch without merchant scope check."""
    stmt = (
        select(Subscription)
        .where(
            or_(
                Subscription.subscription_id == subscription_id,
                Subscription.subscription_literal == subscription_id,
            ),
            Subscription.deleted_at == None,
        )
        .options(
            joinedload(Subscription.customer),
            joinedload(Subscription.merchant),
            joinedload(Subscription.payment_request),
            selectinload(Subscription.activities),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


# ─── Paginated list ───────────────────────────────────────────────────────────


def list_subscriptions(
    db: Session,
    merchant_id: int,
    filters: dict,
    page: int = 1,
    per_page: int = 10,
) -> Tuple[List[Subscription], int]:
    """
    Paginated subscription list with optional filters.

    filters keys:
        status         — comma-separated int statuses, e.g. "200,201"
        customer_id    — customer opaque ID string or int FK
        search         — searches subscription_literal
        invoice_literal — filter by invoice literal
        created_after  — datetime lower bound
        created_before — datetime upper bound
        sort_by        — "created_at" | "next_billing_date" | "total_paid"
        auth_type      — "pre_auth" | "request_auth"
    """
    from src.apps.customers.models.customer import Customer
    from src.apps.invoices.models.invoice import Invoice
    from src.apps.payment_requests.models.payment_request import PaymentRequest as PR

    base = [
        Subscription.merchant_id == merchant_id,
        Subscription.deleted_at == None,
    ]

    # Auth type filter — join to payment_request to compare authorization_type.
    # IMPORTANT: do NOT add PR.authorization_type to base here.  Adding a column
    # from an unjoined table to the WHERE clause before the explicit JOIN is set up
    # causes SQLAlchemy to emit an implicit cartesian product: every subscription
    # is cross-joined with every matching payment_request row (across all
    # merchants), so the filter returns ALL subscriptions instead of filtering.
    # The clause is applied via .where() after .join(), matching the search pattern.
    auth_type_join_needed = False
    auth_type_clause = None
    auth_type = filters.get("auth_type")
    if auth_type:
        auth_type_join_needed = True
        auth_type_clause = PR.authorization_type == auth_type

    # Status filter
    status_raw = filters.get("status")
    if status_raw:
        try:
            statuses = [int(s.strip()) for s in str(status_raw).split(",") if s.strip()]
            if statuses:
                base.append(Subscription.status.in_(statuses))
        except ValueError:
            pass

    # Customer filter — accepts integer PK, opaque customer_id string, or account_literal
    customer_id_raw = filters.get("customer_id")
    if customer_id_raw:
        try:
            base.append(Subscription.customer_id == int(customer_id_raw))
        except (ValueError, TypeError):
            from src.apps.customers.models.customer import Customer as _Customer
            subq = (
                select(_Customer.id)
                .where(
                    or_(
                        _Customer.customer_id == str(customer_id_raw),
                        _Customer.account_literal == str(customer_id_raw),
                    ),
                    _Customer.deleted_at == None,
                )
            )
            base.append(Subscription.customer_id.in_(subq))

    # Date range — make naive datetimes timezone-aware (UTC) before comparing
    # against the timezone-aware created_at column.
    created_after = filters.get("created_after")
    if created_after:
        if isinstance(created_after, datetime) and created_after.tzinfo is None:
            created_after = created_after.replace(tzinfo=timezone.utc)
        base.append(Subscription.created_at >= created_after)

    created_before = filters.get("created_before")
    if created_before:
        if isinstance(created_before, datetime):
            # When only a date is supplied (no time part), include the full day.
            if created_before.hour == 0 and created_before.minute == 0 and created_before.second == 0:
                created_before = datetime.combine(
                    created_before.date(), time(23, 59, 59, 999999), tzinfo=timezone.utc
                )
            elif created_before.tzinfo is None:
                created_before = created_before.replace(tzinfo=timezone.utc)
        base.append(Subscription.created_at <= created_before)

    # Payer filter — filter by the contact's opaque contact_id string
    payer_literal = filters.get("payer_literal")
    if payer_literal:
        from src.apps.customers.models.customer_contact import CustomerContact
        subq = (
            select(CustomerContact.customer_id)
            .where(
                CustomerContact.contact_id == payer_literal,
                CustomerContact.deleted_at == None,
            )
        )
        base.append(Subscription.customer_id.in_(subq))

    # Search — subscription_literal OR customer first/last/business name.
    search = filters.get("search")
    search_join_needed = False
    search_clause = None
    if search:
        from src.apps.customers.models.customer import Customer
        like = f"%{search}%"
        search_clause = or_(
            Subscription.subscription_literal.ilike(like),
            Customer.first_name.ilike(like),
            Customer.last_name.ilike(like),
            Customer.business_legal_name.ilike(like),
        )
        search_join_needed = True

    count_stmt = select(func.count(Subscription.id)).where(*base)
    if auth_type_join_needed:
        count_stmt = count_stmt.join(PR, PR.id == Subscription.payment_request_id)
        if auth_type_clause is not None:
            count_stmt = count_stmt.where(auth_type_clause)
    if search_join_needed and search_clause is not None:
        from src.apps.customers.models.customer import Customer
        count_stmt = count_stmt.outerjoin(
            Customer, Customer.id == Subscription.customer_id
        ).where(search_clause)
    # Add invoice filter to count if needed
    invoice_literal = filters.get("invoice_literal")
    if invoice_literal:
        count_stmt = count_stmt.join(
            Invoice, Invoice.subscription_id == Subscription.id
        ).where(Invoice.invoice_literal == invoice_literal)
    total = db.execute(count_stmt).scalar_one()

    # Sorting
    sort_by = filters.get("sort_by", "created_at")
    sort_col = {
        "next_billing_date": Subscription.next_billing_date,
        "total_paid": Subscription.total_paid,
    }.get(sort_by, Subscription.created_at)

    # HPMNTP-966: use selectinload instead of joinedload to avoid LIMIT/JOIN
    # interaction — joinedload with LIMIT applies the limit to joined rows, not
    # distinct parent rows, which can return fewer unique subscriptions than
    # expected when associations exist.
    data_stmt = (
        select(Subscription)
        .where(*base)
        .options(
            selectinload(Subscription.customer),
            selectinload(Subscription.payment_request),
        )
        .order_by(sort_col.desc())
        .offset((page - 1) * per_page)
        .limit(per_page)
    )
    if auth_type_join_needed:
        data_stmt = data_stmt.join(PR, PR.id == Subscription.payment_request_id)
        if auth_type_clause is not None:
            data_stmt = data_stmt.where(auth_type_clause)
    if search_join_needed and search_clause is not None:
        from src.apps.customers.models.customer import Customer
        data_stmt = data_stmt.outerjoin(
            Customer, Customer.id == Subscription.customer_id
        ).where(search_clause)
    # Add invoice filter to data query if needed
    if invoice_literal:
        data_stmt = data_stmt.join(
            Invoice, Invoice.subscription_id == Subscription.id
        ).where(Invoice.invoice_literal == invoice_literal)
    items = db.execute(data_stmt).scalars().all()
    return list(items), total


# ─── Summary stats ────────────────────────────────────────────────────────────


def get_subscription_summary(db: Session, merchant_id: int, customer_id: Optional[int] = None, date_from: Optional[str] = None, date_to: Optional[str] = None) -> dict:
    """
    Return counts by status and cumulative totals for a merchant.
    MRR is computed by the service layer after normalisation per-interval.

    Uses a single aggregation query instead of 8+ separate COUNT queries.
    """
    from sqlalchemy import case

    where_clauses = [
        Subscription.merchant_id == merchant_id,
        Subscription.deleted_at.is_(None),
    ]
    if customer_id is not None:
        where_clauses.append(Subscription.customer_id == customer_id)
    if date_from:
        from datetime import datetime, timezone
        where_clauses.append(Subscription.created_at >= datetime.strptime(date_from, "%Y-%m-%d").replace(tzinfo=timezone.utc))
    if date_to:
        from datetime import datetime, timezone
        where_clauses.append(Subscription.created_at <= datetime.strptime(date_to, "%Y-%m-%d").replace(hour=23, minute=59, second=59, tzinfo=timezone.utc))

    row = db.execute(
        select(
            func.count(Subscription.id).label("total"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.ACTIVE, 1), else_=0)
            ).label("active_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.INITIALIZING, 1), else_=0)
            ).label("initializing_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.PAST_DUE, 1), else_=0)
            ).label("past_due_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.DUNNING_EXHAUSTED, 1), else_=0)
            ).label("dunning_exhausted_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.CANCELLED, 1), else_=0)
            ).label("cancelled_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.COMPLETED, 1), else_=0)
            ).label("completed_count"),
            func.sum(
                case((Subscription.status == SubscriptionStatus.PAUSED, 1), else_=0)
            ).label("paused_count"),
            func.coalesce(func.sum(Subscription.total_billed), 0.0).label("total_billed"),
            func.coalesce(func.sum(Subscription.total_paid), 0.0).label("total_collected"),
            func.coalesce(
                func.avg(
                    case(
                        (Subscription.status == SubscriptionStatus.ACTIVE, Subscription.total_billed),
                        else_=None,
                    )
                ),
                0.0,
            ).label("avg_subscription_value"),
        ).where(*where_clauses)
    ).first()

    return {
        "active_count": int(row.active_count or 0),
        "initializing_count": int(row.initializing_count or 0),
        "past_due_count": int(row.past_due_count or 0),
        "failed_count": int(row.dunning_exhausted_count or 0),
        "cancelled_count": int(row.cancelled_count or 0),
        "completed_count": int(row.completed_count or 0),
        "dunning_exhausted_count": int(row.dunning_exhausted_count or 0),
        "paused_count": int(row.paused_count or 0),
        "total_billed": float(row.total_billed or 0.0),
        "total_collected": float(row.total_collected or 0.0),
        "avg_subscription_value": float(row.avg_subscription_value or 0.0),
    }


# ─── Invoice queries ──────────────────────────────────────────────────────────


def get_subscription_invoices(db: Session, subscription_id: int):
    """Return all non-deleted invoices generated for this subscription, oldest first."""
    from src.apps.invoices.models.invoice import Invoice

    stmt = (
        select(Invoice)
        .where(
            Invoice.subscription_id == subscription_id,
            Invoice.deleted_at == None,
        )
        .order_by(Invoice.sequence_id.asc(), Invoice.created_at.asc())
    )
    return list(db.execute(stmt).scalars().all())


# ─── Export (unbounded list) ─────────────────────────────────────────────────


def list_subscriptions_for_export(
    db: Session,
    merchant_id: int,
    filters: dict,
    limit: int = 10000,
) -> List[Subscription]:
    """
    Return up to *limit* subscriptions matching *filters* — no pagination.
    Accepts the same filter dict as list_subscriptions().
    """
    from src.apps.customers.models.customer import Customer
    from src.apps.invoices.models.invoice import Invoice
    from src.apps.payment_requests.models.payment_request import PaymentRequest as PR

    base = [
        Subscription.merchant_id == merchant_id,
        Subscription.deleted_at == None,
    ]

    auth_type_join_needed = False
    auth_type_clause = None
    auth_type = filters.get("auth_type")
    if auth_type:
        auth_type_join_needed = True
        auth_type_clause = PR.authorization_type == auth_type

    status_raw = filters.get("status")
    if status_raw:
        try:
            statuses = [int(s.strip()) for s in str(status_raw).split(",") if s.strip()]
            if statuses:
                base.append(Subscription.status.in_(statuses))
        except ValueError:
            pass

    customer_id_raw = filters.get("customer_id")
    if customer_id_raw:
        try:
            base.append(Subscription.customer_id == int(customer_id_raw))
        except (ValueError, TypeError):
            from src.apps.customers.models.customer import Customer as _Customer
            subq = (
                select(_Customer.id)
                .where(
                    or_(
                        _Customer.customer_id == str(customer_id_raw),
                        _Customer.account_literal == str(customer_id_raw),
                    ),
                    _Customer.deleted_at == None,
                )
            )
            base.append(Subscription.customer_id.in_(subq))

    created_after = filters.get("created_after")
    if created_after:
        if isinstance(created_after, datetime) and created_after.tzinfo is None:
            created_after = created_after.replace(tzinfo=timezone.utc)
        base.append(Subscription.created_at >= created_after)

    created_before = filters.get("created_before")
    if created_before:
        if isinstance(created_before, datetime):
            if created_before.hour == 0 and created_before.minute == 0 and created_before.second == 0:
                created_before = datetime.combine(
                    created_before.date(), time(23, 59, 59, 999999), tzinfo=timezone.utc
                )
            elif created_before.tzinfo is None:
                created_before = created_before.replace(tzinfo=timezone.utc)
        base.append(Subscription.created_at <= created_before)

    payer_literal_exp = filters.get("payer_literal")
    if payer_literal_exp:
        from src.apps.customers.models.customer_contact import CustomerContact
        subq = (
            select(CustomerContact.customer_id)
            .where(
                CustomerContact.contact_id == payer_literal_exp,
                CustomerContact.deleted_at == None,
            )
        )
        base.append(Subscription.customer_id.in_(subq))

    search = filters.get("search")
    search_join_needed = False
    search_clause = None
    if search:
        like = f"%{search}%"
        search_clause = or_(
            Subscription.subscription_literal.ilike(like),
            Customer.first_name.ilike(like),
            Customer.last_name.ilike(like),
            Customer.business_legal_name.ilike(like),
        )
        search_join_needed = True

    stmt = (
        select(Subscription)
        .where(*base)
        .options(
            selectinload(Subscription.customer),
            selectinload(Subscription.payment_request),
        )
        .order_by(Subscription.created_at.desc())
        .limit(limit)
    )

    if auth_type_join_needed:
        stmt = stmt.join(PR, PR.id == Subscription.payment_request_id)
        if auth_type_clause is not None:
            stmt = stmt.where(auth_type_clause)
    if search_join_needed and search_clause is not None:
        stmt = stmt.outerjoin(Customer, Customer.id == Subscription.customer_id).where(search_clause)

    invoice_literal = filters.get("invoice_literal")
    if invoice_literal:
        stmt = stmt.join(Invoice, Invoice.subscription_id == Subscription.id).where(
            Invoice.invoice_literal == invoice_literal
        )

    return list(db.execute(stmt).scalars().all())


# ─── Literal / ID generation ─────────────────────────────────────────────────


def generate_subscription_literal(db: Session) -> str:
    """
    Generate next sequential subscription literal: SUB000001, SUB000002, …

    Thread-safe: uses MAX() + 1 inside the same transaction.
    """
    stmt = select(func.max(Subscription.subscription_literal)).where(
        Subscription.subscription_literal.like("SUB%")
    )
    max_literal = db.execute(stmt).scalar_one_or_none()

    if max_literal is None:
        next_seq = 1
    else:
        try:
            next_seq = int(max_literal[3:]) + 1
        except (ValueError, IndexError):
            max_id_stmt = select(func.max(Subscription.id))
            max_id = db.execute(max_id_stmt).scalar_one_or_none() or 0
            next_seq = max_id + 1

    return f"SUB{next_seq:06d}"


def generate_subscription_id() -> str:
    """Generate an opaque public identifier: sub_<16 random hex chars>."""
    return generate_secure_id(prepend="sub", length=16)


# ─── Activity ─────────────────────────────────────────────────────────────────


def write_activity(
    db: Session,
    subscription_id: int,
    activity_type: str,
    description: Optional[str] = None,
    actor_type: Optional[str] = None,
    actor_id: Optional[int] = None,
    metadata: Optional[dict] = None,
) -> SubscriptionActivity:
    """
    Append an activity record.

    Flushes immediately so the caller can read back the id within the
    same transaction. Does NOT commit — caller owns the boundary.
    """
    activity = SubscriptionActivity(
        subscription_id=subscription_id,
        activity_type=activity_type,
        description=description,
        actor_type=actor_type,
        actor_id=actor_id,
        metadata_=metadata,
    )
    db.add(activity)
    db.flush()
    return activity


def get_activities(
    db: Session,
    subscription_id: int,
    page: int = 1,
    per_page: int = 20,
) -> Tuple[List[SubscriptionActivity], int]:
    """Paginated activity trail for a subscription, newest first."""
    base = [SubscriptionActivity.subscription_id == subscription_id]

    count_stmt = select(func.count(SubscriptionActivity.id)).where(*base)
    total = db.execute(count_stmt).scalar_one()

    data_stmt = (
        select(SubscriptionActivity)
        .where(*base)
        .order_by(SubscriptionActivity.created_at.desc())
        .offset((page - 1) * per_page)
        .limit(per_page)
    )
    items = db.execute(data_stmt).scalars().all()
    return list(items), total
