from __future__ import annotations
from datetime import datetime, timezone, timedelta
from typing import List, Optional, Tuple
from sqlalchemy import select, func, and_, or_
from sqlalchemy.orm import Session

from src.apps.merchants.models.merchant import Merchant
from src.apps.merchants.models.merchant_users import MerchantUsers
from src.apps.admin.models.admin_audit_log import AdminAuditLog


def get_merchants_paginated(
    db: Session,
    page: int = 1,
    per_page: int = 20,
    search: Optional[str] = None,
    status: Optional[str] = None,
    is_active: Optional[bool] = None,
    is_onboarded: Optional[bool] = None,
) -> Tuple[List[Merchant], int]:
    stmt = select(Merchant).where(Merchant.deleted_at.is_(None))
    if search:
        stmt = stmt.where(
            or_(
                Merchant.name.ilike(f"%{search}%"),
                Merchant.email.ilike(f"%{search}%"),
                Merchant.uin.ilike(f"%{search}%"),
            )
        )
    if is_active is not None:
        stmt = stmt.where(Merchant.is_active == is_active)
    if is_onboarded is not None:
        stmt = stmt.where(Merchant.is_onboarded == is_onboarded)
    count_stmt = select(func.count()).select_from(stmt.subquery())
    total = db.execute(count_stmt).scalar_one()
    stmt = stmt.offset((page - 1) * per_page).limit(per_page)
    items = list(db.execute(stmt).scalars().all())
    return items, total


def get_merchant_by_id(db: Session, merchant_id: int) -> Optional[Merchant]:
    stmt = select(Merchant).where(Merchant.id == merchant_id, Merchant.deleted_at.is_(None))
    return db.execute(stmt).scalar_one_or_none()


def get_audit_logs(
    db: Session,
    page: int = 1,
    per_page: int = 20,
    admin_user_id: Optional[int] = None,
    action: Optional[str] = None,
    date_from: Optional[datetime] = None,
    date_to: Optional[datetime] = None,
) -> Tuple[List[AdminAuditLog], int]:
    stmt = select(AdminAuditLog)
    if admin_user_id:
        stmt = stmt.where(AdminAuditLog.admin_user_id == admin_user_id)
    if action:
        stmt = stmt.where(AdminAuditLog.action == action)
    if date_from:
        stmt = stmt.where(AdminAuditLog.created_at >= date_from)
    if date_to:
        stmt = stmt.where(AdminAuditLog.created_at <= date_to)
    stmt = stmt.order_by(AdminAuditLog.created_at.desc())
    count_stmt = select(func.count()).select_from(stmt.subquery())
    total = db.execute(count_stmt).scalar_one()
    stmt = stmt.offset((page - 1) * per_page).limit(per_page)
    items = list(db.execute(stmt).scalars().all())
    return items, total
