"""
Payment Methods Report Service
"""
import logging
from datetime import date
from typing import Optional, List

from sqlalchemy import select, func, case, and_
from sqlalchemy.orm import Session

from src.apps.reports.schemas.payment_methods_report import (
    PaymentMethodRow, PaymentMethodSubRow, PaymentMethodReportResponse,
    PaymentMethodSummaryItem, PaymentMethodSummaryResponse
)

logger = logging.getLogger(__name__)

CARD_BRANDS = ["Visa", "Mastercard", "Discover", "Amex"]


def get_payment_methods_report(
    db: Session,
    merchant_id: int,
    date_from: Optional[date] = None,
    date_to: Optional[date] = None,
    method: Optional[List[str]] = None,
) -> PaymentMethodReportResponse:
    """Get payment methods report grouped by method and brand."""
    from src.apps.transactions.models.transactions import Transactions
    from src.apps.payment_methods.models.payment_methods import PaymentMethod
    from src.apps.payment_methods.models.payment_method_card_details import PaymentMethodCardDetails
    from src.core.utils.enums import TransactionStatusTypes

    PAID = TransactionStatusTypes.PAID.value
    REFUNDED = TransactionStatusTypes.REFUNDED.value
    PARTIALLY_REFUNDED = TransactionStatusTypes.PARTIALLY_REFUNDED.value

    # Build base conditions
    base_conditions = [
        Transactions.merchant_id == merchant_id,
        PaymentMethod.deleted_at.is_(None),
    ]
    if date_from:
        from datetime import datetime
        base_conditions.append(Transactions.ocurred_at >= datetime.combine(date_from, datetime.min.time()))
    if date_to:
        from datetime import datetime
        base_conditions.append(Transactions.ocurred_at <= datetime.combine(date_to, datetime.max.time()))

    # --- ACH query ---
    ach_conditions = base_conditions + [PaymentMethod.method == 'ach']
    if method and 'ach' not in method:
        ach_row = None
    else:
        ach_stmt = (
            select(
                func.count(case((Transactions.txn_status == PAID, 1))).label("payment_count"),
                func.count(case((Transactions.txn_status.in_([REFUNDED, PARTIALLY_REFUNDED]), 1))).label("refund_count"),
                func.coalesce(func.sum(case((Transactions.txn_status == PAID, Transactions.txn_amount), else_=0)), 0).label("payment_amount"),
                func.coalesce(func.sum(case((Transactions.txn_status.in_([REFUNDED, PARTIALLY_REFUNDED]), Transactions.txn_amount), else_=0)), 0).label("refund_amount"),
                func.coalesce(func.sum(func.coalesce(Transactions.platform_fee_amount, 0)), 0).label("fees"),
            )
            .join(PaymentMethod, Transactions.payment_method_id == PaymentMethod.id)
            .where(and_(*ach_conditions))
        )
        result = db.execute(ach_stmt).one()
        pa = float(result.payment_amount or 0)
        ra = float(result.refund_amount or 0)
        fees = float(result.fees or 0)
        ach_row = PaymentMethodRow(
            method="ach",
            payment_count=result.payment_count or 0,
            refund_count=result.refund_count or 0,
            payment_amount=pa,
            refund_amount=ra,
            fees=fees,
            net_settlement=round(pa - ra - fees, 2),
            sub_rows=[],
        )

    # --- Card query (by brand) ---
    if method and 'card' not in method:
        card_row = None
    else:
        card_conditions = base_conditions + [PaymentMethod.method == 'card']
        brand_stmt = (
            select(
                func.coalesce(
                    func.concat(
                        func.upper(func.substr(PaymentMethodCardDetails.brand, 1, 1)),
                        func.lower(func.substr(PaymentMethodCardDetails.brand, 2)),
                    ),
                    'Other'
                ).label("brand"),
                func.count(case((Transactions.txn_status == PAID, 1))).label("payment_count"),
                func.count(case((Transactions.txn_status.in_([REFUNDED, PARTIALLY_REFUNDED]), 1))).label("refund_count"),
                func.coalesce(func.sum(case((Transactions.txn_status == PAID, Transactions.txn_amount), else_=0)), 0).label("payment_amount"),
                func.coalesce(func.sum(case((Transactions.txn_status.in_([REFUNDED, PARTIALLY_REFUNDED]), Transactions.txn_amount), else_=0)), 0).label("refund_amount"),
                func.coalesce(func.sum(func.coalesce(Transactions.platform_fee_amount, 0)), 0).label("fees"),
            )
            .join(PaymentMethod, Transactions.payment_method_id == PaymentMethod.id)
            .outerjoin(PaymentMethodCardDetails, PaymentMethod.card_details_id == PaymentMethodCardDetails.id)
            .where(and_(*card_conditions))
            .group_by("brand")
        )
        brand_results = db.execute(brand_stmt).all()

        brand_map = {}
        for r in brand_results:
            brand_name = r.brand or 'Other'
            pa = float(r.payment_amount or 0)
            ra = float(r.refund_amount or 0)
            fees = float(r.fees or 0)
            brand_map[brand_name] = PaymentMethodSubRow(
                brand=brand_name,
                payment_count=r.payment_count or 0,
                refund_count=r.refund_count or 0,
                payment_amount=pa,
                refund_amount=ra,
                fees=fees,
                net_settlement=round(pa - ra - fees, 2),
            )

        # Always include Visa/Mastercard/Discover/Amex rows (zero if missing)
        sub_rows = []
        all_brands = CARD_BRANDS + [b for b in brand_map.keys() if b not in CARD_BRANDS and b != 'Other']
        for b in all_brands:
            if b in brand_map:
                sub_rows.append(brand_map[b])
            else:
                sub_rows.append(PaymentMethodSubRow(brand=b, payment_count=0, refund_count=0, payment_amount=0.0, refund_amount=0.0, fees=0.0, net_settlement=0.0))
        if 'Other' in brand_map:
            sub_rows.append(brand_map['Other'])

        card_pa = sum(sr.payment_amount for sr in sub_rows)
        card_ra = sum(sr.refund_amount for sr in sub_rows)
        card_fees = sum(sr.fees for sr in sub_rows)
        card_pc = sum(sr.payment_count for sr in sub_rows)
        card_rc = sum(sr.refund_count for sr in sub_rows)

        card_row = PaymentMethodRow(
            method="card",
            payment_count=card_pc,
            refund_count=card_rc,
            payment_amount=round(card_pa, 2),
            refund_amount=round(card_ra, 2),
            fees=round(card_fees, 2),
            net_settlement=round(card_pa - card_ra - card_fees, 2),
            sub_rows=sub_rows,
        )

    # Always add cash and cheque as static zero rows
    cash_row = PaymentMethodRow(method="cash", payment_count=0, refund_count=0, payment_amount=0.0, refund_amount=0.0, fees=0.0, net_settlement=0.0, sub_rows=[])
    cheque_row = PaymentMethodRow(method="cheque", payment_count=0, refund_count=0, payment_amount=0.0, refund_amount=0.0, fees=0.0, net_settlement=0.0, sub_rows=[])

    rows = []
    if card_row:
        rows.append(card_row)
    if ach_row:
        rows.append(ach_row)

    # Apply method filter to static rows
    if not method or 'cash' in method:
        rows.append(cash_row)
    if not method or 'cheque' in method:
        rows.append(cheque_row)

    total_payment = round(sum(r.payment_amount for r in rows), 2)
    total_refund = round(sum(r.refund_amount for r in rows), 2)
    total_net = round(sum(r.net_settlement for r in rows), 2)

    return PaymentMethodReportResponse(
        rows=rows,
        total_payment_amount=total_payment,
        total_refund_amount=total_refund,
        total_net_settlement=total_net,
    )


def get_payment_methods_summary(
    db: Session,
    merchant_id: int,
    date_from: Optional[date] = None,
    date_to: Optional[date] = None,
) -> PaymentMethodSummaryResponse:
    """Get top-3 most-used and highest-volume payment methods (Card and ACH only)."""
    report = get_payment_methods_report(db=db, merchant_id=merchant_id, date_from=date_from, date_to=date_to)
    # Exclude cash and cheque from summary
    db_rows = [r for r in report.rows if r.method in ('card', 'ach')]

    most_used = sorted(db_rows, key=lambda r: r.payment_count, reverse=True)[:3]
    highest_volume = sorted(db_rows, key=lambda r: r.payment_amount, reverse=True)[:3]

    return PaymentMethodSummaryResponse(
        most_used=[PaymentMethodSummaryItem(method=r.method, count=r.payment_count, amount=r.payment_amount) for r in most_used],
        highest_volume=[PaymentMethodSummaryItem(method=r.method, count=r.payment_count, amount=r.payment_amount) for r in highest_volume],
    )
