""" Authentication and Authorization middleware for FastAPI applications. This module provides authentication decorator for protecting endpoints and a middleware for request timing measurements. """ from time import perf_counter from typing import Callable, Optional, Dict, Any, Tuple, Union from functools import wraps from starlette.middleware.base import BaseHTTPMiddleware from fastapi import Request, Response from AllConfigs.Token.config import Auth from ApiLibrary.common.line_number import get_line_number_for_error from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi from .base_context import BaseContext from ApiServices.Token.token_handler import OccupantTokenObject, EmployeeTokenObject class AuthContext(BaseContext): """ Context class for authentication middleware. Extends BaseContext to provide authentication-specific functionality. """ def __init__( self, token_context: Union[OccupantTokenObject, EmployeeTokenObject] ) -> None: super().__init__() self.token_context = token_context @property def is_employee(self) -> bool: """Check if authenticated token is for an employee.""" return isinstance(self.token_context, EmployeeTokenObject) @property def is_occupant(self) -> bool: """Check if authenticated token is for an occupant.""" return isinstance(self.token_context, OccupantTokenObject) @property def user_id(self) -> str: """Get the user's UUID from token context.""" return self.token_context.user_uu_id if self.token_context else "" def __repr__(self) -> str: user_type = "Employee" if self.is_employee else "Occupant" return f"AuthContext({user_type}Token: {self.user_id})" class MiddlewareModule: """ Middleware module for handling authentication and request timing. """ @staticmethod def get_user_from_request( request: Request, ) -> AuthContext: """ Get authenticated token context from request. Args: request: FastAPI request object Returns: AuthContext: Context containing the authenticated token data Raises: HTTPExceptionApi: If token is missing, invalid, or user not found """ from ApiServices.Token.token_handler import TokenService # Get token and validate - will raise HTTPExceptionApi if invalid redis_token = TokenService.get_access_token_from_request(request=request) # Get token context - will validate token and raise appropriate errors token_context = TokenService.get_object_via_access_key(access_token=redis_token) if not token_context: raise HTTPExceptionApi( error_code="USER_NOT_FOUND", lang="tr", loc=get_line_number_for_error(), sys_msg="TokenService: Token Context couldnt retrieved from redis", ) return AuthContext(token_context=token_context) @classmethod def auth_required(cls, func: Callable) -> Callable: """ Decorator for protecting FastAPI endpoints with authentication. Usage: @router.get("/protected") @MiddlewareModule.auth_required async def protected_endpoint(request: Request): auth = protected_endpoint.auth # Access auth context if auth.is_employee: # Handle employee logic employee_id = auth.token_context.employee_id else: # Handle occupant logic occupant_id = auth.token_context.occupant_id return {"user_id": auth.user_id} Args: func: The FastAPI route handler function to protect Returns: Callable: Wrapped function that checks authentication before execution Raises: HTTPExceptionApi: If authentication fails """ @wraps(func) def wrapper(request: Request, *args, **kwargs): # Get and validate token context from request auth_context = cls.get_user_from_request(request) # Attach auth context to function func.auth = auth_context # Call the original endpoint function return func(request, *args, **kwargs) return wrapper class RequestTimingMiddleware(BaseHTTPMiddleware): """ Middleware for measuring and logging request timing. Only handles timing, no authentication. """ async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process each request through the middleware. Args: request: FastAPI request object call_next: Next middleware in the chain Returns: Response: Processed response with timing headers """ start_time = perf_counter() # Process the request response = await call_next(request) # Add timing information to response headers end_time = perf_counter() elapsed = (end_time - start_time) * 1000 # Convert to milliseconds response.headers.update( { "request-start": f"{start_time:.6f}", "request-end": f"{end_time:.6f}", "request-duration": f"{elapsed:.2f}ms", } ) return response