validations updated

This commit is contained in:
2025-01-19 21:06:00 +03:00
parent d6785ed36f
commit 8e34497c80
49 changed files with 2642 additions and 142 deletions

View File

@@ -14,6 +14,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from ApiLibrary.common.line_number import get_line_number_for_error
from ErrorHandlers.ErrorHandlers.api_exc_handler import HTTPExceptionApi
from AllConfigs.Token.config import Auth
import inspect
@@ -86,9 +87,8 @@ class MiddlewareModule:
async def wrapper(request: Request, *args, **kwargs):
# Get and validate token context from request
# Create auth context and Attach auth context to both wrapper and original function
func.auth = cls.get_user_from_request(
request
) # This ensures the context is available in both places
func.auth = cls.get_user_from_request(request)
wrapper.auth = func.auth
# Call the original endpoint function
if inspect.iscoroutinefunction(func):
return await func(request, *args, **kwargs)
@@ -141,10 +141,22 @@ class LoggerTimingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Log the request
print(f"Handling request: {request.method} {request.url}")
import arrow
headers = dict(request.headers)
response = await call_next(request)
# Log the response
print(
f"Completed request: {request.method} {request.url} with status {response.status_code}"
"Loggers :",
{
"url": request.url,
"method": request.method,
"access_token": headers.get(Auth.ACCESS_TOKEN_TAG, ""),
"referer": headers.get("referer", ""),
"origin": headers.get("origin", ""),
"user-agent": headers.get("user-agent", ""),
"datetime": arrow.now().format("YYYY-MM-DD HH:mm:ss ZZ"),
"status_code": response.status_code,
},
)
return response

View File

@@ -2,11 +2,196 @@
Token event middleware for handling authentication and event tracking.
"""
from functools import wraps
from typing import Callable, Dict, Any
from .auth_middleware import MiddlewareModule
import inspect
from functools import wraps
from typing import Callable, Dict, Any, Optional, Union
from fastapi import Request
from pydantic import BaseModel
from ApiLibrary.common.line_number import get_line_number_for_error
from ApiServices.Token.token_handler import TokenService
from ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
from Schemas.rules.rules import EndpointRestriction
from .auth_middleware import MiddlewareModule
from Schemas import Events
class EventFunctions:
def __init__(self, endpoint: str, request: Request):
self.endpoint = endpoint
self.request = request
def match_endpoint_with_accesiable_event(self) -> Optional[Dict[str, Any]]:
"""
Match an endpoint with accessible events.
Args:
endpoint: The endpoint to match
Returns:
Dict containing the endpoint registration data
None if endpoint is not found in database
"""
access_token = TokenService.get_access_token_from_request(self.request)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == self.endpoint,
db=db,
).data
if not restriction:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": self.endpoint,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
def retrieve_function_dict(self) -> Optional[Dict[str, Any]]:
"""
Retrieve function dictionary for a given endpoint.
Args:
endpoint: The endpoint to retrieve the function dictionary for
Returns:
Dictionary containing the function dictionary
None if endpoint is not found
"""
access_token = TokenService.get_access_token_from_request(self.request)
token_context = TokenService.get_object_via_access_key(
access_token=access_token
)
if token_context.is_employee:
reachable_event_codes: list[str] = (
token_context.selected_company.reachable_event_codes
)
elif token_context.is_occupant:
reachable_event_codes: list[str] = (
token_context.selected_occupant.reachable_event_codes
)
else:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
if not access_token:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Token not found",
)
db = EndpointRestriction.new_session()
restriction = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_name == self.endpoint,
db=db,
).data
if not restriction:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
event_related = Events.filter_all(
Events.endpoint_id == restriction.id,
db=db,
).data
if not event_related:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
an_event = event_related[0]
event_related_codes: list[str] = [
event.function_code for event in event_related
]
intersected_code: set = set(reachable_event_codes).intersection(
set(event_related_codes)
)
if not len(list(intersected_code)) == 1:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="No event is registered for this user.",
)
return {
"endpoint_url": self.endpoint,
"reachable_event_code": list(intersected_code)[0],
"class": an_event.function_class,
}
class TokenEventMiddleware:
"""
@@ -37,17 +222,71 @@ class TokenEventMiddleware:
authenticated_func = MiddlewareModule.auth_required(func)
@wraps(authenticated_func)
async def wrapper(*args, **kwargs) -> Dict[str, Any]:
# Create handler with context
function_code = (
"7192c2aa-5352-4e36-98b3-dafb7d036a3d" # Keep function_code as URL
)
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
# Get function code from the function's metadata
endpoint_url = getattr(authenticated_func, "url_of_endpoint", {})
if not endpoint_url:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Function code not found",
)
# Make handler available to all functions in the chain
func.func_code = {"function_code": function_code}
func.func_code = EventFunctions(endpoint_url, request).match_endpoint_with_accesiable_event()
# Call the authenticated function
if inspect.iscoroutinefunction(authenticated_func):
return await authenticated_func(*args, **kwargs)
return authenticated_func(*args, **kwargs)
return await authenticated_func(request, *args, **kwargs)
return authenticated_func(request, *args, **kwargs)
return wrapper
@staticmethod
def validation_required(
func: Callable[..., Dict[str, Any]]
) -> Callable[..., Dict[str, Any]]:
"""
Decorator for endpoints with token and event requirements.
This decorator:
1. First validates authentication using MiddlewareModule.auth_required
2. Then adds event tracking context
Args:
func: The function to be decorated
Returns:
Callable: The wrapped function with both auth and event handling
"""
# First apply authentication
authenticated_func = MiddlewareModule.auth_required(func)
@wraps(authenticated_func)
async def wrapper(
request: Request, *args: Any, **kwargs: Any
) -> Union[Dict[str, Any], BaseModel]:
# Handle both async and sync functions
endpoint_asked = getattr(kwargs.get("data", None), "data", None).get("endpoint", None)
if not endpoint_asked:
raise HTTPExceptionApi(
error_code="",
lang="en",
loc=get_line_number_for_error(),
sys_msg="Endpoint not found",
)
wrapper.validation_code = EventFunctions(endpoint_asked, request).retrieve_function_dict()
if inspect.iscoroutinefunction(authenticated_func):
result = await authenticated_func(request, *args, **kwargs)
else:
result = authenticated_func(request, *args, **kwargs)
function_auth = getattr(authenticated_func, "auth", None)
wrapper.auth = function_auth
func.auth = function_auth
authenticated_func.auth = function_auth
# If result is a coroutine, await it
if inspect.iscoroutine(result):
result = await result
return result
return wrapper