"""
Invoice business logic services.
"""

import logging
from datetime import datetime, timezone
from typing import Optional, Tuple, List

logger = logging.getLogger(__name__)

from sqlalchemy.orm import Session, joinedload, selectinload
from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError

from src.apps.invoices.models.invoice import Invoice
from src.apps.invoices.schemas.invoice_requests import InvoiceListFilterSchema
from src.apps.base.utils.functions import generate_secure_id
from src.core.utils.enums import InvoiceStatusTypes
from src.apps.payment_requests.models.payment_request import PaymentRequest


def generate_invoice_literal(db: Session) -> str:
    """
    Generate the next sequential invoice literal in the format INV000001.

    Queries the maximum existing INV-prefixed literal and increments by 1.
    Falls back to max(Invoice.id) + 1 if the suffix cannot be parsed.
    """
    stmt = (
        select(func.max(Invoice.invoice_literal))
        .where(Invoice.invoice_literal.like("INV%"))
    )
    max_literal = db.execute(stmt).scalar_one_or_none()

    if max_literal is None:
        next_seq = 1
    else:
        try:
            next_seq = int(max_literal[3:]) + 1
        except (ValueError, IndexError):
            max_id_stmt = select(func.max(Invoice.id))
            max_id = db.execute(max_id_stmt).scalar_one_or_none() or 0
            next_seq = max_id + 1

    return f"INV{next_seq:06d}"


def create_invoice(
    db: Session,
    *,
    payment_request,
    merchant_id: int,
    customer_id: int,
    amount: float,
    status: InvoiceStatusTypes,
    payer_id: Optional[int] = None,
    approver_id: Optional[int] = None,
    due_date: Optional[datetime] = None,
    billing_date: Optional[datetime] = None,
    reference: Optional[str] = None,
    shipping_fee: float = 0.0,
    tax_fee: float = 0.0,
    paid_amount: float = 0.0,
    paid_date: Optional[datetime] = None,
    is_surcharge_enabled: bool = False,
    surcharge_type: Optional[str] = None,
    comments: Optional[str] = None,
    sequence_id: Optional[int] = None,
) -> Invoice:
    """
    Create a new Invoice ORM record and flush it to the database.

    The caller owns the transaction boundary — this function does NOT commit.
    Call db.commit() after appending any M2M relationships (e.g. transactions).

    Args:
        db: Active SQLAlchemy session.
        payment_request: PaymentRequest ORM instance (already flushed, has .id).
        merchant_id: FK to merchants table.
        customer_id: FK to customers table.
        amount: Invoice total amount.
        status: InvoiceStatusTypes enum value.
        payer_id: Optional FK to customer_contacts.
        approver_id: Optional FK to customer_contacts.
        due_date: Optional due date.
        billing_date: Optional billing/issue date.
        reference: Optional free-text reference.
        shipping_fee: Shipping fee amount (default 0.0).
        tax_fee: Tax fee amount (default 0.0).
        paid_amount: Amount already paid (default 0.0).
        paid_date: Datetime the invoice was paid.
        is_surcharge_enabled: Whether surcharge is active.
        surcharge_type: Surcharge type string (e.g. "inclusive"/"exclusive").
        comments: Optional comments.
        sequence_id: Optional sequence identifier for recurring invoices.

    Returns:
        The newly created Invoice ORM instance (flushed, not committed).
    """
    now = datetime.now(timezone.utc)

    for attempt in range(5):
        try:
            savepoint = db.begin_nested()  # creates a SAVEPOINT
            invoice_literal = generate_invoice_literal(db)
            invoice_id = generate_secure_id(prepend="inv", length=20)

            invoice = Invoice(
                invoice_id=invoice_id,
                invoice_literal=invoice_literal,
                amount=amount,
                status=status,
                payment_request_id=payment_request.id,
                merchant_id=merchant_id,
                customer_id=customer_id,
                payer_id=payer_id,
                approver_id=approver_id,
                due_date=due_date,
                billing_date=billing_date,
                reference=reference,
                shipping_fee=shipping_fee,
                tax_fee=tax_fee,
                paid_amount=paid_amount,
                paid_date=paid_date,
                is_surcharge_enabled=is_surcharge_enabled,
                surcharge_type=surcharge_type,
                comments=comments,
                sequence_id=sequence_id,
                following=False,
                created_at=now,
                updated_at=now,
            )

            db.add(invoice)
            db.flush()
            savepoint.commit()
            return invoice
        except IntegrityError:
            savepoint.rollback()  # rolls back to SAVEPOINT only, not the whole transaction
            if attempt == 4:
                raise

    raise RuntimeError("Failed to generate unique invoice_literal after 5 attempts")


