"""
Discounts Report Service
"""
import logging
from datetime import date
from typing import Optional

from sqlalchemy import select, func, and_, or_
from sqlalchemy.orm import Session

from src.apps.reports.schemas.discounts_report import (
    DiscountReportRow, DiscountReportResponse, DiscountReportSummaryResponse
)

logger = logging.getLogger(__name__)


def get_discounts_report(
    db: Session,
    merchant_id: int,
    search: Optional[str] = None,
    date_from: Optional[date] = None,
    date_to: Optional[date] = None,
    min_amount: Optional[float] = None,
    max_amount: Optional[float] = None,
) -> DiscountReportResponse:
    """Get discounts report grouped by discount definition."""
    from src.apps.payment_requests.models.payment_request_adjustments import PaymentRequestAdjustments
    from src.apps.payment_requests.models.payment_request import PaymentRequest
    from src.apps.merchants.models.merchant_discount import MerchantDiscount
    from datetime import datetime

    conditions = [
        PaymentRequestAdjustments.is_discounted == True,
        PaymentRequest.merchant_id == merchant_id,
        PaymentRequest.deleted_at.is_(None),
    ]

    if date_from:
        conditions.append(PaymentRequest.created_at >= datetime.combine(date_from, datetime.min.time()))
    if date_to:
        conditions.append(PaymentRequest.created_at <= datetime.combine(date_to, datetime.max.time()))
    if min_amount is not None:
        conditions.append(PaymentRequestAdjustments.discount_amount >= min_amount)
    if max_amount is not None:
        conditions.append(PaymentRequestAdjustments.discount_amount <= max_amount)

    discount_name_expr = func.coalesce(MerchantDiscount.title, PaymentRequestAdjustments.discount_name, 'Manual Discount')

    stmt = (
        select(
            discount_name_expr.label("discount_name"),
            PaymentRequestAdjustments.discount_type.label("discount_type"),
            func.count(PaymentRequestAdjustments.id).label("discount_applied"),
            func.sum(PaymentRequestAdjustments.discount_amount).label("discount_amount_total"),
        )
        .join(PaymentRequest, PaymentRequestAdjustments.payment_request_id == PaymentRequest.id)
        .outerjoin(MerchantDiscount, PaymentRequestAdjustments.discount_id == MerchantDiscount.id)
        .where(and_(*conditions))
        .group_by(discount_name_expr, PaymentRequestAdjustments.discount_type)
    )

    if search:
        stmt = stmt.where(
            or_(
                MerchantDiscount.title.ilike(f"%{search}%"),
                PaymentRequestAdjustments.discount_name.ilike(f"%{search}%"),
            )
        )

    results = db.execute(stmt).all()

    rows = [
        DiscountReportRow(
            discount_name=r.discount_name or "Manual Discount",
            discount_type=r.discount_type or "amount",
            discount_applied=r.discount_applied or 0,
            discount_amount_total=round(float(r.discount_amount_total or 0), 2),
        )
        for r in results
    ]

    total = round(sum(r.discount_amount_total for r in rows), 2)
    return DiscountReportResponse(rows=rows, total_amount=total)


def get_discounts_summary(
    db: Session,
    merchant_id: int,
    date_from: Optional[date] = None,
    date_to: Optional[date] = None,
) -> DiscountReportSummaryResponse:
    """Get total discounts applied and total amount discounted."""
    report = get_discounts_report(db=db, merchant_id=merchant_id, date_from=date_from, date_to=date_to)
    total_applied = sum(r.discount_applied for r in report.rows)
    total_amount = round(sum(r.discount_amount_total for r in report.rows), 2)
    return DiscountReportSummaryResponse(
        total_discounts_applied=total_applied,
        total_amount_discounted=total_amount,
    )
