""" Token event middleware for handling authentication and event tracking. """ import inspect from functools import wraps from typing import Callable, Dict, Any, Optional, Union from fastapi import Request from pydantic import BaseModel from ApiLibrary.common.line_number import get_line_number_for_error from ApiServices.Token.token_handler import TokenService from ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi from Schemas.rules.rules import EndpointRestriction from .auth_middleware import MiddlewareModule from Schemas import Events class EventFunctions: def __init__(self, endpoint: str, request: Request): self.endpoint = endpoint self.request = request def match_endpoint_with_accesiable_event(self) -> Optional[Dict[str, Any]]: """ Match an endpoint with accessible events. Args: endpoint: The endpoint to match Returns: Dict containing the endpoint registration data None if endpoint is not found in database """ access_token = TokenService.get_access_token_from_request(self.request) token_context = TokenService.get_object_via_access_key( access_token=access_token ) if token_context.is_employee: reachable_event_codes: list[str] = ( token_context.selected_company.reachable_event_codes ) elif token_context.is_occupant: reachable_event_codes: list[str] = ( token_context.selected_occupant.reachable_event_codes ) else: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) if not access_token: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) db = EndpointRestriction.new_session() restriction = EndpointRestriction.filter_one( EndpointRestriction.endpoint_name == self.endpoint, db=db, ).data if not restriction: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Function code not found", ) event_related = Events.filter_all( Events.endpoint_id == restriction.id, db=db, ).data if not event_related: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="No event is registered for this user.", ) an_event = event_related[0] event_related_codes: list[str] = [ event.function_code for event in event_related ] intersected_code: set = set(reachable_event_codes).intersection( set(event_related_codes) ) if not len(list(intersected_code)) == 1: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="No event is registered for this user.", ) return { "endpoint_url": self.endpoint, "reachable_event_code": list(intersected_code)[0], "class": an_event.function_class, } def retrieve_function_dict(self) -> Optional[Dict[str, Any]]: """ Retrieve function dictionary for a given endpoint. Args: endpoint: The endpoint to retrieve the function dictionary for Returns: Dictionary containing the function dictionary None if endpoint is not found """ access_token = TokenService.get_access_token_from_request(self.request) token_context = TokenService.get_object_via_access_key( access_token=access_token ) if token_context.is_employee: reachable_event_codes: list[str] = ( token_context.selected_company.reachable_event_codes ) elif token_context.is_occupant: reachable_event_codes: list[str] = ( token_context.selected_occupant.reachable_event_codes ) else: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) if not access_token: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Token not found", ) db = EndpointRestriction.new_session() restriction = EndpointRestriction.filter_one( EndpointRestriction.endpoint_name == self.endpoint, db=db, ).data if not restriction: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Function code not found", ) event_related = Events.filter_all( Events.endpoint_id == restriction.id, db=db, ).data if not event_related: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="No event is registered for this user.", ) an_event = event_related[0] event_related_codes: list[str] = [ event.function_code for event in event_related ] intersected_code: set = set(reachable_event_codes).intersection( set(event_related_codes) ) if not len(list(intersected_code)) == 1: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="No event is registered for this user.", ) return { "endpoint_url": self.endpoint, "reachable_event_code": list(intersected_code)[0], "class": an_event.function_class, } class TokenEventMiddleware: """ Module containing token and event handling functionality. This class provides: - Token and event context management - Event validation decorator for endpoints """ @staticmethod def event_required( func: Callable[..., Dict[str, Any]] ) -> Callable[..., Dict[str, Any]]: """ Decorator for endpoints with token and event requirements. This decorator: 1. First validates authentication using MiddlewareModule.auth_required 2. Then adds event tracking context Args: func: The function to be decorated Returns: Callable: The wrapped function with both auth and event handling """ # # First apply authentication # authenticated_func = MiddlewareModule.auth_required(func) authenticated_func = func @wraps(authenticated_func) async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]: # Get function code from the function's metadata endpoint_url = getattr(authenticated_func, "url_of_endpoint", {}) if not endpoint_url: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Function code not found", ) # Make handler available to all functions in the chain func.func_code = EventFunctions( endpoint_url, request ).match_endpoint_with_accesiable_event() # Call the authenticated function if inspect.iscoroutinefunction(authenticated_func): return await authenticated_func(request, *args, **kwargs) return authenticated_func(request, *args, **kwargs) return wrapper @staticmethod def validation_required( func: Callable[..., Dict[str, Any]] ) -> Callable[..., Dict[str, Any]]: """ Decorator for endpoints with token and event requirements. This decorator: 1. First validates authentication using MiddlewareModule.auth_required 2. Then adds event tracking context Args: func: The function to be decorated Returns: Callable: The wrapped function with both auth and event handling """ # First apply authentication authenticated_func = MiddlewareModule.auth_required(func) @wraps(authenticated_func) async def wrapper( request: Request, *args: Any, **kwargs: Any ) -> Union[Dict[str, Any], BaseModel]: # Handle both async and sync functions endpoint_asked = getattr(kwargs.get("data", None), "data", None).get( "endpoint", None ) if not endpoint_asked: raise HTTPExceptionApi( error_code="", lang="en", loc=get_line_number_for_error(), sys_msg="Endpoint not found", ) wrapper.validation_code = EventFunctions( endpoint_asked, request ).retrieve_function_dict() if inspect.iscoroutinefunction(authenticated_func): result = await authenticated_func(request, *args, **kwargs) else: result = authenticated_func(request, *args, **kwargs) function_auth = getattr(authenticated_func, "auth", None) wrapper.auth = function_auth func.auth = function_auth authenticated_func.auth = function_auth # If result is a coroutine, await it if inspect.iscoroutine(result): result = await result return result return wrapper