"""
Dashboard service layer — HWUI-001
Aggregates data from transactions, invoices, payment_request_authorizations, and customers.
No events emitted — read-only aggregation.
"""
from datetime import datetime, timezone
from typing import Tuple

from sqlalchemy import select, func, case, and_, distinct
from sqlalchemy.orm import Session

from src.apps.dashboard.schemas import (
    DateRangeEnum,
    DashboardSummaryResponse,
    RevenueTrendPoint,
    RevenueTrendResponse,
    PaymentMethodDataPoint,
    PaymentMethodsChartResponse,
    TransactionStatusDataPoint,
    TransactionStatusChartResponse,
    TopCustomerDataPoint,
    TopCustomersChartResponse,
    OverdueInvoiceItem,
    FailedTransactionItem,
    PendingAuthorizationItem,
    ActionItemsResponse,
)
from src.core.utils.enums import (
    TransactionStatusTypes,
    TransactionCategories,
    InvoiceStatusTypes,
    AuthorizationStatus,
)


# ---------------------------------------------------------------------------
# Date range helpers
# ---------------------------------------------------------------------------

def _get_date_range(range_val: DateRangeEnum) -> Tuple[datetime, datetime]:
    """Return (period_start, period_end) for the given range."""
    now = datetime.now(timezone.utc).replace(tzinfo=None)
    if range_val == DateRangeEnum.TODAY:
        start = now.replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.LAST_7_DAYS:
        from datetime import timedelta
        start = (now - timedelta(days=7)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.LAST_10_DAYS:
        from datetime import timedelta
        start = (now - timedelta(days=10)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.LAST_30_DAYS:
        from datetime import timedelta
        start = (now - timedelta(days=30)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.LAST_90_DAYS:
        from datetime import timedelta
        start = (now - timedelta(days=90)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.THIS_MONTH:
        start = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.LAST_3_MONTHS:
        from datetime import timedelta
        start = (now - timedelta(days=90)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    elif range_val == DateRangeEnum.THIS_YEAR:
        start = now.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
        end = now
    else:
        from datetime import timedelta
        start = (now - timedelta(days=7)).replace(hour=0, minute=0, second=0, microsecond=0)
        end = now
    return start, end


def _get_prior_period(start: datetime, end: datetime) -> Tuple[datetime, datetime]:
    """Return the equivalent prior period for change_pct calculation."""
    delta = end - start
    prior_end = start
    prior_start = prior_end - delta
    return prior_start, prior_end


def _get_trunc_unit(range_val: DateRangeEnum) -> str:
    """Auto-select date_trunc granularity based on range."""
    if range_val == DateRangeEnum.TODAY:
        return "hour"
    elif range_val in (DateRangeEnum.LAST_3_MONTHS, DateRangeEnum.LAST_90_DAYS):
        return "week"
    elif range_val == DateRangeEnum.THIS_YEAR:
        return "month"
    else:
        return "day"


def _change_pct(current: float, prior: float) -> float:
    """Calculate percentage change. Returns 0 if prior is 0."""
    if prior == 0:
        return 100.0 if current > 0 else 0.0
    return round(((current - prior) / abs(prior)) * 100, 1)


def _txn_date_col():
    """Return the date column expression to use for transaction date filtering."""
    from src.apps.transactions.models.transactions import Transactions
    # Use coalesce(occurred_at, ocurred_at) for backwards compat with legacy rows
    return func.coalesce(Transactions.occurred_at, Transactions.ocurred_at)


# ---------------------------------------------------------------------------
# KPI Summary (HWUI-101)
# ---------------------------------------------------------------------------

def get_dashboard_summary(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
) -> DashboardSummaryResponse:
    from src.apps.transactions.models.transactions import Transactions
    from src.apps.invoices.models.invoice import Invoice

    period_start, period_end = _get_date_range(range_val)
    prior_start, prior_end = _get_prior_period(period_start, period_end)

    date_col = _txn_date_col()

    def _revenue_stats(start: datetime, end: datetime):
        stmt = select(
            func.coalesce(func.sum(
                case((
                    and_(
                        Transactions.txn_status == TransactionStatusTypes.PAID,
                        Transactions.category == TransactionCategories.CHARGE,
                    ),
                    Transactions.txn_amount
                ), else_=0)
            ), 0).label("revenue"),
            func.count(
                case((
                    and_(
                        Transactions.txn_status == TransactionStatusTypes.PAID,
                        Transactions.category == TransactionCategories.CHARGE,
                    ),
                    1
                ))
            ).label("txn_count"),
            func.coalesce(func.sum(
                case((
                    Transactions.category == TransactionCategories.REFUND,
                    Transactions.txn_amount
                ), else_=0)
            ), 0).label("refund_amount"),
            func.count(
                case((
                    Transactions.category == TransactionCategories.REFUND,
                    1
                ))
            ).label("refund_count"),
            func.count(distinct(Transactions.customer_id)).label("active_customers"),
        ).where(
            Transactions.merchant_id == merchant_id,
            date_col >= start,
            date_col <= end,
        )
        return db.execute(stmt).one()

    curr = _revenue_stats(period_start, period_end)
    prior = _revenue_stats(prior_start, prior_end)

    total_revenue = float(curr.revenue or 0)
    txn_count = int(curr.txn_count or 0)
    refund_amount = float(curr.refund_amount or 0)
    refund_count = int(curr.refund_count or 0)
    active_customers = int(curr.active_customers or 0)

    total_txns = txn_count + refund_count
    refund_rate_pct = round((refund_count / total_txns * 100) if total_txns > 0 else 0.0, 2)
    net_revenue = round(total_revenue - refund_amount, 2)
    avg_txn_value = round(total_revenue / txn_count if txn_count > 0 else 0.0, 2)

    prior_revenue = float(prior.revenue or 0)
    prior_txn_count = int(prior.txn_count or 0)
    prior_refund_amount = float(prior.refund_amount or 0)
    prior_refund_count = int(prior.refund_count or 0)
    prior_active_customers = int(prior.active_customers or 0)
    prior_total_txns = prior_txn_count + prior_refund_count
    prior_refund_rate = (prior_refund_count / prior_total_txns * 100) if prior_total_txns > 0 else 0.0
    prior_net = prior_revenue - prior_refund_amount
    prior_avg = prior_revenue / prior_txn_count if prior_txn_count > 0 else 0.0

    # Outstanding invoices — not date-filtered, always current state
    inv_stmt = select(
        func.count(Invoice.id).label("count"),
        func.coalesce(func.sum(Invoice.amount), 0).label("amount"),
    ).where(
        Invoice.merchant_id == merchant_id,
        Invoice.status.in_([InvoiceStatusTypes.PENDING, InvoiceStatusTypes.OVERDUE]),
        Invoice.deleted_at.is_(None),
    )
    inv_result = db.execute(inv_stmt).one()

    return DashboardSummaryResponse(
        total_revenue=round(total_revenue, 2),
        total_revenue_change_pct=_change_pct(total_revenue, prior_revenue),
        transaction_count=txn_count,
        transaction_count_change_pct=_change_pct(txn_count, prior_txn_count),
        refund_count=refund_count,
        refund_amount=round(refund_amount, 2),
        refund_rate_pct=refund_rate_pct,
        refund_rate_change_pct=_change_pct(refund_rate_pct, prior_refund_rate),
        outstanding_invoice_count=int(inv_result.count or 0),
        outstanding_invoice_amount=round(float(inv_result.amount or 0), 2),
        net_revenue=net_revenue,
        net_revenue_change_pct=_change_pct(net_revenue, prior_net),
        avg_transaction_value=avg_txn_value,
        avg_transaction_value_change_pct=_change_pct(avg_txn_value, prior_avg),
        active_customer_count=active_customers,
        active_customer_count_change_pct=_change_pct(active_customers, prior_active_customers),
        date_range=range_val,
        period_start=period_start,
        period_end=period_end,
    )


# ---------------------------------------------------------------------------
# Revenue Trend Chart (HWUI-102)
# ---------------------------------------------------------------------------

def get_revenue_trend(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
) -> RevenueTrendResponse:
    from src.apps.transactions.models.transactions import Transactions

    period_start, period_end = _get_date_range(range_val)
    prior_start, prior_end = _get_prior_period(period_start, period_end)
    trunc_unit = _get_trunc_unit(range_val)
    date_col = _txn_date_col()

    def _trend_query(start: datetime, end: datetime):
        stmt = (
            select(
                func.date_trunc(trunc_unit, date_col).label("bucket"),
                func.coalesce(func.sum(
                    case((
                        and_(
                            Transactions.txn_status == TransactionStatusTypes.PAID,
                            Transactions.category == TransactionCategories.CHARGE,
                        ),
                        Transactions.txn_amount
                    ), else_=0)
                ), 0).label("amount"),
                func.count(
                    case((
                        and_(
                            Transactions.txn_status == TransactionStatusTypes.PAID,
                            Transactions.category == TransactionCategories.CHARGE,
                        ),
                        1
                    ))
                ).label("txn_count"),
            )
            .where(
                Transactions.merchant_id == merchant_id,
                date_col >= start,
                date_col <= end,
            )
            .group_by("bucket")
            .order_by("bucket")
        )
        rows = db.execute(stmt).all()
        result = []
        for r in rows:
            if r.bucket is None:
                continue
            if trunc_unit == "hour":
                date_str = r.bucket.strftime("%Y-%m-%dT%H:00")
            elif trunc_unit == "month":
                date_str = r.bucket.strftime("%Y-%m")
            elif trunc_unit == "week":
                date_str = r.bucket.strftime("%Y-%m-%d")
            else:
                date_str = r.bucket.strftime("%Y-%m-%d")
            result.append(RevenueTrendPoint(
                date=date_str,
                amount=round(float(r.amount or 0), 2),
                transaction_count=int(r.txn_count or 0),
            ))
        return result

    return RevenueTrendResponse(
        data=_trend_query(period_start, period_end),
        prior_period_data=_trend_query(prior_start, prior_end),
        date_range=range_val,
    )


# ---------------------------------------------------------------------------
# Payment Methods Chart (HWUI-102)
# ---------------------------------------------------------------------------

def get_payment_methods_chart(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
) -> PaymentMethodsChartResponse:
    from src.apps.transactions.models.transactions import Transactions
    from src.apps.payment_methods.models.payment_methods import PaymentMethod
    from src.apps.payment_methods.models.payment_method_card_details import PaymentMethodCardDetails

    period_start, period_end = _get_date_range(range_val)
    date_col = _txn_date_col()

    stmt = (
        select(
            PaymentMethod.method.label("method"),
            func.coalesce(PaymentMethodCardDetails.brand, PaymentMethod.method).label("brand"),
            func.count(Transactions.id).label("count"),
            func.coalesce(func.sum(Transactions.txn_amount), 0).label("amount"),
        )
        .join(PaymentMethod, Transactions.payment_method_id == PaymentMethod.id)
        .outerjoin(PaymentMethodCardDetails, PaymentMethod.card_details_id == PaymentMethodCardDetails.id)
        .where(
            Transactions.merchant_id == merchant_id,
            Transactions.txn_status == TransactionStatusTypes.PAID,
            Transactions.category == TransactionCategories.CHARGE,
            date_col >= period_start,
            date_col <= period_end,
        )
        .group_by(PaymentMethod.method, PaymentMethodCardDetails.brand)
        .order_by(func.count(Transactions.id).desc())
    )

    rows = db.execute(stmt).all()
    total_count = sum(r.count for r in rows)
    total_amount = sum(float(r.amount or 0) for r in rows)

    data = []
    for r in rows:
        pct = round((r.count / total_count * 100) if total_count > 0 else 0.0, 1)
        data.append(PaymentMethodDataPoint(
            method=r.method or "card",
            brand=(r.brand or r.method or "other").lower(),
            count=int(r.count or 0),
            amount=round(float(r.amount or 0), 2),
            percentage=pct,
        ))

    return PaymentMethodsChartResponse(
        data=data,
        total_count=total_count,
        total_amount=round(total_amount, 2),
        date_range=range_val,
    )


# ---------------------------------------------------------------------------
# Transaction Status Chart (HWUI-102)
# ---------------------------------------------------------------------------

def get_transaction_status_chart(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
) -> TransactionStatusChartResponse:
    from src.apps.transactions.models.transactions import Transactions

    period_start, period_end = _get_date_range(range_val)
    date_col = _txn_date_col()

    STATUS_MAP = {
        TransactionStatusTypes.PAID: "paid",
        TransactionStatusTypes.PENDING: "pending",
        TransactionStatusTypes.FAILED: "failed",
        TransactionStatusTypes.REFUNDED: "refunded",
    }

    stmt = (
        select(
            Transactions.txn_status.label("status"),
            func.count(Transactions.id).label("count"),
            func.coalesce(func.sum(Transactions.txn_amount), 0).label("amount"),
        )
        .where(
            Transactions.merchant_id == merchant_id,
            date_col >= period_start,
            date_col <= period_end,
        )
        .group_by(Transactions.txn_status)
    )

    rows = db.execute(stmt).all()
    row_map = {r.status: r for r in rows}

    data = []
    for status_type, label in STATUS_MAP.items():
        r = row_map.get(status_type.value)
        data.append(TransactionStatusDataPoint(
            status=label,
            count=int(r.count if r else 0),
            amount=round(float(r.amount if r else 0), 2),
        ))

    return TransactionStatusChartResponse(data=data, date_range=range_val)


# ---------------------------------------------------------------------------
# Top Customers Chart (HWUI-102)
# ---------------------------------------------------------------------------

def get_top_customers_chart(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
    limit: int = 10,
) -> TopCustomersChartResponse:
    from src.apps.transactions.models.transactions import Transactions
    from src.apps.customers.models.customer import Customer

    period_start, period_end = _get_date_range(range_val)
    date_col = _txn_date_col()

    stmt = (
        select(
            Customer.id.label("customer_id"),
            func.concat(
                func.coalesce(Customer.first_name, ""),
                " ",
                func.coalesce(Customer.last_name, ""),
            ).label("customer_name"),
            func.coalesce(Customer.email, "").label("email"),
            func.count(Transactions.id).label("txn_count"),
            func.coalesce(func.sum(Transactions.txn_amount), 0).label("total_amount"),
        )
        .join(Customer, Transactions.customer_id == Customer.id)
        .where(
            Transactions.merchant_id == merchant_id,
            Transactions.txn_status == TransactionStatusTypes.PAID,
            Transactions.category == TransactionCategories.CHARGE,
            date_col >= period_start,
            date_col <= period_end,
            Customer.deleted_at.is_(None),
        )
        .group_by(Customer.id, Customer.first_name, Customer.last_name, Customer.email)
        .order_by(func.sum(Transactions.txn_amount).desc())
        .limit(limit)
    )

    rows = db.execute(stmt).all()
    data = [
        TopCustomerDataPoint(
            customer_id=int(r.customer_id),
            customer_name=(r.customer_name or "Unknown").strip(),
            email=r.email or "",
            transaction_count=int(r.txn_count or 0),
            total_amount=round(float(r.total_amount or 0), 2),
        )
        for r in rows
    ]

    return TopCustomersChartResponse(data=data, date_range=range_val)


# ---------------------------------------------------------------------------
# Action Items (HWUI-103)
# ---------------------------------------------------------------------------

def get_action_items(
    db: Session,
    merchant_id: int,
    range_val: DateRangeEnum = DateRangeEnum.LAST_7_DAYS,
) -> ActionItemsResponse:
    from src.apps.invoices.models.invoice import Invoice
    from src.apps.transactions.models.transactions import Transactions
    from src.apps.payment_requests.models.payment_request_authorizations import PaymentRequestAuthorizations
    from src.apps.customers.models.customer import Customer

    now = datetime.now(timezone.utc).replace(tzinfo=None)

    # --- Overdue invoices (no date filter — always current state) ---
    overdue_count_stmt = select(func.count(Invoice.id)).where(
        Invoice.merchant_id == merchant_id,
        Invoice.status == InvoiceStatusTypes.OVERDUE,
        Invoice.deleted_at.is_(None),
    )
    overdue_total = db.execute(overdue_count_stmt).scalar() or 0

    overdue_stmt = (
        select(
            Invoice.id,
            Invoice.invoice_id,
            Invoice.invoice_literal,
            Invoice.amount,
            Invoice.due_date,
            func.concat(
                func.coalesce(Customer.first_name, ""),
                " ",
                func.coalesce(Customer.last_name, ""),
            ).label("customer_name"),
        )
        .join(Customer, Invoice.customer_id == Customer.id)
        .where(
            Invoice.merchant_id == merchant_id,
            Invoice.status == InvoiceStatusTypes.OVERDUE,
            Invoice.deleted_at.is_(None),
        )
        .order_by(Invoice.due_date.asc())
        .limit(5)
    )
    overdue_rows = db.execute(overdue_stmt).all()

    overdue_invoices = []
    for r in overdue_rows:
        days_overdue = 0
        if r.due_date:
            diff = now.date() - r.due_date.date()
            days_overdue = max(0, diff.days)
        overdue_invoices.append(OverdueInvoiceItem(
            id=r.id,
            invoice_id=r.invoice_id,
            invoice_literal=r.invoice_literal or r.invoice_id,
            customer_name=(r.customer_name or "Unknown").strip(),
            amount=round(float(r.amount or 0), 2),
            due_date=r.due_date or now,
            days_overdue=days_overdue,
        ))

    # --- Failed transactions (filtered by range) ---
    fail_start, fail_end = _get_date_range(range_val)
    date_col = _txn_date_col()

    failed_count_stmt = select(func.count(Transactions.id)).where(
        Transactions.merchant_id == merchant_id,
        Transactions.txn_status == TransactionStatusTypes.FAILED,
        date_col >= fail_start,
        date_col <= fail_end,
    )
    failed_total = db.execute(failed_count_stmt).scalar() or 0

    failed_stmt = (
        select(
            Transactions.id,
            Transactions.txn_id,
            Transactions.txn_amount,
            func.coalesce(Transactions.occurred_at, Transactions.ocurred_at).label("occurred_at"),
            func.concat(
                func.coalesce(Customer.first_name, ""),
                " ",
                func.coalesce(Customer.last_name, ""),
            ).label("customer_name"),
        )
        .join(Customer, Transactions.customer_id == Customer.id)
        .where(
            Transactions.merchant_id == merchant_id,
            Transactions.txn_status == TransactionStatusTypes.FAILED,
            date_col >= fail_start,
            date_col <= fail_end,
        )
        .order_by(date_col.desc())
        .limit(5)
    )
    failed_rows = db.execute(failed_stmt).all()

    failed_transactions = [
        FailedTransactionItem(
            id=r.id,
            txn_id=r.txn_id,
            customer_name=(r.customer_name or "Unknown").strip(),
            amount=round(float(r.txn_amount or 0), 2),
            occurred_at=r.occurred_at or now,
        )
        for r in failed_rows
    ]

    # --- Pending authorizations (no date filter — current state) ---
    auth_count_stmt = select(func.count(PaymentRequestAuthorizations.id)).where(
        PaymentRequestAuthorizations.merchant_id == merchant_id,
        PaymentRequestAuthorizations.status == AuthorizationStatus.PENDING.value,
        PaymentRequestAuthorizations.deleted_at.is_(None),
    )
    auth_total = db.execute(auth_count_stmt).scalar() or 0

    auth_stmt = (
        select(
            PaymentRequestAuthorizations.id,
            PaymentRequestAuthorizations.authorization_id,
            PaymentRequestAuthorizations.authorization_type,
            PaymentRequestAuthorizations.created_at,
            func.concat(
                func.coalesce(Customer.first_name, ""),
                " ",
                func.coalesce(Customer.last_name, ""),
            ).label("customer_name"),
        )
        .join(Customer, PaymentRequestAuthorizations.customer_id == Customer.id)
        .where(
            PaymentRequestAuthorizations.merchant_id == merchant_id,
            PaymentRequestAuthorizations.status == AuthorizationStatus.PENDING.value,
            PaymentRequestAuthorizations.deleted_at.is_(None),
        )
        .order_by(PaymentRequestAuthorizations.created_at.desc())
        .limit(5)
    )
    auth_rows = db.execute(auth_stmt).all()

    pending_authorizations = [
        PendingAuthorizationItem(
            id=r.id,
            authorization_id=r.authorization_id or str(r.id),
            customer_name=(r.customer_name or "Unknown").strip(),
            authorization_type=r.authorization_type or "",
            created_at=r.created_at or now,
        )
        for r in auth_rows
    ]

    return ActionItemsResponse(
        overdue_invoices=overdue_invoices,
        overdue_invoice_count=int(overdue_total),
        failed_transactions=failed_transactions,
        failed_transaction_count=int(failed_total),
        pending_authorizations=pending_authorizations,
        pending_authorization_count=int(auth_total),
    )
