154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
"""
|
|
Authentication and Authorization middleware for FastAPI applications.
|
|
|
|
This module provides authentication decorator for protecting endpoints
|
|
and a middleware for request timing measurements.
|
|
"""
|
|
|
|
from time import perf_counter
|
|
from typing import Callable, Optional, Dict, Any, Tuple
|
|
from functools import wraps
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from AllConfigs.Token.config import Auth
|
|
from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
|
|
|
|
|
|
class MiddlewareModule:
|
|
"""
|
|
Module containing authentication and middleware functionality.
|
|
|
|
This class provides:
|
|
- Token extraction and validation
|
|
- Authentication decorator for endpoints
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_access_token(request: Request) -> Tuple[str, str]:
|
|
"""
|
|
Extract access token from request headers.
|
|
|
|
Args:
|
|
request: FastAPI request object
|
|
|
|
Returns:
|
|
Tuple[str, str]: A tuple containing (scheme, token)
|
|
|
|
Raises:
|
|
HTTPExceptionApi: If token is missing or malformed
|
|
"""
|
|
auth_header = request.headers.get(Auth.ACCESS_TOKEN_TAG)
|
|
if not auth_header:
|
|
raise HTTPExceptionApi(error_code="HTTP_401_UNAUTHORIZED", lang="tr")
|
|
|
|
try:
|
|
scheme, token = auth_header.split()
|
|
if scheme.lower() != "bearer":
|
|
raise HTTPExceptionApi(error_code="HTTP_401_UNAUTHORIZED", lang="tr")
|
|
return scheme, token
|
|
except ValueError:
|
|
raise HTTPExceptionApi(error_code="HTTP_401_UNAUTHORIZED", lang="tr")
|
|
|
|
@staticmethod
|
|
async def validate_token(token: str) -> Dict[str, Any]:
|
|
"""
|
|
Validate the authentication token.
|
|
|
|
Args:
|
|
token: JWT token to validate
|
|
|
|
Returns:
|
|
Dict[str, Any]: User data extracted from token
|
|
|
|
Raises:
|
|
HTTPExceptionApi: If token is invalid
|
|
"""
|
|
try:
|
|
# TODO: Implement your token validation logic
|
|
# Example:
|
|
# return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
|
return {"user_id": "test", "role": "user"} # Placeholder
|
|
except Exception as e:
|
|
raise HTTPExceptionApi(error_code="HTTP_401_UNAUTHORIZED", lang="tr")
|
|
|
|
@classmethod
|
|
def auth_required(cls, func: Callable) -> Callable:
|
|
"""
|
|
Decorator for protecting FastAPI endpoints with authentication.
|
|
|
|
Usage:
|
|
@router.get("/protected")
|
|
@MiddlewareModule.auth_required
|
|
async def protected_endpoint(request: Request):
|
|
user = request.state.user # Access authenticated user data
|
|
return {"message": "Protected content"}
|
|
|
|
@router.get("/public") # No decorator = public endpoint
|
|
async def public_endpoint():
|
|
return {"message": "Public content"}
|
|
|
|
Args:
|
|
func: The FastAPI route handler function to protect
|
|
|
|
Returns:
|
|
Callable: Wrapped function that checks authentication before execution
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(request: Request, *args, **kwargs):
|
|
try:
|
|
# Get token from header
|
|
_, token = cls.get_access_token(request)
|
|
|
|
# Validate token and get user data
|
|
token_data = await cls.validate_token(token)
|
|
|
|
# Add user data to request state for use in endpoint
|
|
request.state.user = token_data
|
|
|
|
# Call the original endpoint function
|
|
return await func(request, *args, **kwargs)
|
|
|
|
except HTTPExceptionApi:
|
|
raise HTTPExceptionApi(error_code="NOT_AUTHORIZED", lang="tr")
|
|
except Exception as e:
|
|
raise HTTPExceptionApi(error_code="NOT_AUTHORIZED", lang="tr")
|
|
return wrapper
|
|
|
|
|
|
class RequestTimingMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware for measuring and logging request timing.
|
|
Only handles timing, no authentication.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""
|
|
Process each request through the middleware.
|
|
|
|
Args:
|
|
request: FastAPI request object
|
|
call_next: Next middleware in the chain
|
|
|
|
Returns:
|
|
Response: Processed response with timing headers
|
|
"""
|
|
start_time = perf_counter()
|
|
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add timing information to response headers
|
|
end_time = perf_counter()
|
|
elapsed = (end_time - start_time) * 1000 # Convert to milliseconds
|
|
|
|
response.headers.update(
|
|
{
|
|
"request-start": f"{start_time:.6f}",
|
|
"request-end": f"{end_time:.6f}",
|
|
"request-duration": f"{elapsed:.2f}ms",
|
|
}
|
|
)
|
|
|
|
return response
|