base context for wrappers updated
This commit is contained in:
@@ -13,7 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
|
||||
from middleware.auth_middleware import RequestTimingMiddleware
|
||||
from middleware.auth_middleware import RequestTimingMiddleware, LoggerTimingMiddleware
|
||||
|
||||
|
||||
def setup_cors_middleware(app: FastAPI) -> None:
|
||||
@@ -74,6 +74,7 @@ def setup_middleware(app: FastAPI) -> None:
|
||||
"""
|
||||
setup_cors_middleware(app)
|
||||
app.add_middleware(RequestTimingMiddleware)
|
||||
app.add_middleware(LoggerTimingMiddleware)
|
||||
setup_exception_handlers(app)
|
||||
|
||||
|
||||
|
||||
@@ -68,13 +68,14 @@ def create_app() -> FastAPI:
|
||||
"""
|
||||
|
||||
from open_api_creator import create_openapi_schema
|
||||
|
||||
# Get all routers and protected routes using the dynamic route creation
|
||||
|
||||
app = FastAPI(
|
||||
title=Config.TITLE,
|
||||
description=Config.DESCRIPTION,
|
||||
default_response_class=JSONResponse,
|
||||
) # Initialize FastAPI app
|
||||
) # Initialize FastAPI app
|
||||
|
||||
@app.get("/", include_in_schema=False, summary=str(Config.DESCRIPTION))
|
||||
async def home() -> RedirectResponse:
|
||||
@@ -89,4 +90,5 @@ def create_app() -> FastAPI:
|
||||
app.include_router(router)
|
||||
|
||||
app.openapi = lambda app=app: create_openapi_schema(app)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
from .token_event_middleware import TokenEventMiddleware
|
||||
from .auth_middleware import RequestTimingMiddleware, MiddlewareModule
|
||||
from .auth_middleware import (
|
||||
LoggerTimingMiddleware,
|
||||
RequestTimingMiddleware,
|
||||
MiddlewareModule,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["TokenEventMiddleware", "RequestTimingMiddleware", "MiddlewareModule"]
|
||||
__all__ = [
|
||||
"TokenEventMiddleware",
|
||||
"RequestTimingMiddleware",
|
||||
"MiddlewareModule",
|
||||
"LoggerTimingMiddleware",
|
||||
]
|
||||
|
||||
@@ -8,53 +8,15 @@ and a middleware for request timing measurements.
|
||||
from time import perf_counter
|
||||
from typing import Callable, Optional, Dict, Any, Tuple, Union
|
||||
from functools import wraps
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from fastapi import Request, Response
|
||||
from AllConfigs.Token.config import Auth
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from ApiLibrary.common.line_number import get_line_number_for_error
|
||||
from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
|
||||
from .base_context import BaseContext
|
||||
from ApiServices.Token.token_handler import OccupantTokenObject, EmployeeTokenObject
|
||||
import inspect
|
||||
|
||||
|
||||
class AuthContext(BaseContext):
|
||||
"""
|
||||
Context class for authentication middleware.
|
||||
Extends BaseContext to provide authentication-specific functionality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, token_context: Union[OccupantTokenObject, EmployeeTokenObject]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.token_context = token_context
|
||||
|
||||
@property
|
||||
def is_employee(self) -> bool:
|
||||
"""Check if authenticated token is for an employee."""
|
||||
return isinstance(self.token_context, EmployeeTokenObject)
|
||||
|
||||
@property
|
||||
def is_occupant(self) -> bool:
|
||||
"""Check if authenticated token is for an occupant."""
|
||||
return isinstance(self.token_context, OccupantTokenObject)
|
||||
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""Get the user's UUID from token context."""
|
||||
return self.token_context.user_uu_id if self.token_context else None
|
||||
|
||||
def as_dict(self):
|
||||
if not isinstance(self.token_context, dict):
|
||||
return self.token_context.model_dump()
|
||||
return self.token_context
|
||||
|
||||
def __repr__(self) -> str:
|
||||
user_type = "Employee" if self.is_employee else "Occupant"
|
||||
return f"AuthContext({user_type}Token: {self.user_id})"
|
||||
|
||||
|
||||
class MiddlewareModule:
|
||||
"""
|
||||
Middleware module for handling authentication and request timing.
|
||||
@@ -124,7 +86,9 @@ class MiddlewareModule:
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
# Get and validate token context from request
|
||||
# Create auth context and Attach auth context to both wrapper and original function
|
||||
func.auth = cls.get_user_from_request(request) # This ensures the context is available in both places
|
||||
func.auth = cls.get_user_from_request(
|
||||
request
|
||||
) # This ensures the context is available in both places
|
||||
# Call the original endpoint function
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(request, *args, **kwargs)
|
||||
@@ -167,3 +131,20 @@ class RequestTimingMiddleware(BaseHTTPMiddleware):
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class LoggerTimingMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware for measuring and logging request timing.
|
||||
Only handles timing, no authentication.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# Log the request
|
||||
print(f"Handling request: {request.method} {request.url}")
|
||||
response = await call_next(request)
|
||||
# Log the response
|
||||
print(
|
||||
f"Completed request: {request.method} {request.url} with status {response.status_code}"
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Base context for middleware."""
|
||||
|
||||
from typing import Optional, Dict, Any, Union, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ApiServices.Token.token_handler import OccupantTokenObject, EmployeeTokenObject
|
||||
|
||||
|
||||
class BaseContext:
|
||||
"""Base context class for middleware."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._token_context: Optional[
|
||||
Union["OccupantTokenObject", "EmployeeTokenObject"]
|
||||
] = None
|
||||
self._function_code: Optional[str] = None
|
||||
|
||||
@property
|
||||
def token_context(
|
||||
self,
|
||||
) -> Optional[Union["OccupantTokenObject", "EmployeeTokenObject"]]:
|
||||
"""Get token context if available."""
|
||||
return self._token_context
|
||||
|
||||
@token_context.setter
|
||||
def token_context(
|
||||
self, value: Union["OccupantTokenObject", "EmployeeTokenObject"]
|
||||
) -> None:
|
||||
"""Set token context."""
|
||||
self._token_context = value
|
||||
|
||||
@property
|
||||
def function_code(self) -> Optional[str]:
|
||||
"""Get function code if available."""
|
||||
return self._function_code
|
||||
|
||||
@function_code.setter
|
||||
def function_code(self, value: str) -> None:
|
||||
"""Set function code."""
|
||||
self._function_code = value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert context to dictionary."""
|
||||
return {
|
||||
"token_context": self._token_context,
|
||||
"function_code": self._function_code,
|
||||
}
|
||||
@@ -5,24 +5,9 @@ Token event middleware for handling authentication and event tracking.
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Any
|
||||
from .auth_middleware import MiddlewareModule
|
||||
from .base_context import BaseContext
|
||||
import inspect
|
||||
|
||||
|
||||
class TokenEventHandler(BaseContext):
|
||||
"""Handler for token events with authentication context."""
|
||||
|
||||
def __init__(self, func: Callable, url_of_endpoint: str):
|
||||
"""Initialize the handler with function and URL."""
|
||||
super().__init__()
|
||||
self.func = func
|
||||
self.url_of_endpoint = url_of_endpoint
|
||||
|
||||
def update_context(self, function_code: str):
|
||||
"""Update the event context with function code."""
|
||||
self.function_code = function_code
|
||||
|
||||
|
||||
class TokenEventMiddleware:
|
||||
"""
|
||||
Module containing token and event handling functionality.
|
||||
@@ -54,7 +39,9 @@ class TokenEventMiddleware:
|
||||
@wraps(authenticated_func)
|
||||
async def wrapper(*args, **kwargs) -> Dict[str, Any]:
|
||||
# Create handler with context
|
||||
function_code = "7192c2aa-5352-4e36-98b3-dafb7d036a3d" # Keep function_code as URL
|
||||
function_code = (
|
||||
"7192c2aa-5352-4e36-98b3-dafb7d036a3d" # Keep function_code as URL
|
||||
)
|
||||
|
||||
# Make handler available to all functions in the chain
|
||||
func.func_code = {"function_code": function_code}
|
||||
@@ -62,4 +49,5 @@ class TokenEventMiddleware:
|
||||
if inspect.iscoroutinefunction(authenticated_func):
|
||||
return await authenticated_func(*args, **kwargs)
|
||||
return authenticated_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -123,7 +123,9 @@ class OpenAPISchemaCreator:
|
||||
},
|
||||
}
|
||||
|
||||
def _process_request_body(self, path: str, method: str, schema: Dict[str, Any]) -> None:
|
||||
def _process_request_body(
|
||||
self, path: str, method: str, schema: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Process request body to include examples from model config.
|
||||
|
||||
@@ -140,17 +142,24 @@ class OpenAPISchemaCreator:
|
||||
content = request_body["content"]
|
||||
if "application/json" in content:
|
||||
json_content = content["application/json"]
|
||||
if "schema" in json_content and "$ref" in json_content["schema"]:
|
||||
if (
|
||||
"schema" in json_content
|
||||
and "$ref" in json_content["schema"]
|
||||
):
|
||||
ref = json_content["schema"]["$ref"]
|
||||
model_name = ref.split("/")[-1]
|
||||
if model_name in schema["components"]["schemas"]:
|
||||
model_schema = schema["components"]["schemas"][model_name]
|
||||
model_schema = schema["components"]["schemas"][
|
||||
model_name
|
||||
]
|
||||
if "example" in model_schema:
|
||||
json_content["example"] = model_schema["example"]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _process_response_examples(self, path: str, method: str, schema: Dict[str, Any]) -> None:
|
||||
def _process_response_examples(
|
||||
self, path: str, method: str, schema: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Process response body to include examples from model config.
|
||||
|
||||
@@ -169,13 +178,20 @@ class OpenAPISchemaCreator:
|
||||
content = response["content"]
|
||||
if "application/json" in content:
|
||||
json_content = content["application/json"]
|
||||
if "schema" in json_content and "$ref" in json_content["schema"]:
|
||||
if (
|
||||
"schema" in json_content
|
||||
and "$ref" in json_content["schema"]
|
||||
):
|
||||
ref = json_content["schema"]["$ref"]
|
||||
model_name = ref.split("/")[-1]
|
||||
if model_name in schema["components"]["schemas"]:
|
||||
model_schema = schema["components"]["schemas"][model_name]
|
||||
model_schema = schema["components"]["schemas"][
|
||||
model_name
|
||||
]
|
||||
if "example" in model_schema:
|
||||
json_content["example"] = model_schema["example"]
|
||||
json_content["example"] = model_schema[
|
||||
"example"
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
@@ -198,7 +214,7 @@ class OpenAPISchemaCreator:
|
||||
schema["paths"][path][method]["responses"].update(
|
||||
self._create_common_responses()
|
||||
)
|
||||
|
||||
|
||||
# Process request body examples
|
||||
self._process_request_body(path, method, schema)
|
||||
# Process response examples
|
||||
@@ -223,7 +239,9 @@ class OpenAPISchemaCreator:
|
||||
if "components" not in openapi_schema:
|
||||
openapi_schema["components"] = {}
|
||||
|
||||
openapi_schema["components"]["securitySchemes"] = self._create_security_schemes()
|
||||
openapi_schema["components"][
|
||||
"securitySchemes"
|
||||
] = self._create_security_schemes()
|
||||
# Configure route security and responses
|
||||
for route in self.app.routes:
|
||||
if isinstance(route, APIRoute) and route.include_in_schema:
|
||||
|
||||
Reference in New Issue
Block a user