def get_invoice_by_literal(
    db: Session,
    invoice_literal: str,
    merchant_id: int,
) -> Optional[Invoice]:
    """
    Fetch a single Invoice by its human-readable literal, scoped to a merchant.

    Eagerly loads all relationships required for the detail view.

    Args:
        db: Active SQLAlchemy session.
        invoice_literal: The INVxxxxxx identifier string.
        merchant_id: Merchant scope guard — only returns invoices owned by this merchant.

    Returns:
        The matching Invoice ORM instance, or None if not found.
    """
    stmt = (
        select(Invoice)
        .where(
            Invoice.invoice_literal == invoice_literal,
            Invoice.merchant_id == merchant_id,
            Invoice.deleted_at == None,
        )
        .options(
            joinedload(Invoice.customer),
            joinedload(Invoice.payer),
            joinedload(Invoice.approver),
            joinedload(Invoice.payment_request).selectinload(PaymentRequest.split_config),
            joinedload(Invoice.payment_request).selectinload(PaymentRequest.recurring_config),
            joinedload(Invoice.payment_request).selectinload(PaymentRequest.payment_request_adjustment),
            selectinload(Invoice.invoice_line_items),
            selectinload(Invoice.transactions),
            joinedload(Invoice.adjustment),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def list_invoices(
    db: Session,
    merchant_id: int,
    filters: InvoiceListFilterSchema,
    page: int = 1,
    per_page: int = 10,
) -> Tuple[List[Invoice], int]:
    """
    Return a paginated list of invoices for a merchant, with optional filters.

    Filters applied (all optional, all additive):
    - ``filters.search``: case-insensitive substring match on invoice_literal or invoice_id.
    - ``filters.status``: comma-separated integer status codes, e.g. "300,400".
    - ``filters.invoice_literal``: exact match on invoice_literal.
    - ``filters.invoice_id``: exact match on invoice_id.

    Args:
        db: Active SQLAlchemy session.
        merchant_id: Merchant scope guard.
        filters: InvoiceListFilterSchema dependency values.
        page: 1-based page number.
        per_page: Number of records per page (max 100).

    Returns:
        A tuple of (list of Invoice ORM instances, total record count).
    """
    base_conditions = [
        Invoice.merchant_id == merchant_id,
        Invoice.deleted_at == None,
    ]

    if filters.search:
        def _escape_like(value: str) -> str:
            """Escape LIKE special characters to prevent query amplification."""
            return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")

        escaped_search = _escape_like(filters.search)
        base_conditions.append(
            Invoice.invoice_literal.ilike(f"%{escaped_search}%", escape="\\")
            | Invoice.invoice_id.ilike(f"%{escaped_search}%", escape="\\")
        )

    if filters.status:
        status_values = [int(s.strip()) for s in filters.status.split(",") if s.strip().isdigit()]
        if status_values:
            base_conditions.append(Invoice.status.in_(status_values))

    if filters.customer_id:
        base_conditions.append(Invoice.customer_id == filters.customer_id)

    if filters.invoice_literal:
        base_conditions.append(Invoice.invoice_literal == filters.invoice_literal)

    if filters.invoice_id:
        base_conditions.append(Invoice.invoice_id == filters.invoice_id)

    if filters.payment_frequency:
        freq_values = [f.strip() for f in filters.payment_frequency.split(",") if f.strip()]
        if freq_values:
            base_conditions.append(
                Invoice.payment_request_id.in_(
                    select(PaymentRequest.id).where(
                        PaymentRequest.payment_frequency.in_(freq_values),
                        PaymentRequest.deleted_at == None,
                    )
                )
            )

    if filters.date_from:
        try:
            from datetime import date as _date
            date_from = _date.fromisoformat(filters.date_from[:10])
            base_conditions.append(Invoice.billing_date >= date_from)
        except ValueError:
            pass

    if filters.date_to:
        try:
            from datetime import date as _date
            date_to = _date.fromisoformat(filters.date_to[:10])
            base_conditions.append(Invoice.billing_date <= date_to)
        except ValueError:
            pass

    # Count query
    count_stmt = select(func.count(Invoice.id)).where(*base_conditions)
    total = db.execute(count_stmt).scalar_one()

    # Build sort order
    _SORTABLE_FIELDS = {
        "invoice_literal": Invoice.invoice_literal,
        "amount": Invoice.amount,
        "due_date": Invoice.due_date,
        "billing_date": Invoice.billing_date,
        "created_at": Invoice.created_at,
        "status": Invoice.status,
        "paid_amount": Invoice.paid_amount,
        "reference": Invoice.reference,
    }
    order_clauses = []
    if filters.sort_by:
        for sort_field in filters.sort_by:
            desc_flag = sort_field.startswith("-")
            field_name = sort_field.lstrip("-")
            if field_name in _SORTABLE_FIELDS:
                col = _SORTABLE_FIELDS[field_name]
                order_clauses.append(col.desc() if desc_flag else col.asc())
    if not order_clauses:
        order_clauses = [Invoice.created_at.desc()]

    # Data query with pagination and eager loading of customer for list view
    offset = (page - 1) * per_page
    data_stmt = (
        select(Invoice)
        .where(*base_conditions)
        .options(
            joinedload(Invoice.customer),
            joinedload(Invoice.merchant),
            joinedload(Invoice.payer),
            joinedload(Invoice.payment_request),
        )
        .order_by(*order_clauses)
        .offset(offset)
        .limit(per_page)
    )
    items = db.execute(data_stmt).unique().scalars().all()

    return list(items), total


async def prepare_invoice_from_payment_request(
    payment_request_id: int,
    db: Session,
) -> Optional[Invoice]:
    """
    Create an Invoice record from a completed PaymentRequest.

    Idempotent: returns the existing invoice if one already exists for this
    payment_request_id.  Returns None if the PaymentRequest cannot be found
    or has no associated customer.

    Args:
        payment_request_id: PK of the PaymentRequest to invoice.
        db: Active SQLAlchemy session.  The caller (event listener) is
            responsible for committing.

    Returns:
        The newly created (or pre-existing) Invoice instance, or None.
    """
    # 1. Return early if an invoice already exists for this payment request.
    #    If the existing invoice is still PENDING (e.g. created via New Invoice
    #    before the customer paid via HPP), update it to PAID so it reflects the
    #    completed payment correctly.
    existing_stmt = select(Invoice).where(
        Invoice.payment_request_id == payment_request_id,
        Invoice.deleted_at.is_(None),
    )
    existing_invoice = db.execute(existing_stmt).scalar_one_or_none()
    if existing_invoice is not None:
        if existing_invoice.status in (InvoiceStatusTypes.PENDING, InvoiceStatusTypes.AWAITING_APPROVAL, InvoiceStatusTypes.DRAFT):
            existing_invoice.status = InvoiceStatusTypes.PAID
            existing_invoice.paid_amount = existing_invoice.paid_amount or float(existing_invoice.amount or 0)
            existing_invoice.paid_date = existing_invoice.paid_date or datetime.now(timezone.utc)
            db.flush()
            logger.info(
                "prepare_invoice_from_payment_request: promoted existing invoice %s to PAID for PR %s",
                existing_invoice.invoice_literal,
                payment_request_id,
            )
            # Update associated authorization records to ACTIVE now that payment completed
            try:
                from src.apps.payment_requests.models.payment_request_authorizations import (
                    PaymentRequestAuthorizations,
                )
                from src.core.utils.enums import AuthorizationStatus
                auth_upd_stmt = select(PaymentRequestAuthorizations).where(
                    PaymentRequestAuthorizations.payment_request_id == payment_request_id,
                    PaymentRequestAuthorizations.status == AuthorizationStatus.PENDING.value,
                )
                for auth_rec in db.execute(auth_upd_stmt).scalars().all():
                    auth_rec.status = AuthorizationStatus.ACTIVE.value
                db.flush()
            except Exception as _auth_exc:
                logger.warning(
                    "prepare_invoice_from_payment_request: could not update auth status for PR %s: %s",
                    payment_request_id,
                    _auth_exc,
                )
        return existing_invoice

    # 2. Load the PaymentRequest.
    pr_stmt = select(PaymentRequest).where(
        PaymentRequest.id == payment_request_id,
        PaymentRequest.deleted_at.is_(None),
    )
    pr = db.execute(pr_stmt).scalar_one_or_none()
    if pr is None:
        return None

    # 3. Resolve customer_id from payment_request_customers.
    customer_id: Optional[int] = None
    for prc in getattr(pr, "payment_request_customers", []):
        cust = getattr(prc, "customer", None)
        if cust:
            customer_id = cust.id
            break

    if customer_id is None:
        return None

    # 4. Create the invoice.
    # NOTE: Do NOT call db.commit() here — the caller (invoices/listener.py)
    # is responsible for committing via SessionCelery.__exit__ and
    # mark_event_processed.
    invoice = create_invoice(
        db=db,
        payment_request=pr,
        merchant_id=pr.merchant_id,
        customer_id=customer_id,
        amount=float(pr.amount or 0),
        status=InvoiceStatusTypes.PENDING,
        due_date=getattr(pr, "due_date", None),
    )

    # 5. Write invoice activity so the activity trail is always populated.
    from src.apps.invoices import crud as invoice_crud
    from src.core.utils.enums import InvoiceActivityTypes
    invoice_crud.write_activity(
        db=db,
        invoice_id=invoice.id,
        activity_type=InvoiceActivityTypes.INVOICE_CREATED,
        description="Invoice created from completed payment",
        actor_type="system",
    )

    return invoice
