"""
Checkout CRUD — pure DB operations, zero business logic.
All callers own the transaction boundary (commit/rollback).
"""
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Tuple

from sqlalchemy import func, select, update
from sqlalchemy.orm import Session, selectinload, joinedload

from src.apps.checkouts.models.checkout import Checkout
from src.apps.checkouts.models.checkout_link import CheckoutLink
from src.apps.checkouts.models.checkout_line_item import CheckoutLineItem
from src.apps.checkouts.models.checkout_activity import CheckoutActivity
from src.apps.checkouts.models.checkout_settings import CheckoutSettings
from src.apps.customers.models.customer import Customer
from src.apps.payment_requests.models.payment_request import PaymentRequest
from src.apps.transactions.models.transactions import Transactions
from src.apps.users.models.user import User
from src.core.utils.enums import TransactionStatusTypes


# ─── Literal generation ───────────────────────────────────────────────────────

def generate_checkout_literal(db: Session) -> str:
    """Generate next sequential CHKxxxxxx literal."""
    stmt = select(func.max(Checkout.checkout_literal)).where(
        Checkout.checkout_literal.like("CHK%")
    )
    max_literal = db.execute(stmt).scalar_one_or_none()
    if max_literal is None:
        next_seq = 1
    else:
        try:
            next_seq = int(max_literal[3:]) + 1
        except (ValueError, IndexError):
            max_id = db.execute(select(func.max(Checkout.id))).scalar_one_or_none() or 0
            next_seq = max_id + 1
    return f"CHK{next_seq:06d}"


# ─── Fetch single checkout ────────────────────────────────────────────────────

