"""Payer JWT helpers for the public checkout flow."""
import logging
from datetime import datetime, timedelta, timezone
from typing import Optional

import jwt as pyjwt
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session

from src.core.config import settings
from src.core.database import get_db

logger = logging.getLogger(__name__)

_bearer = HTTPBearer(auto_error=False)

PAYER_TOKEN_EXPIRY_HOURS = 24
PAYER_REFRESH_TOKEN_EXPIRY_DAYS = 7


def create_payer_token(user_id: int, merchant_id: int, customer_id: Optional[int] = None) -> str:
    """Mint a short-lived payer JWT (24 hours)."""
    now = datetime.now(timezone.utc)
    expires_at = now + timedelta(hours=PAYER_TOKEN_EXPIRY_HOURS)
    payload = {
        "sub": str(user_id),
        "user_type": "payer",
        "merchant_id": merchant_id,
        "customer_id": customer_id,
        "iat": now,
        "exp": expires_at,
    }
    return pyjwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)


def create_payer_refresh_token(user_id: int, merchant_id: int, customer_id: Optional[int] = None) -> str:
    """Mint a long-lived payer refresh token (7 days). Used only to obtain a new access token."""
    now = datetime.now(timezone.utc)
    expires_at = now + timedelta(days=PAYER_REFRESH_TOKEN_EXPIRY_DAYS)
    payload = {
        "sub": str(user_id),
        "user_type": "payer_refresh",
        "merchant_id": merchant_id,
        "customer_id": customer_id,
        "iat": now,
        "exp": expires_at,
    }
    return pyjwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)


def verify_payer_refresh_token(token: str) -> dict:
    """
    Decode and validate a payer refresh token.
    Raises HTTPException on any failure.
    Returns the decoded payload.
    """
    try:
        payload = pyjwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
    except pyjwt.ExpiredSignatureError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Payer refresh token expired")
    except pyjwt.InvalidTokenError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid payer refresh token")

    if payload.get("user_type") != "payer_refresh":
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a payer refresh token")

    return payload


async def get_current_payer(
    credentials: Optional[HTTPAuthorizationCredentials] = Depends(_bearer),
    db: Session = Depends(get_db),
) -> dict:
    """
    Decode Bearer token and verify user_type == 'payer'.
    Rejects tokens issued before a password-reset session invalidation.
    Returns the decoded payload dict.
    """
    if credentials is None:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Payer token required")
    try:
        payload = pyjwt.decode(
            credentials.credentials,
            settings.JWT_SECRET_KEY,
            algorithms=[settings.JWT_ALGORITHM],
        )
    except pyjwt.ExpiredSignatureError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Payer token expired")
    except pyjwt.InvalidTokenError:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid payer token")

    if payload.get("user_type") != "payer":
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a payer token")

    # Check whether this token pre-dates a password reset (session invalidation)
    user_id = payload.get("sub")
    token_iat = payload.get("iat")
    if user_id and token_iat is not None:
        try:
            import redis as _redis
            redis_url = getattr(settings, "REDIS_URL", None) or getattr(settings, "CELERY_RESULT_BACKEND", "redis://localhost:6379/0")
            r = _redis.from_url(redis_url, decode_responses=True)
            invalidate_ts = r.get(f"payer:invalidate_before:{user_id}")
            if invalidate_ts and float(token_iat) < float(invalidate_ts):
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Session invalidated after password reset. Please log in again.",
                )
        except HTTPException:
            raise
        except Exception:
            pass  # Redis unavailable — allow through rather than breaking all payer sessions

    return payload
