"""
Response Formatting Middleware
"""
import json
from typing import Callable

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, StreamingResponse


class CommonResponseMiddleware(BaseHTTPMiddleware):
    """
    Middleware to format all API responses with a common structure.
    
    Wraps response data in a consistent format:
    {
        "data": <original_response>,
        "success": true,
        "message": null
    }
    
    Only applies to JSON responses from specific HTTP methods.
    """
    
    def __init__(
        self, 
        app, 
        methods: list[str] | None = None,
        wrap_responses: bool = True
    ):
        super().__init__(app)
        self.methods = methods or ["GET", "POST", "PUT", "PATCH", "DELETE"]
        self.wrap_responses = wrap_responses
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """Format response with common structure"""
        response: Response = await call_next(request)
        
        # Only process specific HTTP methods
        if request.method not in self.methods or not self.wrap_responses:
            return response
        
        # Skip wrapping for non-JSON responses or streaming responses
        if not isinstance(response, (JSONResponse, StreamingResponse)):
            return response
        
        # Skip wrapping if response is already an error (4xx, 5xx)
        if response.status_code >= 400:
            return response
        
        try:
            # Handle StreamingResponse
            if isinstance(response, StreamingResponse):
                content = b""
                async for chunk in response.body_iterator:
                    content += chunk
                
                # Parse the content as JSON
                content_str = content.decode('utf-8')
                parsed_content = json.loads(content_str)
                
                # Wrap in common response format
                wrapped_response = {
                    "data": parsed_content,
                    "success": True,
                    "message": None
                }
                
                return JSONResponse(
                    content=wrapped_response,
                    status_code=response.status_code,
                    headers=dict(response.headers)
                )
            
            # Handle JSONResponse
            elif isinstance(response, JSONResponse):
                # Get the original content
                original_content = response.body.decode('utf-8')
                parsed_content = json.loads(original_content)
                
                # Check if already wrapped (avoid double-wrapping)
                if isinstance(parsed_content, dict) and "data" in parsed_content:
                    return response
                
                # Wrap in common response format
                wrapped_response = {
                    "data": parsed_content,
                    "success": True,
                    "message": None
                }
                
                return JSONResponse(
                    content=wrapped_response,
                    status_code=response.status_code,
                    headers=dict(response.headers)
                )
                
        except (json.JSONDecodeError, UnicodeDecodeError) as e:
            # If we can't parse the response, return it as-is
            return response
        
        return response


async def common_response_middleware_function(request: Request, call_next: Callable) -> Response:
    """
    Functional version of CommonResponsesrc.middleware.
    
    Simplified version that wraps JSON responses in a common format.
    """
    response: StreamingResponse = await call_next(request)
    
    # Only process specific methods
    if request.method not in ["GET", "POST", "PUT", "PATCH", "DELETE"]:
        return response
    
    # Skip error responses
    if response.status_code >= 400:
        return response
    
    try:
        content = b""
        content_str = ""
        
        async for chunk in response.body_iterator:
            content_str += chunk.decode()
            content += chunk
        
        # Parse and wrap the response
        parsed_content = json.loads(content_str)
        
        # Check if already wrapped
        if isinstance(parsed_content, dict) and "data" in parsed_content:
            wrapped_response = parsed_content
        else:
            wrapped_response = {
                "data": parsed_content,
                "success": True,
                "message": None
            }
        
        return JSONResponse(
            content=wrapped_response,
            status_code=response.status_code,
            headers=dict(response.headers),
        )
        
    except (json.JSONDecodeError, UnicodeDecodeError):
        # Return original response if parsing fails
        return response