def get_checkout_by_literal(
    db: Session, checkout_literal: str, merchant_id: int
) -> Optional[Checkout]:
    stmt = (
        select(Checkout)
        .where(
            Checkout.checkout_literal == checkout_literal,
            Checkout.merchant_id == merchant_id,
            Checkout.deleted_at == None,
        )
        .options(
            selectinload(Checkout.links),
            selectinload(Checkout.line_items),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def get_checkout_by_id(
    db: Session, checkout_id: int, merchant_id: int
) -> Optional[Checkout]:
    stmt = select(Checkout).where(
        Checkout.id == checkout_id,
        Checkout.merchant_id == merchant_id,
        Checkout.deleted_at == None,
    )
    return db.execute(stmt).scalar_one_or_none()


def get_checkout_by_id_public(db: Session, checkout_id: int) -> Optional[Checkout]:
    """No merchant scope — used for public-facing lookups after token resolution."""
    stmt = (
        select(Checkout)
        .where(Checkout.id == checkout_id, Checkout.deleted_at == None)
        .options(selectinload(Checkout.line_items))
    )
    return db.execute(stmt).unique().scalar_one_or_none()


# ─── List checkouts ───────────────────────────────────────────────────────────

def list_checkouts(
    db: Session,
    merchant_id: int,
    status: Optional[str] = None,
    checkout_type: Optional[str] = None,
    q: Optional[str] = None,
    from_date: Optional[datetime] = None,
    to_date: Optional[datetime] = None,
    page: int = 1,
    page_size: int = 20,
    statuses: Optional[List[str]] = None,
    checkout_types: Optional[List[str]] = None,
) -> Tuple[List[Checkout], int]:
    base = (
        select(Checkout)
        .where(Checkout.merchant_id == merchant_id, Checkout.deleted_at == None)
    )
    # Multi-value filters take precedence over single-value ones
    if statuses:
        base = base.where(func.upper(Checkout.status).in_([s.upper() for s in statuses]))
    elif status:
        base = base.where(func.upper(Checkout.status) == status.upper())
    if checkout_types:
        base = base.where(func.upper(Checkout.checkout_type).in_([ct.upper() for ct in checkout_types]))
    elif checkout_type:
        base = base.where(func.upper(Checkout.checkout_type) == checkout_type.upper())
    if q:
        base = base.where(Checkout.title.ilike(f"%{q}%"))
    if from_date:
        base = base.where(Checkout.created_at >= from_date)
    if to_date:
        base = base.where(Checkout.created_at <= to_date)

    total = db.execute(select(func.count()).select_from(base.subquery())).scalar_one()
    rows = db.execute(
        base.order_by(Checkout.created_at.desc())
        .offset((page - 1) * page_size)
        .limit(page_size)
    ).scalars().all()
    return rows, total


# ─── Summary ──────────────────────────────────────────────────────────────────

def get_checkout_summary(db: Session, merchant_id: int) -> Dict[str, Any]:
    rows = db.execute(
        select(Checkout.status, func.count(Checkout.id))
        .where(Checkout.merchant_id == merchant_id, Checkout.deleted_at == None)
        .group_by(Checkout.status)
    ).all()
    counts = {r[0]: r[1] for r in rows}
    views = db.execute(
        select(func.coalesce(func.sum(CheckoutLink.click_count), 0))
        .join(Checkout, CheckoutLink.checkout_id == Checkout.id)
        .where(Checkout.merchant_id == merchant_id, Checkout.deleted_at == None)
    ).scalar_one()
    active_total = db.execute(
        select(func.coalesce(func.sum(Transactions.txn_amount), 0.0))
        .join(PaymentRequest, Transactions.payment_request_id == PaymentRequest.id)
        .join(Checkout, PaymentRequest.checkout_id == Checkout.id)
        .where(
            Checkout.merchant_id == merchant_id,
            Checkout.deleted_at == None,
            Checkout.status == "ACTIVE",
            Transactions.txn_status == TransactionStatusTypes.PAID,
        )
    ).scalar_one()
    return {
        "draft_count": counts.get("DRAFT", 0),
        "active_count": counts.get("ACTIVE", 0),
        "inactive_count": counts.get("INACTIVE", 0),
        "expired_count": counts.get("EXPIRED", 0),
        "active_total_collected": active_total,
        "total_link_views": views or 0,
    }


# ─── Create / update / delete ─────────────────────────────────────────────────

def create_checkout(db: Session, data: dict, merchant_id: int) -> Checkout:
    line_items_data = data.pop("line_items", []) or []
    checkout = Checkout(merchant_id=merchant_id, **data)
    db.add(checkout)
    db.flush()
    _LINE_ITEM_SKIP = {"id", "title"}
    for li in line_items_data:
        if isinstance(li, dict):
            item = CheckoutLineItem(
                checkout_id=checkout.id,
                **{k: v for k, v in li.items() if k not in _LINE_ITEM_SKIP},
            )
            db.add(item)
    db.flush()
    return checkout


def update_checkout(db: Session, checkout: Checkout, data: dict) -> Checkout:
    line_items_data = data.pop("line_items", None)
    for field, value in data.items():
        if hasattr(checkout, field):
            setattr(checkout, field, value)
    if line_items_data is not None:
        # Replace all existing line items
        _LINE_ITEM_SKIP = {"id", "title"}
        db.query(CheckoutLineItem).filter(CheckoutLineItem.checkout_id == checkout.id).delete(synchronize_session="evaluate")
        for li in line_items_data:
            if isinstance(li, dict):
                item = CheckoutLineItem(
                    checkout_id=checkout.id,
                    **{k: v for k, v in li.items() if k not in _LINE_ITEM_SKIP},
                )
                db.add(item)
    db.flush()
    return checkout


def soft_delete_checkout(db: Session, checkout: Checkout) -> Checkout:
    checkout.deleted_at = datetime.now(timezone.utc)
    db.flush()
    return checkout


# ─── Checkout Links ───────────────────────────────────────────────────────────

def get_active_link(db: Session, checkout_id: int) -> Optional[CheckoutLink]:
    stmt = select(CheckoutLink).where(
        CheckoutLink.checkout_id == checkout_id,
        CheckoutLink.status == "ACTIVE",
    )
    return db.execute(stmt).scalar_one_or_none()


def get_link_by_token(db: Session, token: str) -> Optional[CheckoutLink]:
    stmt = (
        select(CheckoutLink)
        .where(CheckoutLink.token == token)
        .options(selectinload(CheckoutLink.checkout).selectinload(Checkout.line_items))
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def create_checkout_link(db: Session, checkout_id: int, token: str) -> CheckoutLink:
    link = CheckoutLink(checkout_id=checkout_id, token=token, status="ACTIVE", click_count=0)
    db.add(link)
    db.flush()
    return link


def revoke_checkout_link(db: Session, link: CheckoutLink) -> CheckoutLink:
    link.status = "REVOKED"
    link.revoked_at = datetime.now(timezone.utc)
    db.flush()
    return link


def expire_checkout_link(db: Session, link: CheckoutLink) -> CheckoutLink:
    link.status = "EXPIRED"
    db.flush()
    return link


def increment_link_click_count(db: Session, link_id: int) -> None:
    """SQL-level atomic increment — avoids race conditions."""
    db.execute(
        update(CheckoutLink)
        .where(CheckoutLink.id == link_id)
        .values(
            click_count=CheckoutLink.click_count + 1,
            last_clicked_at=func.now(),
        )
    )


# ─── Activities ───────────────────────────────────────────────────────────────

def create_checkout_activity(
    db: Session,
    checkout_id: int,
    merchant_id: int,
    event_type: str,
    description: str,
    actor_user_id: Optional[int] = None,
    actor_customer_id: Optional[int] = None,
    metadata: Optional[dict] = None,
) -> CheckoutActivity:
    activity = CheckoutActivity(
        checkout_id=checkout_id,
        merchant_id=merchant_id,
        event_type=event_type,
        description=description,
        actor_user_id=actor_user_id,
        actor_customer_id=actor_customer_id,
        metadata_=metadata,
    )
    db.add(activity)
    db.flush()
    return activity


# ─── Settings ─────────────────────────────────────────────────────────────────

def get_checkout_settings(db: Session, merchant_id: int) -> Optional[CheckoutSettings]:
    stmt = select(CheckoutSettings).where(CheckoutSettings.merchant_id == merchant_id)
    return db.execute(stmt).scalar_one_or_none()


def upsert_checkout_settings(db: Session, merchant_id: int, data: dict) -> CheckoutSettings:
    settings = get_checkout_settings(db, merchant_id)
    if settings is None:
        settings = CheckoutSettings(merchant_id=merchant_id)
        db.add(settings)
    for field, value in data.items():
        if hasattr(settings, field):
            setattr(settings, field, value)
    db.flush()
    return settings


# ─── Bulk actions ─────────────────────────────────────────────────────────────

def get_checkouts_by_literals(
    db: Session, literals: List[str], merchant_id: int
) -> List[Checkout]:
    stmt = (
        select(Checkout)
        .where(
            Checkout.checkout_literal.in_(literals),
            Checkout.merchant_id == merchant_id,
            Checkout.deleted_at == None,
        )
        .options(selectinload(Checkout.links))
    )
    return db.execute(stmt).unique().scalars().all()


# ─── Analytics ────────────────────────────────────────────────────────────────

def get_checkout_analytics(
    db: Session, checkout_id: int, merchant_id: int, period: str = "30d"
) -> Dict[str, Any]:
    """Real-time analytics for a single checkout, filtered by period."""
    # --- Period window ---
    now = datetime.now(timezone.utc)
    _period_days = {"7d": 7, "30d": 30, "90d": 90, "365d": 365}
    days = _period_days.get(period) if period != "all" else None
    since = (now - timedelta(days=days)) if days is not None else None

    # --- Link views (all-time cumulative across all links for this checkout) ---
    link_views = db.execute(
        select(func.coalesce(func.sum(CheckoutLink.click_count), 0))
        .where(CheckoutLink.checkout_id == checkout_id)
    ).scalar_one() or 0

    # --- Paid transactions via this checkout ---
    # Transactions.customer_id is a direct FK to customers.id.
    # txn_amount = pr.amount + tip_amount (see hpp/services.py submit_payment).
    # So tip_total = SUM(txn_amount - pr.amount) for all paid checkout transactions.
    _PAID_STATUSES = (200, 201)  # TransactionStatusTypes.PAID, CAPTURED
    txn_date_col = func.coalesce(Transactions.occurred_at, Transactions.ocurred_at)

    paid_q = (
        select(
            Transactions.id,
            Transactions.txn_amount,
            Transactions.customer_id,
            PaymentRequest.amount.label("pr_amount"),
            txn_date_col.label("txn_date"),
        )
        .join(PaymentRequest, Transactions.payment_request_id == PaymentRequest.id)
        .where(
            PaymentRequest.checkout_id == checkout_id,
            PaymentRequest.merchant_id == merchant_id,
            PaymentRequest.deleted_at == None,
            Transactions.txn_status.in_(_PAID_STATUSES),
        )
    )
    if since is not None:
        paid_q = paid_q.where(txn_date_col >= since)

    paid_txns = db.execute(paid_q).all()

    transaction_count = len(paid_txns)
    total_collected = sum(float(t.txn_amount or 0) for t in paid_txns)
    tip_total = sum(
        max(0.0, float(t.txn_amount or 0) - float(t.pr_amount or 0))
        for t in paid_txns
    )
    unique_customer_ids = {t.customer_id for t in paid_txns if t.customer_id}
    unique_payers = len(unique_customer_ids)
    average_transaction_amount = (
        round(total_collected / transaction_count, 2) if transaction_count > 0 else 0.0
    )
    conversion_rate = (
        round(transaction_count / link_views * 100, 1) if link_views > 0 else 0.0
    )

    # --- Total authorized: sum of PR amounts currently in AUTHORISED state ---
    total_authorized = db.execute(
        select(func.coalesce(func.sum(PaymentRequest.amount), 0.0))
        .where(
            PaymentRequest.checkout_id == checkout_id,
            PaymentRequest.merchant_id == merchant_id,
            PaymentRequest.deleted_at == None,
            PaymentRequest.status == 300,  # PaymentRequestStatusTypes.AUTHORISED
        )
    ).scalar_one() or 0.0

    # --- New customers vs returning payers ---
    # A "new customer" had their first-ever payment request originate from this checkout.
    # A "returning payer" had prior payment requests before their first checkout PR.
    new_customers_created = 0
    returning_payers = 0
    if unique_customer_ids:
        customer_list = list(unique_customer_ids)

        # Earliest checkout PR date per customer (joined through User)
        first_checkout_pr_subq = (
            select(
                User.customer_id.label("customer_id"),
                func.min(PaymentRequest.created_at).label("first_checkout_at"),
            )
            .join(PaymentRequest, PaymentRequest.created_by_id == User.id)
            .where(
                PaymentRequest.checkout_id == checkout_id,
                PaymentRequest.merchant_id == merchant_id,
                PaymentRequest.deleted_at == None,
                User.customer_id.in_(customer_list),
            )
            .group_by(User.customer_id)
            .subquery()
        )

        # Count customers who had any PR before their first checkout PR
        prior_counts = db.execute(
            select(
                first_checkout_pr_subq.c.customer_id,
                func.count(PaymentRequest.id).label("prior_count"),
            )
            .join(User, User.customer_id == first_checkout_pr_subq.c.customer_id)
            .join(PaymentRequest, PaymentRequest.created_by_id == User.id)
            .where(
                PaymentRequest.created_at < first_checkout_pr_subq.c.first_checkout_at,
                PaymentRequest.merchant_id == merchant_id,
                PaymentRequest.deleted_at == None,
            )
            .group_by(first_checkout_pr_subq.c.customer_id)
        ).all()

        returning_payers = sum(1 for row in prior_counts if row.prior_count > 0)
        new_customers_created = unique_payers - returning_payers

    # --- Daily series (transactions grouped by day within the period) ---
    daily_series: List[Dict[str, Any]] = []
    if transaction_count > 0:
        daily_q = (
            select(
                func.date(txn_date_col).label("day"),
                func.count(Transactions.id).label("count"),
                func.sum(Transactions.txn_amount).label("total"),
            )
            .join(PaymentRequest, Transactions.payment_request_id == PaymentRequest.id)
            .where(
                PaymentRequest.checkout_id == checkout_id,
                PaymentRequest.merchant_id == merchant_id,
                PaymentRequest.deleted_at == None,
                Transactions.txn_status.in_(_PAID_STATUSES),
            )
            .group_by(func.date(txn_date_col))
            .order_by(func.date(txn_date_col))
        )
        if since is not None:
            daily_q = daily_q.where(txn_date_col >= since)

        daily_series = [
            {
                "date": str(row.day),
                "count": row.count,
                "total": round(float(row.total or 0), 2),
            }
            for row in db.execute(daily_q).all()
        ]

    return {
        "link_views": link_views,
        "total_collected": round(total_collected, 2),
        "total_authorized": round(float(total_authorized), 2),
        "transaction_count": transaction_count,
        "unique_payers": unique_payers,
        "tip_total": round(tip_total, 2),
        "new_customers_created": new_customers_created,
        "returning_payers": returning_payers,
        "conversion_rate": conversion_rate,
        "average_transaction_amount": average_transaction_amount,
        "daily_series": daily_series,
    }


def get_checkout_payers(
    db: Session, checkout_id: int, merchant_id: int
) -> List[Dict[str, Any]]:
    """Return distinct customers who submitted payments through this checkout, with their total paid amount."""
    # HPMNTP-961: use PaymentRequestCustomer junction table instead of going through User.
    # submit_payment() populates payment_requests_customers with customer_id directly,
    # making this join reliable regardless of user.customer_id state.
    from src.apps.payment_requests.models.payment_request_customer import PaymentRequestCustomer

    # Subquery: sum of amounts per customer for payment requests on this checkout
    paid_amount_subq = (
        select(
            PaymentRequestCustomer.customer_id.label("customer_id"),
            func.coalesce(func.sum(PaymentRequest.amount), 0.0).label("paid_amount"),
        )
        .join(PaymentRequest, PaymentRequest.id == PaymentRequestCustomer.payment_request_id)
        .where(
            PaymentRequest.checkout_id == checkout_id,
            PaymentRequest.merchant_id == merchant_id,
            PaymentRequest.deleted_at == None,
        )
        .group_by(PaymentRequestCustomer.customer_id)
        .subquery()
    )

    stmt = (
        select(Customer, func.coalesce(paid_amount_subq.c.paid_amount, 0.0).label("paid_amount"))
        .join(PaymentRequestCustomer, PaymentRequestCustomer.customer_id == Customer.id)
        .join(PaymentRequest, PaymentRequest.id == PaymentRequestCustomer.payment_request_id)
        .outerjoin(paid_amount_subq, paid_amount_subq.c.customer_id == Customer.id)
        .where(
            PaymentRequest.checkout_id == checkout_id,
            PaymentRequest.merchant_id == merchant_id,
            PaymentRequest.deleted_at == None,
            Customer.deleted_at == None,
        )
        .distinct()
    )
    rows = db.execute(stmt).unique().all()
    return [
        {
            "customer_id": c.id,
            "customer_literal": c.customer_id,
            "first_name": c.first_name,
            "last_name": c.last_name,
            "email": c.email,
            "phone": c.phone,
            "business_name": c.business_legal_name,
            "created_at": c.created_at.isoformat() if c.created_at else None,
            "paid_amount": float(paid_amount),
        }
        for c, paid_amount in rows
    ]


def get_all_expiring_checkouts(db: Session) -> List[Checkout]:
    """Fetch ACTIVE checkouts past their expires_at for the Celery expiry job."""
    now = datetime.now(timezone.utc)
    stmt = (
        select(Checkout)
        .where(
            Checkout.status == "ACTIVE",
            Checkout.expires_at != None,
            Checkout.expires_at <= now,
            Checkout.deleted_at == None,
        )
        .options(selectinload(Checkout.links))
    )
    return db.execute(stmt).unique().scalars().all()
