"""
Cart Plugin CRUD — pure DB operations, zero business logic.
All callers own the transaction boundary (commit / rollback).
"""
from __future__ import annotations

from datetime import datetime, timezone
from typing import List, Optional, Tuple

from sqlalchemy import func, select, update
from sqlalchemy.orm import Session, selectinload

from src.apps.cart_plugin.models.merchant_widget_key import MerchantWidgetKey
from src.apps.cart_plugin.models.cart_session import CartSession
from src.apps.cart_plugin.models.cart_session_item import CartSessionItem
from src.apps.cart_plugin.schemas.cart_schemas import (
    CreateWidgetKeyRequest,
    UpdateWidgetKeyRequest,
)


# ─── Widget key CRUD ──────────────────────────────────────────────────────────

def get_widget_key_by_public_key(db: Session, public_key: str) -> Optional[MerchantWidgetKey]:
    """Look up a widget key by its public key string; filters soft-deleted rows."""
    stmt = select(MerchantWidgetKey).where(
        MerchantWidgetKey.public_key == public_key,
        MerchantWidgetKey.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def get_widget_key_by_id(
    db: Session, key_id: int, merchant_id: int
) -> Optional[MerchantWidgetKey]:
    """Fetch a widget key by PK, scoped to a merchant; filters soft-deleted rows."""
    stmt = select(MerchantWidgetKey).where(
        MerchantWidgetKey.id == key_id,
        MerchantWidgetKey.merchant_id == merchant_id,
        MerchantWidgetKey.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def get_widget_keys_for_merchant(
    db: Session, merchant_id: int
) -> List[MerchantWidgetKey]:
    """Return all non-deleted widget keys for a merchant, newest first."""
    stmt = (
        select(MerchantWidgetKey)
        .where(
            MerchantWidgetKey.merchant_id == merchant_id,
            MerchantWidgetKey.deleted_at.is_(None),
        )
        .order_by(MerchantWidgetKey.created_at.desc())
    )
    return list(db.execute(stmt).scalars().all())


def create_widget_key(
    db: Session,
    merchant_id: int,
    data: CreateWidgetKeyRequest,
    public_key: str,
    encrypted_secret: Optional[str],
) -> MerchantWidgetKey:
    """Persist a new widget key and return it (caller must commit)."""
    key = MerchantWidgetKey(
        merchant_id=merchant_id,
        display_name=data.display_name,
        public_key=public_key,
        webhook_secret_encrypted=encrypted_secret,
        allowed_origins=data.allowed_origins or [],
        webhook_url=data.webhook_url,
        checkout_mode=data.checkout_mode or "auto",
        is_active=True,
    )
    db.add(key)
    db.flush()
    return key


def update_widget_key(
    db: Session, key: MerchantWidgetKey, data: UpdateWidgetKeyRequest
) -> MerchantWidgetKey:
    """Apply a partial update to a widget key (caller must commit)."""
    update_data = data.model_dump(exclude_unset=True)
    for field, value in update_data.items():
        setattr(key, field, value)
    db.flush()
    return key


def revoke_widget_key(db: Session, key: MerchantWidgetKey) -> MerchantWidgetKey:
    """Mark a key as revoked and inactive (caller must commit)."""
    now = datetime.now(timezone.utc)
    key.is_active = False
    key.revoked_at = now
    db.flush()
    return key


def rotate_widget_key_secret(
    db: Session, key: MerchantWidgetKey, new_encrypted_secret: str
) -> MerchantWidgetKey:
    """Replace the encrypted webhook secret (caller must commit)."""
    key.webhook_secret_encrypted = new_encrypted_secret
    db.flush()
    return key


# ─── Cart session CRUD ────────────────────────────────────────────────────────

def create_cart_session(
    db: Session,
    merchant_id: int,
    widget_key_id: int,
    token: str,
    data,  # CreateCartSessionRequest
    expires_at: Optional[datetime],
    origin: Optional[str] = None,
) -> CartSession:
    """Persist a new CartSession (caller must commit)."""
    tip_cents = int(round((data.tip_amount or 0.0) * 100))
    session = CartSession(
        token=token,
        merchant_id=merchant_id,
        widget_key_id=widget_key_id,
        status="PENDING",
        currency=data.currency or "USD",
        origin=origin,
        return_url=data.return_url,
        checkout_mode=data.checkout_mode,
        tip_amount=tip_cents,
        discount_code=data.discount_code,
        metadata_=data.metadata,
        idempotency_key=data.idempotency_key,
        expires_at=expires_at,
    )
    db.add(session)
    db.flush()
    return session


def create_cart_session_items(
    db: Session, session_id: int, items: list
) -> List[CartSessionItem]:
    """Bulk-insert CartSessionItems and return them (caller must commit)."""
    db_items = []
    for item_data in items:
        unit_price_cents = int(round(float(item_data.unit_price) * 100))
        db_item = CartSessionItem(
            session_id=session_id,
            external_id=item_data.external_id,
            name=item_data.name,
            description=item_data.description,
            quantity=item_data.quantity,
            unit_price=unit_price_cents,
            image_url=item_data.image_url,
            metadata_=item_data.metadata,
        )
        db.add(db_item)
        db_items.append(db_item)
    db.flush()
    return db_items


def get_cart_session_by_token(db: Session, token: str) -> Optional[CartSession]:
    """Fetch a session by token without locking."""
    stmt = select(CartSession).where(CartSession.token == token)
    return db.execute(stmt).scalar_one_or_none()


def get_cart_session_with_items(
    db: Session, token: str, for_update: bool = False
) -> Optional[CartSession]:
    """
    Fetch a session by token with items eagerly loaded.
    When for_update=True, acquires a SELECT FOR UPDATE row lock — use this
    in the submit path to prevent duplicate submissions.
    """
    stmt = (
        select(CartSession)
        .options(selectinload(CartSession.items), selectinload(CartSession.widget_key))
        .where(CartSession.token == token)
    )
    if for_update:
        stmt = stmt.with_for_update()
    return db.execute(stmt).unique().scalar_one_or_none()


def get_session_by_idempotency_key(
    db: Session, widget_key_id: int, idempotency_key: str
) -> Optional[CartSession]:
    """Return an existing session matching the idempotency key, if any."""
    stmt = select(CartSession).where(
        CartSession.widget_key_id == widget_key_id,
        CartSession.idempotency_key == idempotency_key,
    )
    return db.execute(stmt).scalar_one_or_none()


def update_cart_session_paid(
    db: Session,
    session: CartSession,
    transaction_id: int,
    subtotal_cents: int,
    tip_cents: int,
    tax_cents: int,
    discount_cents: int,
    total_cents: int,
    customer_name: Optional[str],
    customer_email: Optional[str],
    billing_info: Optional[dict],
    discount_code: Optional[str],
) -> CartSession:
    """Mark a session as PAID with all computed totals (caller must commit)."""
    session.status = "PAID"
    session.transaction_id = transaction_id
    session.subtotal = subtotal_cents
    session.tip_amount = tip_cents
    session.tax_amount = tax_cents
    session.discount_amount = discount_cents
    session.total = total_cents
    session.customer_name = customer_name
    session.customer_email = customer_email
    session.billing_info = billing_info
    session.discount_code = discount_code
    db.flush()
    return session


def update_cart_session_webhook_status(
    db: Session, session: CartSession, status_code: int
) -> CartSession:
    """Record webhook delivery attempt result (caller must commit)."""
    session.last_webhook_attempt_at = datetime.now(timezone.utc)
    session.last_webhook_status = status_code
    db.flush()
    return session


def expire_stale_sessions(db: Session) -> int:
    """
    Bulk-UPDATE PENDING sessions whose expires_at has passed to EXPIRED.
    Returns the number of rows affected.
    """
    now = datetime.now(timezone.utc)
    stmt = (
        update(CartSession)
        .where(
            CartSession.status == "PENDING",
            CartSession.expires_at < now,
        )
        .values(status="EXPIRED")
        .execution_options(synchronize_session="fetch")
    )
    result = db.execute(stmt)
    return result.rowcount


def list_sessions_for_merchant(
    db: Session,
    merchant_id: int,
    *,
    status: Optional[str] = None,
    limit: int = 25,
    offset: int = 0,
) -> Tuple[List[CartSession], int]:
    """Return paginated sessions for a merchant with an optional status filter."""
    base_where = [CartSession.merchant_id == merchant_id]
    if status:
        base_where.append(CartSession.status == status.upper())

    count_stmt = select(func.count()).select_from(CartSession).where(*base_where)
    total = db.execute(count_stmt).scalar_one()

    list_stmt = (
        select(CartSession)
        .where(*base_where)
        .order_by(CartSession.created_at.desc())
        .limit(limit)
        .offset(offset)
    )
    sessions = list(db.execute(list_stmt).scalars().all())
    return sessions, total
