176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
"""
|
|
Authentication and Authorization middleware for FastAPI applications.
|
|
|
|
This module provides authentication decorator for protecting endpoints
|
|
and a middleware for request timing measurements.
|
|
"""
|
|
|
|
import inspect
|
|
|
|
from time import perf_counter
|
|
from typing import Callable
|
|
from functools import wraps
|
|
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
|
|
from ApiLayers.ApiValidations.Custom.wrapper_contexts import AuthContext
|
|
from ApiLayers.ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
|
|
from ApiLayers.AllConfigs.Token.config import Auth
|
|
from ApiLayers.ApiServices.Token.token_handler import TokenService
|
|
|
|
|
|
class MiddlewareModule:
|
|
"""
|
|
Middleware module for handling authentication and request timing.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_user_from_request(request: Request) -> dict:
|
|
"""
|
|
Get authenticated token context from request.
|
|
|
|
Args:
|
|
request: FastAPI request object
|
|
|
|
Returns:
|
|
AuthContext: Context containing the authenticated token data
|
|
|
|
Raises:
|
|
HTTPExceptionApi: If token is missing, invalid, or user not found
|
|
"""
|
|
|
|
# Get token and validate - will raise HTTPExceptionApi if invalid
|
|
redis_token = TokenService.get_access_token_from_request(request=request)
|
|
# Get token context - will validate token and raise appropriate errors
|
|
token_context = TokenService.get_object_via_access_key(access_token=redis_token)
|
|
if not token_context:
|
|
raise HTTPExceptionApi(
|
|
error_code="USER_NOT_FOUND",
|
|
lang="tr",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="TokenService: Token Context couldnt retrieved from redis",
|
|
)
|
|
return token_context
|
|
|
|
@classmethod
|
|
def auth_required(cls, func: Callable) -> Callable:
|
|
"""
|
|
Decorator for protecting FastAPI endpoints with authentication.
|
|
|
|
Usage:
|
|
@router.get("/protected")
|
|
@MiddlewareModule.auth_required
|
|
async def protected_endpoint(request: Request):
|
|
auth = protected_endpoint.auth # Access auth context
|
|
if auth.is_employee:
|
|
# Handle employee logic
|
|
employee_id = auth.token_context.employee_id
|
|
else:
|
|
# Handle occupant logic
|
|
occupant_id = auth.token_context.occupant_id
|
|
return {"user_id": auth.user_id}
|
|
|
|
Args:
|
|
func: The FastAPI route handler function to protect
|
|
|
|
Returns:
|
|
Callable: Wrapped function that checks authentication before execution
|
|
|
|
Raises:
|
|
HTTPExceptionApi: If authentication fails
|
|
"""
|
|
|
|
@wraps(func)
|
|
async def wrapper(request: Request, *args, **kwargs):
|
|
# Get and validate token context from request
|
|
endpoint_url = str(request.url.path)
|
|
token_context = cls.get_user_from_request(request=request)
|
|
auth_context = AuthContext(
|
|
auth=token_context, url=endpoint_url, request=request
|
|
)
|
|
|
|
# Set auth context on the wrapper function itself
|
|
setattr(func, "auth_context", auth_context)
|
|
setattr(wrapper, "auth_context", auth_context)
|
|
|
|
# Call the original endpoint function
|
|
if inspect.iscoroutinefunction(func):
|
|
result = await func(request, *args, **kwargs)
|
|
else:
|
|
result = func(request, *args, **kwargs)
|
|
|
|
# Set auth context on the wrapper function itself
|
|
setattr(func, "auth_context", auth_context)
|
|
setattr(wrapper, "auth_context", auth_context)
|
|
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
|
|
class RequestTimingMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware for measuring and logging request timing.
|
|
Only handles timing, no authentication.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""
|
|
Process each request through the middleware.
|
|
|
|
Args:
|
|
request: FastAPI request object
|
|
call_next: Next middleware in the chain
|
|
|
|
Returns:
|
|
Response: Processed response with timing headers
|
|
"""
|
|
start_time = perf_counter()
|
|
# Process the request
|
|
response = await call_next(request)
|
|
|
|
# Add timing information to response headers
|
|
end_time = perf_counter()
|
|
elapsed = (end_time - start_time) * 1000 # Convert to milliseconds
|
|
|
|
response.headers.update(
|
|
{
|
|
"request-start": f"{start_time:.6f}",
|
|
"request-end": f"{end_time:.6f}",
|
|
"request-duration": f"{elapsed:.2f}ms",
|
|
}
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
class LoggerTimingMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware for measuring and logging request timing.
|
|
Only handles timing, no authentication.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
# Log the request
|
|
import arrow
|
|
|
|
headers = dict(request.headers)
|
|
response = await call_next(request)
|
|
# Log the response
|
|
print(
|
|
"Loggers :",
|
|
{
|
|
"url": request.url,
|
|
"method": request.method,
|
|
"access_token": headers.get(Auth.ACCESS_TOKEN_TAG, ""),
|
|
"referer": headers.get("referer", ""),
|
|
"origin": headers.get("origin", ""),
|
|
"user-agent": headers.get("user-agent", ""),
|
|
"datetime": arrow.now().format("YYYY-MM-DD HH:mm:ss ZZ"),
|
|
"status_code": response.status_code,
|
|
},
|
|
)
|
|
return response
|