"""
JWT Token utilities for authentication.
"""

from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
import jwt
from pydantic import EmailStr

from src.core.config import settings


class JWTManager:
    """JWT token management utility."""
    
    @staticmethod
    def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
        """
        Create JWT access token.
        
        Args:
            data: Data to encode in the token
            expires_delta: Custom expiration time
            
        Returns:
            Encoded JWT token
        """
        to_encode = data.copy()
        
        if expires_delta:
            expire = datetime.now(timezone.utc) + expires_delta
        else:
            expire = datetime.now(timezone.utc) + timedelta(minutes=settings.jwt_expires_minutes)
        
        to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)})
        
        encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
        return encoded_jwt
    
    @staticmethod
    def create_refresh_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
        """
        Create JWT refresh token.
        
        Args:
            data: Data to encode in the token
            expires_delta: Custom expiration time (default from JWT_REFRESH_EXPIRES)
            
        Returns:
            Encoded JWT refresh token
        """
        to_encode = data.copy()
        
        if expires_delta:
            expire = datetime.now(timezone.utc) + expires_delta
        else:
            expire = datetime.now(timezone.utc) + timedelta(days=settings.jwt_refresh_expires_days)
        
        to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc), "type": "refresh"})
        
        encoded_jwt = jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
        return encoded_jwt
    
    @staticmethod
    def verify_token(token: str) -> Optional[Dict[str, Any]]:
        """
        Verify and decode JWT token.
        
        Args:
            token: JWT token to verify
            
        Returns:
            Decoded token data or None if invalid
        """
        try:
            payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
            return payload
        except jwt.ExpiredSignatureError:
            return None
        except jwt.PyJWTError:
            return None
    
    @staticmethod
    def is_token_expired(token: str) -> bool:
        """
        Check if token is expired.
        
        Args:
            token: JWT token to check
            
        Returns:
            True if expired, False otherwise
        """
        try:
            payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
            exp = payload.get("exp")
            if exp is None:
                return True
            return datetime.fromtimestamp(exp, tz=timezone.utc) < datetime.now(timezone.utc)
        except jwt.PyJWTError:
            return True
    
    @staticmethod
    def get_user_id_from_token(token: str) -> Optional[int]:
        """
        Extract user ID from token.
        
        Args:
            token: JWT token
            
        Returns:
            User ID or None if invalid
        """
        payload = JWTManager.verify_token(token)
        if payload:
            return payload.get("user_id")
        return None
    
    @staticmethod
    def get_user_email_from_token(token: str) -> Optional[str]:
        """
        Extract user email from token.
        
        Args:
            token: JWT token
            
        Returns:
            User email or None if invalid
        """
        payload = JWTManager.verify_token(token)
        if payload:
            return payload.get("email")
        return None


# Global JWT manager instance
jwt_manager = JWTManager()
