wag-services-and-backend-la.../ApiLayers/Middleware/token_event_middleware.py

199 lines
7.3 KiB
Python

"""
Token event middleware for handling authentication and event tracking.
"""
import inspect
from functools import wraps
from typing import Callable, Dict, Any, Optional, Tuple, Union
from fastapi import Request
from pydantic import BaseModel
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
from ApiLayers.ApiServices.Token.token_handler import TokenService
from ApiLayers.ApiValidations.Custom.wrapper_contexts import EventContext
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
from ApiLayers.Schemas import Events, EndpointRestriction
from ApiLayers.AllConfigs.Redis.configs import RedisCategoryKeys
from Services.Redis.Actions.actions import RedisActions
from .auth_middleware import MiddlewareModule
class TokenEventMiddleware:
"""
Module containing token and event handling functionality.
This class provides:
- Token and event context management
- Event validation decorator for endpoints
"""
@staticmethod
def retrieve_access_content(request_from_scope: Request) -> Tuple[str, list[str]]:
"""
Retrieves the access token and validates it.
Args:
request_from_scope: The FastAPI request object
Returns:
Tuple[str, list[str]]: The access token and a list of reachable event codes
"""
# Get token context from request
access_token = TokenService.get_access_token_from_request(request_from_scope)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
# Get token context from Redis by access token and collect reachable event codes
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
return token_context, reachable_event_codes
@staticmethod
def retrieve_intersected_event_code(
request: Request, reachable_event_codes: list[str]
) -> Tuple[str, str]:
"""
Match an endpoint with accessible events.
Args:
request: The endpoint to match
reachable_event_codes: The list of event codes accessible to the user
Returns:
Dict containing the endpoint registration data
None if endpoint is not found in database
"""
endpoint_url = str(request.url.path)
# Get the endpoint URL for matching with events
function_codes_of_endpoint = RedisActions.get_json(
list_keys=[RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url]
)
function_code_list_of_event = function_codes_of_endpoint.first
if not function_codes_of_endpoint.status:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
# Intersect function codes with user accers objects available event codes
# reachable_event_codes = ["36a165fe-a2f3-437b-80ee-1ee44670fe70"]
intersected_code = list(
set(function_code_list_of_event) & set(reachable_event_codes)
)
if not len(intersected_code) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return endpoint_url, intersected_code[0]
@classmethod
def event_required(cls, func: Callable) -> Callable:
"""
Decorator for endpoints with token and event requirements.
This decorator:
1. First validates authentication using MiddlewareModule.auth_required
2. Then adds event tracking context
Args:
func: The function to be decorated
Returns:
Callable: The wrapped function with both auth and event handling
"""
@wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
# Get and validate token context from request
token_context, reachable_event_codes = cls.retrieve_access_content(request)
endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(
request, reachable_event_codes
)
event_context = EventContext(
auth=token_context,
code=reachable_event_code,
url=endpoint_url,
request=request,
)
# Get auth context from the authenticated function's wrapper
if token_context is not None:
setattr(wrapper, "event_context", event_context)
setattr(func, "event_context", event_context)
# Execute the authenticated function and get its result
if inspect.iscoroutinefunction(func):
result = await func(request, *args, **kwargs)
else:
result = func(request, *args, **kwargs)
return result
return wrapper
# event_required is already sets function_code state to wrapper
# @classmethod
# def validation_required(cls, func: Callable[..., Dict[str, Any]]) -> Callable[..., Dict[str, Any]]:
# """
# Decorator for endpoints with token and event requirements.
# This decorator:
# 1. First validates authentication using MiddlewareModule.auth_required
# 2. Then adds event tracking context
# Args:
# func: The function to be decorated
# Returns:
# Callable: The wrapped function with both auth and event handling
# """
# @wraps(func)
# async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Union[Dict[str, Any], BaseModel]:
# # Get and validate token context from request
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
# endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
# # Get auth context from the authenticated function's wrapper
# if token_context is not None:
# setattr(wrapper, 'auth', token_context)
# setattr(wrapper, 'url', endpoint_url)
# setattr(wrapper, 'func_code', reachable_event_code)
# # Execute the authenticated function and get its result
# if inspect.iscoroutinefunction(func):
# result = await func(request, *args, **kwargs)
# else:
# result = func(request, *args, **kwargs)
# return result
# return wrapper