auth and token middleware context update

This commit is contained in:
2025-01-26 20:18:06 +03:00
parent 3d5a43220e
commit a7e48d8755
17 changed files with 265 additions and 345 deletions

View File

@@ -17,6 +17,12 @@ class WagRedis:
)
class RedisAuthKeys:
AUTH: str = "AUTH"
OCCUPANT: str = "OCCUPANT"
EMPLOYEE: str = "EMPLOYEE"
class RedisCategoryKeys:
REBUILD: str = "REBUILD"
ENDPOINT2CLASS: str = "ENDPOINT2CLASS"
@@ -28,6 +34,3 @@ class RedisCategoryKeys:
MENU_FIRST_LAYER: str = "MENU_FIRST_LAYER"
PAGE_MAPPER: str = "PAGE_MAPPER"
MENU_MAPPER: str = "MENU_MAPPER"
AUTH: str = "AUTH"
OCCUPANT: str = "OCCUPANT"
EMPLOYEE: str = "EMPLOYEE"

View File

@@ -16,23 +16,17 @@ class CreateEndpointFromCluster:
def __init__(self, **kwargs):
self.router: CategoryCluster = kwargs.get("router")
self.method_endpoint: MethodToEvent = kwargs.get("method_endpoint")
self.unique_id = str(uuid.uuid4())[:8] # Use first 8 chars of UUID for brevity
self.attach_router()
def attach_router(self):
method = getattr(self.router, self.method_endpoint.METHOD.lower())
# Create a unique operation ID based on the endpoint path, method, and a unique identifier
base_path = self.method_endpoint.URL.strip('/').replace('/', '_').replace('-', '_')
operation_id = f"{base_path}_{self.method_endpoint.METHOD.lower()}_{self.unique_id}"
kwargs = {
"path": self.method_endpoint.URL,
"summary": self.method_endpoint.SUMMARY,
"description": self.method_endpoint.DESCRIPTION,
"operation_id": operation_id
}
if hasattr(self.method_endpoint, 'RESPONSE_MODEL') and self.method_endpoint.RESPONSE_MODEL is not None:
kwargs["response_model"] = self.method_endpoint.RESPONSE_MODEL

View File

@@ -1,6 +1,6 @@
"""Token service for handling authentication tokens and user sessions."""
from typing import List, Union, TypeVar, Dict, Any, Optional, TYPE_CHECKING
from typing import List, Union, TypeVar, Dict, Any, TYPE_CHECKING
from ApiLayers.AllConfigs.Token.config import Auth
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error

View File

@@ -0,0 +1,18 @@
from typing import Optional, Any
from pydantic import BaseModel
class DefaultContext(BaseModel):
...
class EventContext(DefaultContext):
auth: Any
code: str
url: str
class AuthContext(DefaultContext):
auth: Any
url: str

View File

@@ -15,6 +15,7 @@ from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
from ApiLayers.ApiValidations.Custom.wrapper_contexts import AuthContext
from ApiLayers.ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
from ApiLayers.AllConfigs.Token.config import Auth
from ApiLayers.ApiServices.Token.token_handler import TokenService
@@ -87,20 +88,25 @@ class MiddlewareModule:
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
# Get and validate token context from request
auth_context = {
"is_employee": False,
"is_occupant": False,
"context": {}
}
endpoint_url = str(request.url.path)
auth_context = AuthContext(
auth={"token": {"access_token": "", "refresher_token": "", "context": {}}},
url=endpoint_url,
)
# Set auth context on the wrapper function itself
setattr(wrapper, 'auth', auth_context)
setattr(func, 'auth_context', auth_context)
setattr(wrapper, 'auth_context', auth_context)
# Call the original endpoint function
if inspect.iscoroutinefunction(func):
result = await func(request, *args, **kwargs)
else:
result = func(request, *args, **kwargs)
# Set auth context on the wrapper function itself
setattr(func, 'auth_context', auth_context)
setattr(wrapper, 'auth_context', auth_context)
return result
return wrapper

View File

