"""
Modern pagination utility for SQLAlchemy queries with Pydantic schemas.

This module provides a QueryPaginator class that handles pagination of SQLAlchemy
queries with automatic serialization to Pydantic models, designed for Python 3.12+.
"""

from __future__ import annotations

import math
from typing import Any, TypeVar, Generic, Union

from pydantic import BaseModel
from sqlalchemy.orm.query import Query
from sqlalchemy.sql import Select
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from src.core.utils.enums import DEFAULT_PER_PAGE, MAX_PER_PAGE

# Type variable for response schema with upper bound constraint
ResponseSchemaType = TypeVar("ResponseSchemaType", bound=BaseModel)


class PaginationResult(BaseModel, Generic[ResponseSchemaType]):
    """Pagination result container with type safety."""
    
    total: int
    page: int
    per_page: int
    next: str | None
    previous: str | None
    first: str
    last: str
    result: list[ResponseSchemaType]


class QueryPaginator(Generic[ResponseSchemaType]):
    """
    Modern SQLAlchemy query paginator with Pydantic schema serialization.

    This paginator provides efficient pagination for SQLAlchemy queries with
    automatic conversion to Pydantic models. It supports both ORM and raw query
    results with comprehensive URL generation for navigation links.

    Args:
        query: SQLAlchemy Query object or Select statement to paginate
        schema: Pydantic model class for response serialization
        db: Database session (required for Select statements)
        url: Base URL for generating pagination links
        offset: Number of records to skip (default: 0)
        limit: Number of records to return (default: DEFAULT_PER_PAGE)
        use_orm: Whether to use ORM mode for serialization (default: True)
        model: SQLAlchemy model class (required for Select statements)

    Example:
        >>> # For legacy Query objects
        >>> paginator = QueryPaginator(
        ...     query=session.query(User),
        ...     schema=UserSchema,
        ...     url="/api/users",
        ...     limit=10
        ... )
        >>> 
        >>> # For modern Select statements
        >>> paginator = QueryPaginator(
        ...     query=select(User),
        ...     schema=UserSchema,
        ...     db=session,
        ...     model=User,
        ...     url="/api/users",
        ...     limit=10
        ... )
        >>> result = paginator.paginate()
    """

    def __init__(
        self,
        query: Union[Query, Select],
        schema: type[ResponseSchemaType],
        url: str,
        *,
        db: Session = None,
        model: Any = None,
        offset: int = 0,
        limit: int = DEFAULT_PER_PAGE,
        use_orm: bool = True,
    ) -> None:
        # Validate and constrain input parameters
        self.query = query
        self.schema = schema
        self.db = db
        self.model = model
        self.url = url.rstrip('/')  # Normalize URL
        self.offset = max(0, offset)  # Ensure non-negative
        self.limit = max(1, min(limit, MAX_PER_PAGE))  # Clamp between 1 and MAX_PER_PAGE
        self.use_orm = use_orm
        
        # Determine query type
        self.is_select_stmt = isinstance(query, Select)
        
        # Validate requirements for Select statements
        if self.is_select_stmt and (db is None or model is None):
            raise ValueError("Database session and model are required for Select statements")
        
        # Cache for computed values
        self._count: int | None = None
        self._records: list[ResponseSchemaType] | None = None

    @property
    def count(self) -> int:
        """Get total record count with caching."""
        if self._count is None:
            if self.is_select_stmt:
                # For Select statements, create a count query maintaining filters
                # Extract the FROM clause and WHERE conditions from the original query
                count_query = select(func.count()).select_from(
                    self.query.froms[0] if self.query.froms else self.model
                )
                
                # Apply the same WHERE conditions as the original query
                if self.query.whereclause is not None:
                    count_query = count_query.where(self.query.whereclause)
                
                self._count = self.db.execute(count_query).scalar()
            else:
                # For legacy Query objects
                self._count = self.query.count()
        return self._count

    @property
    def current_page(self) -> int:
        """Get current page number (1-indexed)."""
        return (self.offset // self.limit) + 1

    @property
    def total_pages(self) -> int:
        """Get total number of pages."""
        return max(1, math.ceil(self.count / self.limit))

    @property
    def has_next(self) -> bool:
        """Check if there's a next page."""
        return self.offset + self.limit < self.count

    @property
    def has_previous(self) -> bool:
        """Check if there's a previous page."""
        return self.offset > 0

    def _build_url(self, page: int, per_page: int | None = None) -> str:
        """
        Build pagination URL for the given page.
        
        Args:
            page: Page number (1-indexed)
            per_page: Items per page (defaults to current limit)
            
        Returns:
            Formatted URL with pagination parameters
        """
        per_page = per_page or self.limit
        return f"{self.url}?page={page}&per_page={per_page}"

    @property
    def next_url(self) -> str | None:
        """Get URL for next page, or None if no next page exists."""
        if not self.has_next:
            return None
        return self._build_url(self.current_page + 1)

    @property
    def previous_url(self) -> str | None:
        """Get URL for previous page, or None if no previous page exists."""
        if not self.has_previous:
            return None
        return self._build_url(self.current_page - 1)

    @property
    def first_url(self) -> str:
        """Get URL for first page."""
        return self._build_url(1)

    @property
    def last_url(self) -> str:
        """Get URL for last page."""
        return self._build_url(self.total_pages)

    def get_records(self) -> list[ResponseSchemaType]:
        """
        Retrieve and serialize the paginated records.
        
        Returns:
            List of serialized Pydantic models for the current page
            
        Raises:
            ValueError: If serialization fails
        """
        if self._records is not None:
            return self._records

        # Execute the paginated query
        if self.is_select_stmt:
            # For Select statements
            paginated_query = self.query.offset(self.offset).limit(self.limit)
            raw_records = self.db.execute(paginated_query).scalars().all()
        else:
            # For legacy Query objects
            raw_records = self.query.offset(self.offset).limit(self.limit).all()
        
        # Serialize records based on mode
        try:
            if self.use_orm:
                # Use Pydantic's model_validate method for SQLAlchemy ORM objects
                self._records = [self.schema.model_validate(record) for record in raw_records]
            else:
                # Handle named tuples or row objects from raw SQL
                self._records = [
                    self.schema.model_validate(record._asdict() if hasattr(record, '_asdict') else record)
                    for record in raw_records
                ]
        except Exception as e:
            raise ValueError(f"Failed to serialize records with schema {self.schema.__name__}: {e}") from e

        return self._records

    def paginate(self) -> PaginationResult[ResponseSchemaType]:
        """
        Execute pagination and return complete result.
        
        Returns:
            PaginationResult containing all pagination data and serialized records
            
        Example:
            >>> result = paginator.paginate()
            >>> print(f"Page {result.page} of {math.ceil(result.total / result.per_page)}")
            >>> for item in result.result:
            ...     print(item.name)
        """
        return PaginationResult[ResponseSchemaType](
            total=self.count,
            page=self.current_page,
            per_page=self.limit,
            next=self.next_url,
            previous=self.previous_url,
            first=self.first_url,
            last=self.last_url,
            result=self.get_records(),
        )

    def to_dict(self) -> dict[str, Any]:
        """
        Convert pagination result to dictionary format.
        
        This method provides backward compatibility with the original implementation
        while maintaining the same response structure.
        
        Returns:
            Dictionary containing pagination metadata and results
        """
        paginated_result = self.paginate()
        return {
            "total": paginated_result.total,
            "page": paginated_result.page,
            "per_page": paginated_result.per_page,
            "next": paginated_result.next,
            "previous": paginated_result.previous,
            "first": paginated_result.first,
            "last": paginated_result.last,
            "result": [item.model_dump() for item in paginated_result.result],
        }

    def __repr__(self) -> str:
        """String representation for debugging."""
        return (
            f"{self.__class__.__name__}("
            f"schema={self.schema.__name__}, "
            f"page={self.current_page}, "
            f"per_page={self.limit}, "
            f"total={self.count if self._count is not None else 'unknown'}"
            f")"
        )
