""" Authentication and Authorization middleware for FastAPI applications. This module provides authentication decorator for protecting endpoints and a middleware for request timing measurements. """ from time import perf_counter from typing import Callable, Optional, Dict, Any, Tuple from functools import wraps from fastapi import HTTPException, Request, Response, status from starlette.middleware.base import BaseHTTPMiddleware from AllConfigs.Token.config import Auth from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi class MiddlewareModule: """ Module containing authentication and middleware functionality. This class provides: - Token extraction and validation - Authentication decorator for endpoints - Request timing middleware """ @staticmethod def get_access_token(request: Request) -> Tuple[str, str]: """ Extract access token from request headers. Args: request: FastAPI request object Returns: Tuple[str, str]: A tuple containing (scheme, token) Raises: HTTPExceptionApi: If token is missing or malformed """ auth_header = request.headers.get(Auth.ACCESS_TOKEN_TAG) if not auth_header: raise HTTPExceptionApi( status_code=status.HTTP_401_UNAUTHORIZED, detail="No authorization header", ) try: scheme, token = auth_header.split() if scheme.lower() != "bearer": raise HTTPExceptionApi( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication scheme", ) return scheme, token except ValueError: raise HTTPExceptionApi( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token format" ) @staticmethod async def validate_token(token: str) -> Dict[str, Any]: """ Validate the authentication token. Args: token: JWT token to validate Returns: Dict[str, Any]: User data extracted from token Raises: HTTPExceptionApi: If token is invalid """ try: # TODO: Implement your token validation logic # Example: # return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) return {"user_id": "test", "role": "user"} # Placeholder except Exception as e: raise HTTPExceptionApi( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Token validation failed: {str(e)}", ) @classmethod def auth_required(cls, func: Callable) -> Callable: """ Decorator for protecting FastAPI endpoints with authentication. Usage: @router.get("/protected") @MiddlewareModule.auth_required async def protected_endpoint(request: Request): user = request.state.user # Access authenticated user data return {"message": "Protected content"} @router.get("/public") # No decorator = public endpoint async def public_endpoint(): return {"message": "Public content"} Args: func: The FastAPI route handler function to protect Returns: Callable: Wrapped function that checks authentication before execution """ @wraps(func) async def wrapper(request: Request, *args, **kwargs): try: # Get token from header _, token = cls.get_access_token(request) # Validate token and get user data token_data = await cls.validate_token(token) # Add user data to request state for use in endpoint request.state.user = token_data # Call the original endpoint function return await func(request, *args, **kwargs) except HTTPExceptionApi: raise except Exception as e: raise HTTPExceptionApi( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"Authentication failed: {str(e)}", ) return wrapper class RequestTimingMiddleware(BaseHTTPMiddleware): """ Middleware for measuring and logging request timing. Only handles timing, no authentication. """ async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process each request through the middleware. Args: request: FastAPI request object call_next: Next middleware in the chain Returns: Response: Processed response with timing headers """ start_time = perf_counter() # Process the request response = await call_next(request) # Add timing information to response headers self._add_timing_headers(response, start_time) return response @staticmethod def _add_timing_headers(response: Response, start_time: float) -> None: """ Add request timing information to response headers. Args: response: FastAPI response object start_time: Time when request processing started """ end_time = perf_counter() elapsed = (end_time - start_time) * 1000 # Convert to milliseconds response.headers.update( { "request-start": f"{start_time:.6f}", "request-end": f"{end_time:.6f}", "request-duration": f"{elapsed:.2f}ms", } )