base context for wrappers updated
This commit is contained in:
@@ -10,77 +10,9 @@ This module initializes and configures the FastAPI application with:
|
||||
"""
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from create_routes import get_all_routers
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from app_handler import setup_middleware, get_uvicorn_config
|
||||
from create_file import setup_security_schema, configure_route_security
|
||||
from open_api_creator import OpenAPISchemaCreator, create_openapi_schema
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(
|
||||
responses={
|
||||
422: {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"detail": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"loc": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"msg": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Get all routers and protected routes from the new configuration
|
||||
routers, protected_routes = get_all_routers()
|
||||
|
||||
# Include all routers
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
|
||||
# Configure OpenAPI schema with security
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
# Create OpenAPI schema using our custom creator
|
||||
openapi_schema = create_openapi_schema(app)
|
||||
|
||||
# Add security scheme
|
||||
openapi_schema.update(setup_security_schema())
|
||||
|
||||
# Configure security for protected routes
|
||||
for path, methods in protected_routes.items():
|
||||
for method in methods:
|
||||
configure_route_security(
|
||||
path, method, openapi_schema, list(protected_routes.keys())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
return app
|
||||
from create_file import create_app
|
||||
|
||||
|
||||
app = create_app() # Initialize FastAPI application
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi import FastAPI, APIRouter
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from AllConfigs.Token.config import Auth
|
||||
from AllConfigs.main import MainConfig as Config
|
||||
|
||||
from create_routes import get_all_routers
|
||||
@@ -29,11 +30,11 @@ def setup_security_schema() -> Dict[str, Any]:
|
||||
return {
|
||||
"components": {
|
||||
"securitySchemes": {
|
||||
"Bearer": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Enter the token",
|
||||
"Bearer Auth": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": Auth.ACCESS_TOKEN_TAG,
|
||||
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,12 +66,15 @@ def create_app() -> FastAPI:
|
||||
Returns:
|
||||
FastAPI: Configured FastAPI application instance
|
||||
"""
|
||||
# Initialize FastAPI app
|
||||
|
||||
from open_api_creator import create_openapi_schema
|
||||
# Get all routers and protected routes using the dynamic route creation
|
||||
|
||||
app = FastAPI(
|
||||
title=Config.TITLE,
|
||||
description=Config.DESCRIPTION,
|
||||
default_response_class=JSONResponse,
|
||||
)
|
||||
) # Initialize FastAPI app
|
||||
|
||||
@app.get("/", include_in_schema=False, summary=str(Config.DESCRIPTION))
|
||||
async def home() -> RedirectResponse:
|
||||
@@ -84,31 +88,5 @@ def create_app() -> FastAPI:
|
||||
for router in routers:
|
||||
app.include_router(router)
|
||||
|
||||
# Configure OpenAPI schema with security
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title="WAG Management API",
|
||||
version="4.0.0",
|
||||
description="WAG Management API Service",
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add security scheme
|
||||
security_schema = setup_security_schema()
|
||||
openapi_schema.update(security_schema)
|
||||
|
||||
# Configure security for protected routes
|
||||
for path, methods in protected_routes.items():
|
||||
for method in methods:
|
||||
configure_route_security(
|
||||
path, method, openapi_schema, list(protected_routes.keys())
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi
|
||||
app.openapi = lambda app=app: create_openapi_schema(app)
|
||||
return app
|
||||
|
||||
@@ -24,8 +24,6 @@ class EndpointFactoryConfig:
|
||||
summary: str
|
||||
description: str
|
||||
endpoint_function: Callable[P, R] # Now accepts any parameters and return type
|
||||
response_model: Optional[type] = None
|
||||
request_model: Optional[type] = None
|
||||
is_auth_required: bool = True
|
||||
is_event_required: bool = False
|
||||
extra_options: Dict[str, Any] = None
|
||||
@@ -56,7 +54,7 @@ class EnhancedEndpointFactory:
|
||||
endpoint_function = config.endpoint_function
|
||||
|
||||
if config.is_auth_required:
|
||||
endpoint_function = MiddlewareModule.auth_required(endpoint_function)
|
||||
# endpoint_function = MiddlewareModule.auth_required(endpoint_function)
|
||||
# Track protected routes
|
||||
full_path = f"{self.router.prefix}{endpoint_path}"
|
||||
if full_path not in self.protected_routes:
|
||||
@@ -66,7 +64,6 @@ class EnhancedEndpointFactory:
|
||||
# Register the endpoint with FastAPI router
|
||||
getattr(self.router, config.method.lower())(
|
||||
endpoint_path,
|
||||
response_model=config.response_model,
|
||||
summary=config.summary,
|
||||
description=config.description,
|
||||
**config.extra_options,
|
||||
|
||||
@@ -15,6 +15,7 @@ from ApiLibrary.common.line_number import get_line_number_for_error
|
||||
from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
|
||||
from .base_context import BaseContext
|
||||
from ApiServices.Token.token_handler import OccupantTokenObject, EmployeeTokenObject
|
||||
import inspect
|
||||
|
||||
|
||||
class AuthContext(BaseContext):
|
||||
@@ -42,7 +43,12 @@ class AuthContext(BaseContext):
|
||||
@property
|
||||
def user_id(self) -> str:
|
||||
"""Get the user's UUID from token context."""
|
||||
return self.token_context.user_uu_id if self.token_context else ""
|
||||
return self.token_context.user_uu_id if self.token_context else None
|
||||
|
||||
def as_dict(self):
|
||||
if not isinstance(self.token_context, dict):
|
||||
return self.token_context.model_dump()
|
||||
return self.token_context
|
||||
|
||||
def __repr__(self) -> str:
|
||||
user_type = "Employee" if self.is_employee else "Occupant"
|
||||
@@ -57,7 +63,7 @@ class MiddlewareModule:
|
||||
@staticmethod
|
||||
def get_user_from_request(
|
||||
request: Request,
|
||||
) -> AuthContext:
|
||||
) -> dict:
|
||||
"""
|
||||
Get authenticated token context from request.
|
||||
|
||||
@@ -74,7 +80,6 @@ class MiddlewareModule:
|
||||
|
||||
# Get token and validate - will raise HTTPExceptionApi if invalid
|
||||
redis_token = TokenService.get_access_token_from_request(request=request)
|
||||
|
||||
# Get token context - will validate token and raise appropriate errors
|
||||
token_context = TokenService.get_object_via_access_key(access_token=redis_token)
|
||||
if not token_context:
|
||||
@@ -85,7 +90,7 @@ class MiddlewareModule:
|
||||
sys_msg="TokenService: Token Context couldnt retrieved from redis",
|
||||
)
|
||||
|
||||
return AuthContext(token_context=token_context)
|
||||
return token_context
|
||||
|
||||
@classmethod
|
||||
def auth_required(cls, func: Callable) -> Callable:
|
||||
@@ -116,14 +121,13 @@ class MiddlewareModule:
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(request: Request, *args, **kwargs):
|
||||
async def wrapper(request: Request, *args, **kwargs):
|
||||
# Get and validate token context from request
|
||||
auth_context = cls.get_user_from_request(request)
|
||||
|
||||
# Attach auth context to function
|
||||
func.auth = auth_context
|
||||
|
||||
# Create auth context and Attach auth context to both wrapper and original function
|
||||
func.auth = cls.get_user_from_request(request) # This ensures the context is available in both places
|
||||
# Call the original endpoint function
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(request, *args, **kwargs)
|
||||
return func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -147,7 +151,6 @@ class RequestTimingMiddleware(BaseHTTPMiddleware):
|
||||
Response: Processed response with timing headers
|
||||
"""
|
||||
start_time = perf_counter()
|
||||
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import wraps
|
||||
from typing import Callable, Dict, Any
|
||||
from .auth_middleware import MiddlewareModule
|
||||
from .base_context import BaseContext
|
||||
import inspect
|
||||
|
||||
|
||||
class TokenEventHandler(BaseContext):
|
||||
@@ -51,26 +52,14 @@ class TokenEventMiddleware:
|
||||
authenticated_func = MiddlewareModule.auth_required(func)
|
||||
|
||||
@wraps(authenticated_func)
|
||||
def wrapper(*args, **kwargs) -> Dict[str, Any]:
|
||||
async def wrapper(*args, **kwargs) -> Dict[str, Any]:
|
||||
# Create handler with context
|
||||
handler = TokenEventHandler(
|
||||
func=authenticated_func,
|
||||
url_of_endpoint=authenticated_func.url_of_endpoint,
|
||||
)
|
||||
|
||||
# Update event-specific context
|
||||
handler.update_context(
|
||||
function_code="7192c2aa-5352-4e36-98b3-dafb7d036a3d" # Keep function_code as URL
|
||||
)
|
||||
|
||||
# Copy auth context from authenticated function
|
||||
if hasattr(authenticated_func, "auth"):
|
||||
handler.token_context = authenticated_func.auth.token_context
|
||||
|
||||
# Make handler available to the function
|
||||
authenticated_func.handler = handler
|
||||
function_code = "7192c2aa-5352-4e36-98b3-dafb7d036a3d" # Keep function_code as URL
|
||||
|
||||
# Make handler available to all functions in the chain
|
||||
func.func_code = {"function_code": function_code}
|
||||
# Call the authenticated function
|
||||
if inspect.iscoroutinefunction(authenticated_func):
|
||||
return await authenticated_func(*args, **kwargs)
|
||||
return authenticated_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -13,6 +13,8 @@ from typing import Any, Dict, List, Optional, Set
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from AllConfigs.Token.config import Auth
|
||||
from AllConfigs.main import MainConfig as Config
|
||||
from create_routes import get_all_routers
|
||||
|
||||
@@ -62,17 +64,11 @@ class OpenAPISchemaCreator:
|
||||
"""
|
||||
return {
|
||||
"Bearer Auth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Enter the token with the `Bearer: ` prefix",
|
||||
},
|
||||
"API Key": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-API-Key",
|
||||
"description": "Optional API key for service authentication",
|
||||
},
|
||||
"name": Auth.ACCESS_TOKEN_TAG,
|
||||
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
|
||||
}
|
||||
}
|
||||
|
||||
def _create_common_responses(self) -> Dict[str, Any]:
|
||||
@@ -198,7 +194,6 @@ class OpenAPISchemaCreator:
|
||||
if path in self.protected_routes and method in self.protected_routes[path]:
|
||||
schema["paths"][path][method]["security"] = [
|
||||
{"Bearer Auth": []},
|
||||
{"API Key": []},
|
||||
]
|
||||
schema["paths"][path][method]["responses"].update(
|
||||
self._create_common_responses()
|
||||
@@ -219,7 +214,7 @@ class OpenAPISchemaCreator:
|
||||
openapi_schema = get_openapi(
|
||||
title=Config.TITLE,
|
||||
description=Config.DESCRIPTION,
|
||||
version="1.0.0",
|
||||
version="1.1.1",
|
||||
routes=self.app.routes,
|
||||
tags=self.tags_metadata,
|
||||
)
|
||||
@@ -229,17 +224,15 @@ class OpenAPISchemaCreator:
|
||||
openapi_schema["components"] = {}
|
||||
|
||||
openapi_schema["components"]["securitySchemes"] = self._create_security_schemes()
|
||||
|
||||
# Configure route security and responses
|
||||
for route in self.app.routes:
|
||||
if isinstance(route, APIRoute) and route.include_in_schema:
|
||||
path = str(route.path)
|
||||
methods = [method.lower() for method in route.methods]
|
||||
|
||||
for method in methods:
|
||||
self.configure_route_security(path, method, openapi_schema)
|
||||
|
||||
# Add custom documentation extensions
|
||||
# # Add custom documentation extensions
|
||||
openapi_schema["x-documentation"] = {
|
||||
"postman_collection": "/docs/postman",
|
||||
"swagger_ui": "/docs",
|
||||
|
||||
Reference in New Issue
Block a user