@@ -5,110 +5,22 @@ Token event middleware for handling authentication and event tracking.
import inspect
from functools import wraps
from typing import Callable, Dict, Any, Optional, Union
from typing import Callable, Dict, Any, Optional, Tuple, Union
from fastapi import Request
from pydantic import BaseModel
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
from ApiLayers.ApiServices.Token.token_handler import TokenService
from ApiLayers.ApiValidations.Custom.wrapper_contexts import EventContext
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
from ApiLayers.Schemas import Events, EndpointRestriction
from ApiLayers.AllConfigs.Redis.configs import RedisCategoryKeys
from Services.Redis.Actions.actions import RedisActions
from .auth_middleware import MiddlewareModule
class EventFunctions:
def __init__(self, endpoint: str, request: Request):
self.endpoint = endpoint
self.request = request
def retrieve_function_dict(self) -> Optional[Dict[str, Any]]:
"""
Retrieve function dictionary for a given endpoint.
Args:
endpoint: The endpoint to retrieve the function dictionary for
Returns:
Dictionary containing the function dictionary
None if endpoint is not found
"""
access_token = TokenService.get_access_token_from_request(self.request)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == self.endpoint,
db=db,
).data
if not restriction:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": self.endpoint,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
class TokenEventMiddleware:
"""
Module containing token and event handling functionality.
@@ -119,37 +31,18 @@ class TokenEventMiddleware:
"""
@staticmethod
def match_endpoint_with_accessible_event(request_from_scope, endpoint_from_scope) -> Optional[Dict[str, Any]]:
def retrieve_access_content(request_from_scope: Request) -> Tuple[str, list[str]]:
"""
Match an endpoint with accessible events.
Retrieves the access token and validates it.
Args:
request_from_scope: The endpoint to match
request_from_scope: The FastAPI request object
Returns:
Dict containing the endpoint registration data
None if endpoint is not found in database
Tuple[str, list[str]]: The access token and a list of reachable event codes
"""
# Get token context from request
access_token = TokenService.get_access_token_from_request(request_from_scope)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
@@ -157,50 +50,60 @@ class TokenEventMiddleware:
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
# Get token context from Redis by access token and collect reachable event codes
token_context = TokenService.get_object_via_access_key(access_token=access_token)
if token_context.is_employee:
reachable_event_codes: list[str] = token_context.selected_company.reachable_event_codes
elif token_context.is_occupant:
reachable_event_codes: list[str] = token_context.selected_occupant.reachable_event_codes
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
return token_context, reachable_event_codes
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == endpoint_from_scope,
db=db,
).data
if not restriction:
@staticmethod
def retrieve_intersected_event_code(request: Request, reachable_event_codes: list[str]) -> str:
"""
Match an endpoint with accessible events.
Args:
request: The endpoint to match
Returns:
Dict containing the endpoint registration data
None if endpoint is not found in database
"""
endpoint_url = str(request.url.path)
# Get the endpoint URL for matching with events
function_codes_of_endpoint = RedisActions.get_json(
list_keys=[
RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url
]
)
function_code_list_of_event = function_codes_of_endpoint.first
if not function_codes_of_endpoint.status:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
# Intersect function codes with user accers objects available event codes
intersected_code = list(set(function_code_list_of_event) & set(reachable_event_codes))
if not len(intersected_code) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": endpoint_from_scope,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
return endpoint_url, intersected_code[0]
@classmethod
def event_required(cls, func: Callable) -> Callable:
@@ -216,88 +119,22 @@ class TokenEventMiddleware:
Returns:
Callable: The wrapped function with both auth and event handling
"""
# First apply authentication
authenticated_func = MiddlewareModule.auth_required(func)
@wraps(authenticated_func)
@wraps(func)
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
# Get the endpoint URL for matching with events
endpoint_url = str(request.url.path)
# Set func_code first
func_code = "8aytr-"
setattr(wrapper, 'func_code', func_code)
# Get and validate token context from request
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
token_context, reachable_event_codes = {"token": "context", "context": {}}, ["g1j8i6j7-9k4h-0h6l-4i3j-2j0k1k0j0i0k"]
endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
event_context = EventContext(auth=token_context, code=reachable_event_code, url=endpoint_url)
# Get auth context from the authenticated function's wrapper
auth_context = getattr(authenticated_func, 'auth', None)
print('auth_context', auth_context)
if auth_context is not None:
setattr(wrapper, 'auth', auth_context)
if token_context is not None:
setattr(wrapper, 'event_context', event_context)
setattr(func, 'event_context', event_context)
# Execute the authenticated function and get its result
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
return result
# Copy any existing attributes from the authenticated function
for attr in dir(authenticated_func):
if not attr.startswith('__'):
setattr(wrapper, attr, getattr(authenticated_func, attr))
return wrapper
@staticmethod
def validation_required(
func: Callable[..., Dict[str, Any]]
) -> Callable[..., Dict[str, Any]]:
"""
Decorator for endpoints with token and event requirements.
This decorator:
1. First validates authentication using MiddlewareModule.auth_required
2. Then adds event tracking context
Args:
func: The function to be decorated
Returns:
Callable: The wrapped function with both auth and event handling
"""
# First apply authentication
# authenticated_func = MiddlewareModule.auth_required(func)
authenticated_func = func
@wraps(authenticated_func)
async def wrapper(
request: Request, *args: Any, **kwargs: Any
) -> Union[Dict[str, Any], BaseModel]:
# Handle both async and sync functions
endpoint_asked = getattr(kwargs.get("data", None), "data", None).get(
"endpoint", None
)
if not endpoint_asked:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Endpoint not found",
)
func.func_code = cls.match_endpoint_with_accessible_event(
endpoint_url, request
)
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
function_auth = getattr(authenticated_func, "auth", None)
wrapper.auth = function_auth
func.auth = function_auth
authenticated_func.auth = function_auth
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
if inspect.iscoroutinefunction(func):
result = await func(request, *args, **kwargs)
else:
@@ -305,3 +142,41 @@ class TokenEventMiddleware:
return result
return wrapper
# event_required is already sets function_code state to wrapper
# @classmethod
# def validation_required(cls, func: Callable[..., Dict[str, Any]]) -> Callable[..., Dict[str, Any]]:
# """
# Decorator for endpoints with token and event requirements.
# This decorator:
# 1. First validates authentication using MiddlewareModule.auth_required
# 2. Then adds event tracking context
# Args:
# func: The function to be decorated
# Returns:
# Callable: The wrapped function with both auth and event handling
# """
# @wraps(func)
# async def wrapper(request: Request, *args: Any, **kwargs: Any) -> Union[Dict[str, Any], BaseModel]:
# # Get and validate token context from request
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
# endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
# # Get auth context from the authenticated function's wrapper
# if token_context is not None:
# setattr(wrapper, 'auth', token_context)
# setattr(wrapper, 'url', endpoint_url)
# setattr(wrapper, 'func_code', reachable_event_code)
# # Execute the authenticated function and get its result
# if inspect.iscoroutinefunction(func):
# result = await func(request, *args, **kwargs)
# else:
# result = func(request, *args, **kwargs)
# return result
# return wrapper