""" Token event middleware for handling authentication and event tracking. """ import inspect from functools import wraps from typing import Callable, Dict, Any, Optional, Tuple, Union from fastapi import Request from pydantic import BaseModel from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error from ApiLayers.ApiServices.Token.token_handler import TokenService from ApiLayers.ApiValidations.Custom.wrapper_contexts import EventContext from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi from ApiLayers.Schemas import Events, EndpointRestriction from ApiLayers.AllConfigs.Redis.configs import RedisCategoryKeys from Services.Redis.Actions.actions import RedisActions from .auth_middleware import MiddlewareModule class TokenEventMiddleware: """ Module containing token and event handling functionality. This class provides: - Token and event context management - Event validation decorator for endpoints """ @staticmethod def retrieve_access_content(request_from_scope: Request) -> Tuple[str, list[str]]: """ Retrieves the access token and validates it. Args: request_from_scope: The FastAPI request object Returns: Tuple[str, list[str]]: The access token and a list of reachable event codes """ # Get token context from request access_token = TokenService.get_access_token_from_request(request_from_scope) if not access_token: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) # Get token context from Redis by access token and collect reachable event codes token_context = TokenService.get_object_via_access_key( access_token=access_token ) if token_context.is_employee: reachable_event_codes: list[str] = ( token_context.selected_company.reachable_event_codes ) elif token_context.is_occupant: reachable_event_codes: list[str] = ( token_context.selected_occupant.reachable_event_codes ) else: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) return token_context, reachable_event_codes @staticmethod def retrieve_intersected_event_code( request: Request, reachable_event_codes: list[str] ) -> Tuple[str, str]: """ Match an endpoint with accessible events. Args: request: The endpoint to match reachable_event_codes: The list of event codes accessible to the user Returns: Dict containing the endpoint registration data None if endpoint is not found in database """ endpoint_url = str(request.url.path) # Get the endpoint URL for matching with events function_codes_of_endpoint = RedisActions.get_json( list_keys=[RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url] ) function_code_list_of_event = function_codes_of_endpoint.first if not function_codes_of_endpoint.status: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Function code not found", ) # Intersect function codes with user accers objects available event codes # reachable_event_codes = ["36a165fe-a2f3-437b-80ee-1ee44670fe70"] intersected_code = list( set(function_code_list_of_event) & set(reachable_event_codes) ) if not len(intersected_code) == 1: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="No event is registered for this user.", ) return endpoint_url, intersected_code[0] @classmethod def event_required(cls, func: Callable) -> Callable: """ Decorator for endpoints with token and event requirements. This decorator: 1. First validates authentication using MiddlewareModule.auth_required 2. Then adds event tracking context Args: func: The function to be decorated Returns: Callable: The wrapped function with both auth and event handling """ @wraps(func) async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]: # Get and validate token context from request token_context, reachable_event_codes = cls.retrieve_access_content(request) endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code( request, reachable_event_codes ) event_context = EventContext( auth=token_context, code=reachable_event_code, url=endpoint_url, request=request, ) # Get auth context from the authenticated function's wrapper if token_context is not None: setattr(wrapper, "event_context", event_context) setattr(func, "event_context", event_context) # Execute the authenticated function and get its result if inspect.iscoroutinefunction(func): result = await func(request, *args, **kwargs) else: result = func(request, *args, **kwargs) return result return wrapper # event_required is already sets function_code state to wrapper # @classmethod # def validation_required(cls, func: Callable[..., Dict[str, Any]]) -> Callable[..., Dict[str, Any]]: # """ # Decorator for endpoints with token and event requirements. # This decorator: # 1. First validates authentication using MiddlewareModule.auth_required # 2. Then adds event tracking context # Args: # func: The function to be decorated # Returns: # Callable: The wrapped function with both auth and event handling # """ # @wraps(func) # async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Union[Dict[str, Any], BaseModel]: # # Get and validate token context from request # token_context, reachable_event_codes = cls.retrieve_access_content(request) # endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes) # # Get auth context from the authenticated function's wrapper # if token_context is not None: # setattr(wrapper, 'auth', token_context) # setattr(wrapper, 'url', endpoint_url) # setattr(wrapper, 'func_code', reachable_event_code) # # Execute the authenticated function and get its result # if inspect.iscoroutinefunction(func): # result = await func(request, *args, **kwargs) # else: # result = func(request, *args, **kwargs) # return result # return wrapper