import logging
from typing import Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from src.apps.payment_providers.models.payment_provider import PaymentProvider
from src.apps.payment_providers.models.merchant_provider_config import MerchantProviderConfig
from src.apps.payment_providers.models.merchant_provider_credential import MerchantProviderCredential
from src.apps.payment_providers.helpers.credentials import encrypt_credential, decrypt_credential

logger = logging.getLogger(__name__)


def get_provider_by_slug(slug: str, db: Session) -> Optional[PaymentProvider]:
    stmt = select(PaymentProvider).where(
        PaymentProvider.slug == slug,
        PaymentProvider.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def get_provider_by_id(provider_id: int, db: Session) -> Optional[PaymentProvider]:
    stmt = select(PaymentProvider).where(
        PaymentProvider.id == provider_id,
        PaymentProvider.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def get_all_providers(db: Session) -> List[PaymentProvider]:
    stmt = select(PaymentProvider).where(
        PaymentProvider.deleted_at.is_(None),
    ).order_by(PaymentProvider.name)
    return list(db.execute(stmt).scalars().all())


def get_active_providers(db: Session) -> List[PaymentProvider]:
    stmt = select(PaymentProvider).where(
        PaymentProvider.is_active == True,
        PaymentProvider.deleted_at.is_(None),
    ).order_by(PaymentProvider.name)
    return list(db.execute(stmt).scalars().all())


def get_provider_id_by_slug(slug: str, db: Session) -> Optional[int]:
    provider = get_provider_by_slug(slug, db)
    return provider.id if provider else None


def get_merchant_provider_config(
    merchant_id: int, provider_id: int, db: Session
) -> Optional[MerchantProviderConfig]:
    stmt = select(MerchantProviderConfig).where(
        MerchantProviderConfig.merchant_id == merchant_id,
        MerchantProviderConfig.provider_id == provider_id,
        MerchantProviderConfig.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def get_available_providers_for_merchant(merchant_id: int, db: Session) -> List[MerchantProviderConfig]:
    stmt = select(MerchantProviderConfig).where(
        MerchantProviderConfig.merchant_id == merchant_id,
        MerchantProviderConfig.is_active == True,
        MerchantProviderConfig.deleted_at.is_(None),
    )
    return list(db.execute(stmt).scalars().all())


def upsert_merchant_provider_config(
    merchant_id: int,
    provider_id: int,
    onboarding_status: str,
    config_data: Optional[dict],
    db: Session,
) -> MerchantProviderConfig:
    existing = get_merchant_provider_config(merchant_id, provider_id, db)
    if existing:
        existing.onboarding_status = onboarding_status
        if config_data is not None:
            existing.config_data = config_data
        db.flush()
        return existing
    new_config = MerchantProviderConfig(
        merchant_id=merchant_id,
        provider_id=provider_id,
        onboarding_status=onboarding_status,
        config_data=config_data,
        is_active=True,
    )
    db.add(new_config)
    db.flush()
    return new_config


def upsert_credential(
    config_id: int, key: str, value: str, db: Session
) -> MerchantProviderCredential:
    """Upsert an encrypted credential. Value should be the raw (plaintext) string."""
    encrypted = encrypt_credential(value)
    stmt = select(MerchantProviderCredential).where(
        MerchantProviderCredential.merchant_provider_config_id == config_id,
        MerchantProviderCredential.credential_key == key,
    )
    existing = db.execute(stmt).scalar_one_or_none()
    if existing:
        existing.credential_value = encrypted
        db.flush()
        return existing
    new_cred = MerchantProviderCredential(
        merchant_provider_config_id=config_id,
        credential_key=key,
        credential_value=encrypted,
    )
    db.add(new_cred)
    db.flush()
    return new_cred


def get_credentials(config_id: int, db: Session) -> dict:
    """Return a dict of {credential_key: decrypted_value} for a config.

    If an individual credential fails to decrypt (e.g. key rotation mismatch),
    that key is omitted from the result and a warning is logged.  The remaining
    credentials are still returned so that partially-broken configs don't take
    the entire feature down.
    """
    stmt = select(MerchantProviderCredential).where(
        MerchantProviderCredential.merchant_provider_config_id == config_id,
    )
    creds = db.execute(stmt).scalars().all()
    result = {}
    for c in creds:
        try:
            result[c.credential_key] = decrypt_credential(c.credential_value)
        except Exception:
            # Log the key name only — never log the ciphertext or the raw value.
            logger.warning(
                "op=get_credentials config_id=%s credential_key=%s result=decrypt_failed — skipping key",
                config_id,
                c.credential_key,
            )
    return result


def get_merchant_by_id(merchant_id: int, db: Session):
    """Return a Merchant ORM object by primary key, or None."""
    from src.apps.merchants.models.merchant import Merchant
    stmt = select(Merchant).where(
        Merchant.id == merchant_id,
        Merchant.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def update_merchant_active_provider(
    merchant_id: int, provider_id: Optional[int], db: Session
) -> None:
    """Set the active_provider_id on a merchant row and flush."""
    merchant = get_merchant_by_id(merchant_id, db)
    if merchant is None:
        from src.core.exceptions import NotFoundError
        raise NotFoundError(message=f"Merchant {merchant_id} not found")
    merchant.active_provider_id = provider_id
    db.flush()


def get_completed_credential_keys(config_id: int, db: Session) -> List[str]:
    stmt = select(MerchantProviderCredential.credential_key).where(
        MerchantProviderCredential.merchant_provider_config_id == config_id,
    )
    return [row[0] for row in db.execute(stmt).all()]


def delete_credentials(config_id: int, db: Session) -> None:
    """Hard-delete all credentials for a config (used when resetting onboarding)."""
    stmt = select(MerchantProviderCredential).where(
        MerchantProviderCredential.merchant_provider_config_id == config_id,
    )
    for cred in db.execute(stmt).scalars().all():
        db.delete(cred)
    db.flush()


def get_default_provider(db: Session) -> Optional[PaymentProvider]:
    """Return the default active provider, or None if none is configured."""
    stmt = select(PaymentProvider).where(
        PaymentProvider.is_default == True,
        PaymentProvider.is_active == True,
        PaymentProvider.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar_one_or_none()


def count_configured_providers_for_merchant(merchant_id: int, db: Session) -> int:
    """Count providers where onboarding_status != 'not_started' for this merchant."""
    stmt = select(func.count()).select_from(MerchantProviderConfig).where(
        MerchantProviderConfig.merchant_id == merchant_id,
        MerchantProviderConfig.onboarding_status != "not_started",
        MerchantProviderConfig.is_active == True,
        MerchantProviderConfig.deleted_at.is_(None),
    )
    return db.execute(stmt).scalar() or 0


def clear_default_flag(exclude_provider_id: int, db: Session) -> None:
    """Clear is_default=True on all providers except the given one."""
    stmt = select(PaymentProvider).where(
        PaymentProvider.is_default == True,
        PaymentProvider.id != exclude_provider_id,
        PaymentProvider.deleted_at.is_(None),
    )
    others = db.execute(stmt).scalars().all()
    for other in others:
        other.is_default = False
    db.flush()
