293 lines
10 KiB
Python
293 lines
10 KiB
Python
"""
|
|
Token event middleware for handling authentication and event tracking.
|
|
"""
|
|
|
|
import inspect
|
|
|
|
from functools import wraps
|
|
from typing import Callable, Dict, Any, Optional, Union
|
|
from fastapi import Request
|
|
from pydantic import BaseModel
|
|
|
|
from ApiLibrary.common.line_number import get_line_number_for_error
|
|
from ApiServices.Token.token_handler import TokenService
|
|
from ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
|
|
from Schemas.rules.rules import EndpointRestriction
|
|
|
|
from .auth_middleware import MiddlewareModule
|
|
from Schemas import Events
|
|
|
|
|
|
|
|
class EventFunctions:
|
|
|
|
def __init__(self, endpoint: str, request: Request):
|
|
self.endpoint = endpoint
|
|
self.request = request
|
|
|
|
|
|
def match_endpoint_with_accesiable_event(self) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Match an endpoint with accessible events.
|
|
|
|
Args:
|
|
endpoint: The endpoint to match
|
|
|
|
Returns:
|
|
Dict containing the endpoint registration data
|
|
None if endpoint is not found in database
|
|
"""
|
|
access_token = TokenService.get_access_token_from_request(self.request)
|
|
token_context = TokenService.get_object_via_access_key(
|
|
access_token=access_token
|
|
)
|
|
if token_context.is_employee:
|
|
reachable_event_codes: list[str] = (
|
|
token_context.selected_company.reachable_event_codes
|
|
)
|
|
elif token_context.is_occupant:
|
|
reachable_event_codes: list[str] = (
|
|
token_context.selected_occupant.reachable_event_codes
|
|
)
|
|
else:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Token not found",
|
|
)
|
|
|
|
if not access_token:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Token not found",
|
|
)
|
|
|
|
db = EndpointRestriction.new_session()
|
|
restriction = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_name == self.endpoint,
|
|
db=db,
|
|
).data
|
|
if not restriction:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Function code not found",
|
|
)
|
|
|
|
event_related = Events.filter_all(
|
|
Events.endpoint_id == restriction.id,
|
|
db=db,
|
|
).data
|
|
if not event_related:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="No event is registered for this user.",
|
|
)
|
|
an_event = event_related[0]
|
|
event_related_codes: list[str] = [
|
|
event.function_code for event in event_related
|
|
]
|
|
intersected_code: set = set(reachable_event_codes).intersection(
|
|
set(event_related_codes)
|
|
)
|
|
if not len(list(intersected_code)) == 1:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="No event is registered for this user.",
|
|
)
|
|
return {
|
|
"endpoint_url": self.endpoint,
|
|
"reachable_event_code": list(intersected_code)[0],
|
|
"class": an_event.function_class,
|
|
}
|
|
|
|
def retrieve_function_dict(self) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Retrieve function dictionary for a given endpoint.
|
|
|
|
Args:
|
|
endpoint: The endpoint to retrieve the function dictionary for
|
|
|
|
Returns:
|
|
Dictionary containing the function dictionary
|
|
None if endpoint is not found
|
|
"""
|
|
access_token = TokenService.get_access_token_from_request(self.request)
|
|
token_context = TokenService.get_object_via_access_key(
|
|
access_token=access_token
|
|
)
|
|
if token_context.is_employee:
|
|
reachable_event_codes: list[str] = (
|
|
token_context.selected_company.reachable_event_codes
|
|
)
|
|
elif token_context.is_occupant:
|
|
reachable_event_codes: list[str] = (
|
|
token_context.selected_occupant.reachable_event_codes
|
|
)
|
|
else:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Token not found",
|
|
)
|
|
|
|
if not access_token:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Token not found",
|
|
)
|
|
|
|
db = EndpointRestriction.new_session()
|
|
restriction = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_name == self.endpoint,
|
|
db=db,
|
|
).data
|
|
if not restriction:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Function code not found",
|
|
)
|
|
|
|
event_related = Events.filter_all(
|
|
Events.endpoint_id == restriction.id,
|
|
db=db,
|
|
).data
|
|
if not event_related:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="No event is registered for this user.",
|
|
)
|
|
an_event = event_related[0]
|
|
event_related_codes: list[str] = [
|
|
event.function_code for event in event_related
|
|
]
|
|
intersected_code: set = set(reachable_event_codes).intersection(
|
|
set(event_related_codes)
|
|
)
|
|
if not len(list(intersected_code)) == 1:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="No event is registered for this user.",
|
|
)
|
|
return {
|
|
"endpoint_url": self.endpoint,
|
|
"reachable_event_code": list(intersected_code)[0],
|
|
"class": an_event.function_class,
|
|
}
|
|
|
|
|
|
class TokenEventMiddleware:
|
|
"""
|
|
Module containing token and event handling functionality.
|
|
|
|
This class provides:
|
|
- Token and event context management
|
|
- Event validation decorator for endpoints
|
|
"""
|
|
|
|
@staticmethod
|
|
def event_required(
|
|
func: Callable[..., Dict[str, Any]]
|
|
) -> Callable[..., Dict[str, Any]]:
|
|
"""
|
|
Decorator for endpoints with token and event requirements.
|
|
This decorator:
|
|
1. First validates authentication using MiddlewareModule.auth_required
|
|
2. Then adds event tracking context
|
|
|
|
Args:
|
|
func: The function to be decorated
|
|
|
|
Returns:
|
|
Callable: The wrapped function with both auth and event handling
|
|
"""
|
|
# # First apply authentication
|
|
# authenticated_func = MiddlewareModule.auth_required(func)
|
|
authenticated_func = func
|
|
@wraps(authenticated_func)
|
|
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
|
|
|
|
# Get function code from the function's metadata
|
|
endpoint_url = getattr(authenticated_func, "url_of_endpoint", {})
|
|
if not endpoint_url:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Function code not found",
|
|
)
|
|
|
|
# Make handler available to all functions in the chain
|
|
func.func_code = EventFunctions(endpoint_url, request).match_endpoint_with_accesiable_event()
|
|
# Call the authenticated function
|
|
if inspect.iscoroutinefunction(authenticated_func):
|
|
return await authenticated_func(request, *args, **kwargs)
|
|
return authenticated_func(request, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
@staticmethod
|
|
def validation_required(
|
|
func: Callable[..., Dict[str, Any]]
|
|
) -> Callable[..., Dict[str, Any]]:
|
|
"""
|
|
Decorator for endpoints with token and event requirements.
|
|
This decorator:
|
|
1. First validates authentication using MiddlewareModule.auth_required
|
|
2. Then adds event tracking context
|
|
|
|
Args:
|
|
func: The function to be decorated
|
|
|
|
Returns:
|
|
Callable: The wrapped function with both auth and event handling
|
|
"""
|
|
# First apply authentication
|
|
authenticated_func = MiddlewareModule.auth_required(func)
|
|
|
|
@wraps(authenticated_func)
|
|
async def wrapper(
|
|
request: Request, *args: Any, **kwargs: Any
|
|
) -> Union[Dict[str, Any], BaseModel]:
|
|
# Handle both async and sync functions
|
|
endpoint_asked = getattr(kwargs.get("data", None), "data", None).get("endpoint", None)
|
|
if not endpoint_asked:
|
|
raise HTTPExceptionApi(
|
|
error_code="",
|
|
lang="en",
|
|
loc=get_line_number_for_error(),
|
|
sys_msg="Endpoint not found",
|
|
)
|
|
wrapper.validation_code = EventFunctions(endpoint_asked, request).retrieve_function_dict()
|
|
if inspect.iscoroutinefunction(authenticated_func):
|
|
result = await authenticated_func(request, *args, **kwargs)
|
|
else:
|
|
result = authenticated_func(request, *args, **kwargs)
|
|
function_auth = getattr(authenticated_func, "auth", None)
|
|
wrapper.auth = function_auth
|
|
func.auth = function_auth
|
|
authenticated_func.auth = function_auth
|
|
# If result is a coroutine, await it
|
|
if inspect.iscoroutine(result):
|
|
result = await result
|
|
return result
|
|
return wrapper
|