wag-managment-api-service-v.../ApiLayers/Middleware/token_event_middleware.py

308 lines
11 KiB
Python

"""
Token event middleware for handling authentication and event tracking.
"""
import inspect
from functools import wraps
from typing import Callable, Dict, Any, Optional, Union
from fastapi import Request
from pydantic import BaseModel
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
from ApiLayers.ApiServices.Token.token_handler import TokenService
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
from ApiLayers.Schemas import Events, EndpointRestriction
from .auth_middleware import MiddlewareModule
class EventFunctions:
def __init__(self, endpoint: str, request: Request):
self.endpoint = endpoint
self.request = request
def retrieve_function_dict(self) -> Optional[Dict[str, Any]]:
"""
Retrieve function dictionary for a given endpoint.
Args:
endpoint: The endpoint to retrieve the function dictionary for
Returns:
Dictionary containing the function dictionary
None if endpoint is not found
"""
access_token = TokenService.get_access_token_from_request(self.request)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == self.endpoint,
db=db,
).data
if not restriction:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": self.endpoint,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
class TokenEventMiddleware:
"""
Module containing token and event handling functionality.
This class provides:
- Token and event context management
- Event validation decorator for endpoints
"""
@staticmethod
def match_endpoint_with_accessible_event(request_from_scope, endpoint_from_scope) -> Optional[Dict[str, Any]]:
"""
Match an endpoint with accessible events.
Args:
request_from_scope: The endpoint to match
Returns:
Dict containing the endpoint registration data
None if endpoint is not found in database
"""
access_token = TokenService.get_access_token_from_request(request_from_scope)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == endpoint_from_scope,
db=db,
).data
if not restriction:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": endpoint_from_scope,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
@classmethod
def event_required(cls, func: Callable) -> Callable:
"""
Decorator for endpoints with token and event requirements.
This decorator:
1. First validates authentication using MiddlewareModule.auth_required
2. Then adds event tracking context
Args:
func: The function to be decorated
Returns:
Callable: The wrapped function with both auth and event handling
"""
# First apply authentication
authenticated_func = MiddlewareModule.auth_required(func)
@wraps(authenticated_func)
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
# Get the endpoint URL for matching with events
endpoint_url = str(request.url.path)
# Set func_code first
func_code = "8aytr-"
setattr(wrapper, 'func_code', func_code)
# Get auth context from the authenticated function's wrapper
auth_context = getattr(authenticated_func, 'auth', None)
print('auth_context', auth_context)
if auth_context is not None:
setattr(wrapper, 'auth', auth_context)
# Execute the authenticated function and get its result
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
return result
# Copy any existing attributes from the authenticated function
for attr in dir(authenticated_func):
if not attr.startswith('__'):
setattr(wrapper, attr, getattr(authenticated_func, attr))
return wrapper
@staticmethod
def validation_required(
func: Callable[..., Dict[str, Any]]
) -> Callable[..., Dict[str, Any]]:
"""
Decorator for endpoints with token and event requirements.
This decorator:
1. First validates authentication using MiddlewareModule.auth_required
2. Then adds event tracking context
Args:
func: The function to be decorated
Returns:
Callable: The wrapped function with both auth and event handling
"""
# First apply authentication
# authenticated_func = MiddlewareModule.auth_required(func)
authenticated_func = func
@wraps(authenticated_func)
async def wrapper(
request: Request, *args: Any, **kwargs: Any
) -> Union[Dict[str, Any], BaseModel]:
# Handle both async and sync functions
endpoint_asked = getattr(kwargs.get("data", None), "data", None).get(
"endpoint", None
)
if not endpoint_asked:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Endpoint not found",
)
func.func_code = cls.match_endpoint_with_accessible_event(
endpoint_url, request
)
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
function_auth = getattr(authenticated_func, "auth", None)
wrapper.auth = function_auth
func.auth = function_auth
authenticated_func.auth = function_auth
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
if inspect.iscoroutinefunction(func):
result = await func(request, *args, **kwargs)
else:
result = func(request, *args, **kwargs)
return result
return wrapper