"""
Invoice CRUD — pure DB operations, zero business logic.

All functions accept a SQLAlchemy Session and return ORM instances or plain
dicts. Callers own the transaction boundary (commit/rollback).
"""

from datetime import datetime, timezone, timedelta
from typing import Dict, List, Optional, Tuple

from sqlalchemy import func, select, update
from sqlalchemy.orm import Session, joinedload, selectinload

from src.apps.invoices.models.invoice import Invoice
from src.apps.invoices.models.invoice_activity import InvoiceActivity
from src.apps.invoices.models.invoice_line_items import InvoiceLineItems
from src.apps.payment_requests.models.payment_request import PaymentRequest
from src.apps.invoices.schemas.invoice_requests import InvoiceListFilterSchema
from src.apps.base.models.reminder import Reminder
from src.core.utils.enums import InvoiceStatusTypes


# ─── Invoice queries ─────────────────────────────────────────────────────────


def get_invoice_by_id(
    db: Session,
    id: int,
    merchant_id: int,
) -> Optional[Invoice]:
    """Fetch a single non-deleted invoice by integer PK, scoped to merchant."""
    stmt = (
        select(Invoice)
        .where(
            Invoice.id == id,
            Invoice.merchant_id == merchant_id,
            Invoice.deleted_at == None,
        )
        .options(
            joinedload(Invoice.customer),
            joinedload(Invoice.payer),
            joinedload(Invoice.approver),
            joinedload(Invoice.payment_request),
            selectinload(Invoice.invoice_line_items),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def get_invoice_by_literal(
    db: Session,
    invoice_literal: str,
    merchant_id: int,
) -> Optional[Invoice]:
    """
    Fetch a single non-deleted Invoice by its human-readable literal,
    scoped to a merchant with full eager loading for the detail view.
    """
    stmt = (
        select(Invoice)
        .where(
            Invoice.invoice_literal == invoice_literal,
            Invoice.merchant_id == merchant_id,
            Invoice.deleted_at == None,
        )
        .options(
            joinedload(Invoice.customer),
            joinedload(Invoice.payer),
            joinedload(Invoice.approver),
            joinedload(Invoice.payment_request).options(
                selectinload(PaymentRequest.payment_request_adjustment),
                selectinload(PaymentRequest.attachments),
                selectinload(PaymentRequest.split_config),
                selectinload(PaymentRequest.recurring_config),
                selectinload(PaymentRequest.line_items),
            ),
            selectinload(Invoice.invoice_line_items),
            selectinload(Invoice.transactions),
            selectinload(Invoice.payment_links),
            joinedload(Invoice.adjustment),
            joinedload(Invoice.merchant),
        )
    )
    return db.execute(stmt).unique().scalar_one_or_none()


def list_invoices(
    db: Session,
    merchant_id: int,
    filters: InvoiceListFilterSchema,
    page: int = 1,
    per_page: int = 10,
) -> Tuple[List[Invoice], int]:
    """
    Paginated invoice list with optional search/status/sort filters.
    Delegates to the canonical implementation in services.invoice_services
    but is exposed here so routers can import from one place.
    """
    from src.apps.invoices.services.invoice_services import list_invoices as _list_invoices
    return _list_invoices(db=db, merchant_id=merchant_id, filters=filters, page=page, per_page=per_page)


def soft_delete_invoice(db: Session, invoice: Invoice) -> Invoice:
    """Mark an invoice as deleted (sets deleted_at = now). Does not commit."""
    invoice.deleted_at = datetime.now(timezone.utc)
    db.flush()
    return invoice


def get_invoice_summary_stats(db: Session, merchant_id: int) -> Dict:
    """
    Return counts and sums by status for a merchant.

    DRAFT invoices are excluded from all totals except draft_count.
    Soft-deleted invoices are always excluded.
    """
    non_draft_statuses = [
        InvoiceStatusTypes.CREATED,
        InvoiceStatusTypes.UPDATED,
        InvoiceStatusTypes.PENDING,
        InvoiceStatusTypes.WAITING,
        InvoiceStatusTypes.AWAITING_APPROVAL,
        InvoiceStatusTypes.CAPTURED,
        InvoiceStatusTypes.PAID,
        InvoiceStatusTypes.PARTIALLY_PAID,
        InvoiceStatusTypes.AUTHORIZED,
        InvoiceStatusTypes.FAILED,
        InvoiceStatusTypes.OVERDUE,
        InvoiceStatusTypes.CANCELLED,
    ]

    pending_statuses = [
        InvoiceStatusTypes.CREATED,
        InvoiceStatusTypes.UPDATED,
        InvoiceStatusTypes.PENDING,
        InvoiceStatusTypes.WAITING,
        InvoiceStatusTypes.AWAITING_APPROVAL,
    ]

    base = [Invoice.merchant_id == merchant_id, Invoice.deleted_at == None]

    # Total (non-draft)
    total_stmt = select(
        func.count(Invoice.id),
        func.coalesce(func.sum(Invoice.amount), 0.0),
    ).where(*base, Invoice.status.in_([s.value for s in non_draft_statuses]))
    total_count, total_amount = db.execute(total_stmt).one()

    # Pending
    pending_stmt = select(
        func.count(Invoice.id),
        func.coalesce(func.sum(Invoice.amount), 0.0),
    ).where(*base, Invoice.status.in_([s.value for s in pending_statuses]))
    pending_count, pending_amount = db.execute(pending_stmt).one()

    # Overdue
    overdue_stmt = select(
        func.count(Invoice.id),
        func.coalesce(func.sum(Invoice.amount), 0.0),
    ).where(*base, Invoice.status == InvoiceStatusTypes.OVERDUE)
    overdue_count, overdue_amount = db.execute(overdue_stmt).one()

    # Paid
    paid_stmt = select(
        func.count(Invoice.id),
        func.coalesce(func.sum(Invoice.amount), 0.0),
    ).where(*base, Invoice.status.in_([InvoiceStatusTypes.PAID, InvoiceStatusTypes.PARTIALLY_PAID]))
    paid_count, paid_amount = db.execute(paid_stmt).one()

    # Draft count (separate)
    draft_stmt = select(func.count(Invoice.id)).where(
        *base, Invoice.status == InvoiceStatusTypes.DRAFT
    )
    draft_count = db.execute(draft_stmt).scalar_one()

    return {
        "total_count": total_count,
        "total_amount": float(total_amount),
        "pending_count": pending_count,
        "pending_amount": float(pending_amount),
        "overdue_count": overdue_count,
        "overdue_amount": float(overdue_amount),
        "paid_count": paid_count,
        "paid_amount": float(paid_amount),
        "draft_count": draft_count,
    }


# ─── Activity queries ─────────────────────────────────────────────────────────


def list_invoice_activities(
    db: Session,
    invoice_id: int,
    page: int = 1,
    per_page: int = 20,
) -> Tuple[List[InvoiceActivity], int]:
    """Paginated activity trail for an invoice, newest first."""
    base = [InvoiceActivity.invoice_id == invoice_id]

    count_stmt = select(func.count(InvoiceActivity.id)).where(*base)
    total = db.execute(count_stmt).scalar_one()

    data_stmt = (
        select(InvoiceActivity)
        .where(*base)
        .order_by(InvoiceActivity.created_at.desc())
        .offset((page - 1) * per_page)
        .limit(per_page)
    )
    items = db.execute(data_stmt).scalars().all()
    return list(items), total


def write_activity(
    db: Session,
    invoice_id: int,
    activity_type: str,
    description: Optional[str] = None,
    actor_type: Optional[str] = None,
    actor_id: Optional[int] = None,
    metadata: Optional[dict] = None,
) -> InvoiceActivity:
    """
    Append an activity record for an invoice.

    Flushes immediately so the caller can read back the id within the
    same transaction. Does NOT commit — caller owns the boundary.
    """
    activity = InvoiceActivity(
        invoice_id=invoice_id,
        activity_type=activity_type,
        description=description,
        actor_type=actor_type,
        actor_id=actor_id,
        metadata_=metadata,
    )
    db.add(activity)
    db.flush()
    return activity


# ─── Reminder queries ─────────────────────────────────────────────────────────


def get_reminders_for_invoice(
    db: Session,
    invoice_id: int,
    merchant_id: int,
) -> List[Reminder]:
    """All non-deleted reminders for a given invoice + merchant."""
    stmt = select(Reminder).where(
        Reminder.invoice_id == invoice_id,
        Reminder.merchant_id == merchant_id,
        Reminder.deleted_at == None,
    )
    return list(db.execute(stmt).scalars().all())


def get_due_reminders(db: Session) -> List[Reminder]:
    """
    Reminders with status='pending' whose scheduled_at is at or before
    now + 15 minutes (the polling window). Used by send_reminder_poll task.
    """
    window = datetime.now(timezone.utc) + timedelta(minutes=15)
    stmt = select(Reminder).where(
        Reminder.reminder_status == "pending",
        Reminder.scheduled_at <= window,
        Reminder.deleted_at == None,
    )
    return list(db.execute(stmt).scalars().all())


# ─── Attachment / Transaction / Authorization queries ─────────────────────────


def get_invoice_attachments(db: Session, invoice_id: int) -> list:
    """Return File records attached to an invoice via the M2M map."""
    from src.apps.invoices.models.invoice import invoice_attachments_map
    from src.apps.files.models.file import File

    stmt = (
        select(File)
        .join(invoice_attachments_map, File.id == invoice_attachments_map.c.file_id)
        .where(invoice_attachments_map.c.invoice_id == invoice_id)
    )
    return list(db.execute(stmt).scalars().all())


def get_invoice_transactions(db: Session, invoice_id: int) -> list:
    """Return Transaction records linked to an invoice via the M2M map."""
    from src.apps.transactions.models.transactions import Transactions, transactions_invoices_map

    stmt = (
        select(Transactions)
        .join(
            transactions_invoices_map,
            Transactions.id == transactions_invoices_map.c.transaction_id,
        )
        .where(transactions_invoices_map.c.invoice_id == invoice_id)
    )
    return list(db.execute(stmt).scalars().all())


def get_invoice_authorizations(db: Session, invoice_id: int, merchant_id: int) -> list:
    """
    Return PaymentRequestAuthorizations linked to the invoice's payment_request.

    Filtering rules based on the payment_request.authorization_type:
    - request_auth: only HPP customer authorizations (hpp_session_id IS NOT NULL)
    - pre_auth:     only merchant-signed authorizations (merchant_signer_id IS NOT NULL)

    The inner invoice fetch is scoped to merchant_id and deleted_at to prevent
    cross-tenant data leakage.
    """
    from src.apps.invoices.models.invoice import Invoice as InvoiceModel
    from src.apps.payment_requests.models.payment_request_authorizations import (
        PaymentRequestAuthorizations,
    )
    from src.apps.payment_requests.enums import PaymentAuthorizationTypes

    stmt = select(InvoiceModel).where(
        InvoiceModel.id == invoice_id,
        InvoiceModel.merchant_id == merchant_id,
        InvoiceModel.deleted_at == None,
    )
    invoice = db.execute(stmt).scalar_one_or_none()
    if not invoice or not invoice.payment_request_id:
        return []

    pr = invoice.payment_request
    auth_type = getattr(pr, "authorization_type", None) if pr else None

    auth_stmt = select(PaymentRequestAuthorizations).where(
        PaymentRequestAuthorizations.payment_request_id == invoice.payment_request_id
    )

    if auth_type == PaymentAuthorizationTypes.REQUEST_AUTH.value:
        # Only show HPP customer authorizations
        auth_stmt = auth_stmt.where(PaymentRequestAuthorizations.hpp_session_id != None)
    elif auth_type == PaymentAuthorizationTypes.PRE_AUTH.value:
        # Only show merchant-signed preauth authorizations
        auth_stmt = auth_stmt.where(PaymentRequestAuthorizations.merchant_signer_id != None)

    return list(db.execute(auth_stmt).scalars().all())
