auth and token middleware context update
This commit is contained in:
@@ -5,110 +5,22 @@ Token event middleware for handling authentication and event tracking.
|
||||
import inspect
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Any, Optional, Union
|
||||
from typing import Callable, Dict, Any, Optional, Tuple, Union
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
|
||||
from ApiLayers.ApiServices.Token.token_handler import TokenService
|
||||
from ApiLayers.ApiValidations.Custom.wrapper_contexts import EventContext
|
||||
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
|
||||
from ApiLayers.Schemas import Events, EndpointRestriction
|
||||
from ApiLayers.AllConfigs.Redis.configs import RedisCategoryKeys
|
||||
|
||||
from Services.Redis.Actions.actions import RedisActions
|
||||
|
||||
from .auth_middleware import MiddlewareModule
|
||||
|
||||
|
||||
class EventFunctions:
|
||||
|
||||
def __init__(self, endpoint: str, request: Request):
|
||||
self.endpoint = endpoint
|
||||
self.request = request
|
||||
|
||||
|
||||
|
||||
def retrieve_function_dict(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve function dictionary for a given endpoint.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint to retrieve the function dictionary for
|
||||
|
||||
Returns:
|
||||
Dictionary containing the function dictionary
|
||||
None if endpoint is not found
|
||||
"""
|
||||
access_token = TokenService.get_access_token_from_request(self.request)
|
||||
token_context = TokenService.get_object_via_access_key(
|
||||
access_token=access_token
|
||||
)
|
||||
if token_context.is_employee:
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_company.reachable_event_codes
|
||||
)
|
||||
elif token_context.is_occupant:
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_occupant.reachable_event_codes
|
||||
)
|
||||
else:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
|
||||
if not access_token:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
|
||||
db = EndpointRestriction.new_session()
|
||||
restriction = EndpointRestriction.filter_one(
|
||||
EndpointRestriction.endpoint_name == self.endpoint,
|
||||
db=db,
|
||||
).data
|
||||
if not restriction:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Function code not found",
|
||||
)
|
||||
|
||||
event_related = Events.filter_all(
|
||||
Events.endpoint_id == restriction.id,
|
||||
db=db,
|
||||
).data
|
||||
if not event_related:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="No event is registered for this user.",
|
||||
)
|
||||
an_event = event_related[0]
|
||||
event_related_codes: list[str] = [
|
||||
event.function_code for event in event_related
|
||||
]
|
||||
intersected_code: set = set(reachable_event_codes).intersection(
|
||||
set(event_related_codes)
|
||||
)
|
||||
if not len(list(intersected_code)) == 1:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="No event is registered for this user.",
|
||||
)
|
||||
return {
|
||||
"endpoint_url": self.endpoint,
|
||||
"reachable_event_code": list(intersected_code)[0],
|
||||
"class": an_event.function_class,
|
||||
}
|
||||
|
||||
|
||||
class TokenEventMiddleware:
|
||||
"""
|
||||
Module containing token and event handling functionality.
|
||||
@@ -119,37 +31,18 @@ class TokenEventMiddleware:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def match_endpoint_with_accessible_event(request_from_scope, endpoint_from_scope) -> Optional[Dict[str, Any]]:
|
||||
def retrieve_access_content(request_from_scope: Request) -> Tuple[str, list[str]]:
|
||||
"""
|
||||
Match an endpoint with accessible events.
|
||||
Retrieves the access token and validates it.
|
||||
|
||||
Args:
|
||||
request_from_scope: The endpoint to match
|
||||
request_from_scope: The FastAPI request object
|
||||
|
||||
Returns:
|
||||
Dict containing the endpoint registration data
|
||||
None if endpoint is not found in database
|
||||
Tuple[str, list[str]]: The access token and a list of reachable event codes
|
||||
"""
|
||||
# Get token context from request
|
||||
access_token = TokenService.get_access_token_from_request(request_from_scope)
|
||||
token_context = TokenService.get_object_via_access_key(
|
||||
access_token=access_token
|
||||
)
|
||||
if token_context.is_employee:
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_company.reachable_event_codes
|
||||
)
|
||||
elif token_context.is_occupant:
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_occupant.reachable_event_codes
|
||||
)
|
||||
else:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
|
||||
if not access_token:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
@@ -157,50 +50,60 @@ class TokenEventMiddleware:
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
|
||||
# Get token context from Redis by access token and collect reachable event codes
|
||||
token_context = TokenService.get_object_via_access_key(access_token=access_token)
|
||||
if token_context.is_employee:
|
||||
reachable_event_codes: list[str] = token_context.selected_company.reachable_event_codes
|
||||
elif token_context.is_occupant:
|
||||
reachable_event_codes: list[str] = token_context.selected_occupant.reachable_event_codes
|
||||
else:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
return token_context, reachable_event_codes
|
||||
|
||||
db = EndpointRestriction.new_session()
|
||||
restriction = EndpointRestriction.filter_one(
|
||||
EndpointRestriction.endpoint_name == endpoint_from_scope,
|
||||
db=db,
|
||||
).data
|
||||
if not restriction:
|
||||
@staticmethod
|
||||
def retrieve_intersected_event_code(request: Request, reachable_event_codes: list[str]) -> str:
|
||||
"""
|
||||
Match an endpoint with accessible events.
|
||||
|
||||
Args:
|
||||
request: The endpoint to match
|
||||
|
||||
Returns:
|
||||
Dict containing the endpoint registration data
|
||||
None if endpoint is not found in database
|
||||
"""
|
||||
endpoint_url = str(request.url.path)
|
||||
# Get the endpoint URL for matching with events
|
||||
function_codes_of_endpoint = RedisActions.get_json(
|
||||
list_keys=[
|
||||
RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url
|
||||
]
|
||||
)
|
||||
function_code_list_of_event = function_codes_of_endpoint.first
|
||||
if not function_codes_of_endpoint.status:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Function code not found",
|
||||
)
|
||||
|
||||
event_related = Events.filter_all(
|
||||
Events.endpoint_id == restriction.id,
|
||||
db=db,
|
||||
).data
|
||||
if not event_related:
|
||||
|
||||
# Intersect function codes with user accers objects available event codes
|
||||
intersected_code = list(set(function_code_list_of_event) & set(reachable_event_codes))
|
||||
if not len(intersected_code) == 1:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="No event is registered for this user.",
|
||||
)
|
||||
an_event = event_related[0]
|
||||
event_related_codes: list[str] = [
|
||||
event.function_code for event in event_related
|
||||
]
|
||||
intersected_code: set = set(reachable_event_codes).intersection(
|
||||
set(event_related_codes)
|
||||
)
|
||||
if not len(list(intersected_code)) == 1:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="No event is registered for this user.",
|
||||
)
|
||||
return {
|
||||
"endpoint_url": endpoint_from_scope,
|
||||
"reachable_event_code": list(intersected_code)[0],
|
||||
"class": an_event.function_class,
|
||||
}
|
||||
return endpoint_url, intersected_code[0]
|
||||
|
||||
@classmethod
|
||||
def event_required(cls, func: Callable) -> Callable:
|
||||
@@ -216,88 +119,22 @@ class TokenEventMiddleware:
|
||||
Returns:
|
||||
Callable: The wrapped function with both auth and event handling
|
||||
"""
|
||||
# First apply authentication
|
||||
authenticated_func = MiddlewareModule.auth_required(func)
|
||||
|
||||
@wraps(authenticated_func)
|
||||
@wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Get the endpoint URL for matching with events
|
||||
endpoint_url = str(request.url.path)
|
||||
|
||||
# Set func_code first
|
||||
func_code = "8aytr-"
|
||||
setattr(wrapper, 'func_code', func_code)
|
||||
|
||||
# Get and validate token context from request
|
||||
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
|
||||
token_context, reachable_event_codes = {"token": "context", "context": {}}, ["g1j8i6j7-9k4h-0h6l-4i3j-2j0k1k0j0i0k"]
|
||||
endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
|
||||
event_context = EventContext(auth=token_context, code=reachable_event_code, url=endpoint_url)
|
||||
|
||||
# Get auth context from the authenticated function's wrapper
|
||||
auth_context = getattr(authenticated_func, 'auth', None)
|
||||
print('auth_context', auth_context)
|
||||
if auth_context is not None:
|
||||
setattr(wrapper, 'auth', auth_context)
|
||||
if token_context is not None:
|
||||
setattr(wrapper, 'event_context', event_context)
|
||||
setattr(func, 'event_context', event_context)
|
||||
|
||||
# Execute the authenticated function and get its result
|
||||
if inspect.iscoroutinefunction(authenticated_func):
|
||||
result = await authenticated_func(request, *args, **kwargs)
|
||||
else:
|
||||
result = authenticated_func(request, *args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
# Copy any existing attributes from the authenticated function
|
||||
for attr in dir(authenticated_func):
|
||||
if not attr.startswith('__'):
|
||||
setattr(wrapper, attr, getattr(authenticated_func, attr))
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def validation_required(
|
||||
func: Callable[..., Dict[str, Any]]
|
||||
) -> Callable[..., Dict[str, Any]]:
|
||||
"""
|
||||
Decorator for endpoints with token and event requirements.
|
||||
This decorator:
|
||||
1. First validates authentication using MiddlewareModule.auth_required
|
||||
2. Then adds event tracking context
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped function with both auth and event handling
|
||||
"""
|
||||
# First apply authentication
|
||||
# authenticated_func = MiddlewareModule.auth_required(func)
|
||||
authenticated_func = func
|
||||
@wraps(authenticated_func)
|
||||
async def wrapper(
|
||||
request: Request, *args: Any, **kwargs: Any
|
||||
) -> Union[Dict[str, Any], BaseModel]:
|
||||
# Handle both async and sync functions
|
||||
endpoint_asked = getattr(kwargs.get("data", None), "data", None).get(
|
||||
"endpoint", None
|
||||
)
|
||||
if not endpoint_asked:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Endpoint not found",
|
||||
)
|
||||
func.func_code = cls.match_endpoint_with_accessible_event(
|
||||
endpoint_url, request
|
||||
)
|
||||
if inspect.iscoroutinefunction(authenticated_func):
|
||||
result = await authenticated_func(request, *args, **kwargs)
|
||||
else:
|
||||
result = authenticated_func(request, *args, **kwargs)
|
||||
function_auth = getattr(authenticated_func, "auth", None)
|
||||
wrapper.auth = function_auth
|
||||
func.auth = function_auth
|
||||
authenticated_func.auth = function_auth
|
||||
if inspect.iscoroutinefunction(authenticated_func):
|
||||
result = await authenticated_func(request, *args, **kwargs)
|
||||
else:
|
||||
result = authenticated_func(request, *args, **kwargs)
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(request, *args, **kwargs)
|
||||
else:
|
||||
@@ -305,3 +142,41 @@ class TokenEventMiddleware:
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
# event_required is already sets function_code state to wrapper
|
||||
# @classmethod
|
||||
# def validation_required(cls, func: Callable[..., Dict[str, Any]]) -> Callable[..., Dict[str, Any]]:
|
||||
# """
|
||||
# Decorator for endpoints with token and event requirements.
|
||||
# This decorator:
|
||||
# 1. First validates authentication using MiddlewareModule.auth_required
|
||||
# 2. Then adds event tracking context
|
||||
|
||||
# Args:
|
||||
# func: The function to be decorated
|
||||
|
||||
# Returns:
|
||||
# Callable: The wrapped function with both auth and event handling
|
||||
# """
|
||||
|
||||
# @wraps(func)
|
||||
# async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Union[Dict[str, Any], BaseModel]:
|
||||
|
||||
# # Get and validate token context from request
|
||||
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
|
||||
# endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
|
||||
|
||||
# # Get auth context from the authenticated function's wrapper
|
||||
# if token_context is not None:
|
||||
# setattr(wrapper, 'auth', token_context)
|
||||
# setattr(wrapper, 'url', endpoint_url)
|
||||
# setattr(wrapper, 'func_code', reachable_event_code)
|
||||
|
||||
# # Execute the authenticated function and get its result
|
||||
# if inspect.iscoroutinefunction(func):
|
||||
# result = await func(request, *args, **kwargs)
|
||||
# else:
|
||||
# result = func(request, *args, **kwargs)
|
||||
# return result
|
||||
|
||||
# return wrapper
|
||||
Reference in New Issue
Block a user