middleware and respnse models updated
This commit is contained in:
@@ -13,7 +13,7 @@ class PasswordModule:
|
||||
return str(uuid.uuid4()) if str_std else uuid.uuid4()
|
||||
|
||||
@staticmethod
|
||||
def generate_token(length=32):
|
||||
def generate_token(length=32) -> str:
|
||||
letters = "abcdefghijklmnopqrstuvwxyz"
|
||||
merged_letters = [letter for letter in letters] + [
|
||||
letter.upper() for letter in letters
|
||||
@@ -27,17 +27,17 @@ class PasswordModule:
|
||||
return token_generated
|
||||
|
||||
@staticmethod
|
||||
def generate_access_token():
|
||||
def generate_access_token() -> str:
|
||||
return secrets.token_urlsafe(Auth.ACCESS_TOKEN_LENGTH)
|
||||
|
||||
@staticmethod
|
||||
def generate_refresher_token():
|
||||
def generate_refresher_token() -> str:
|
||||
return secrets.token_urlsafe(Auth.REFRESHER_TOKEN_LENGTH)
|
||||
|
||||
@staticmethod
|
||||
def create_hashed_password(domain: str, id_: str, password: str):
|
||||
def create_hashed_password(domain: str, id_: str, password: str) -> str:
|
||||
return hashlib.sha256(f"{domain}:{id_}:{password}".encode("utf-8")).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def check_password(cls, domain, id_, password, password_hashed):
|
||||
def check_password(cls, domain, id_, password, password_hashed) -> bool:
|
||||
return cls.create_hashed_password(domain, id_, password) == password_hashed
|
||||
|
||||
@@ -20,14 +20,17 @@ class CreateEndpointFromCluster:
|
||||
|
||||
def attach_router(self):
|
||||
method = getattr(self.router, self.method_endpoint.METHOD.lower())
|
||||
|
||||
|
||||
# Create a unique operation ID based on the endpoint path, method, and a unique identifier
|
||||
kwargs = {
|
||||
"path": self.method_endpoint.URL,
|
||||
"summary": self.method_endpoint.SUMMARY,
|
||||
"description": self.method_endpoint.DESCRIPTION,
|
||||
}
|
||||
if hasattr(self.method_endpoint, 'RESPONSE_MODEL') and self.method_endpoint.RESPONSE_MODEL is not None:
|
||||
if (
|
||||
hasattr(self.method_endpoint, "RESPONSE_MODEL")
|
||||
and self.method_endpoint.RESPONSE_MODEL is not None
|
||||
):
|
||||
kwargs["response_model"] = self.method_endpoint.RESPONSE_MODEL
|
||||
|
||||
|
||||
method(**kwargs)(self.method_endpoint.endpoint_callable)
|
||||
|
||||
@@ -2,4 +2,3 @@ from Services.Redis import RedisActions, AccessToken
|
||||
from Services.Redis.Models.cluster import RedisList
|
||||
|
||||
redis_list = RedisList(redis_key="test")
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from ApiLayers.ErrorHandlers import HTTPExceptionApi
|
||||
from ApiLayers.ApiValidations.Request.authentication import Login
|
||||
from ApiLayers.ApiLibrary.token.password_module import PasswordModule
|
||||
from ApiLayers.ApiLibrary.common.line_number import get_line_number_for_error
|
||||
from ApiLayers.ErrorHandlers import HTTPExceptionApi
|
||||
|
||||
|
||||
class UserLoginModule:
|
||||
@@ -11,13 +9,27 @@ class UserLoginModule:
|
||||
def __init__(self, request: "Request"):
|
||||
self.request = request
|
||||
self.user = None
|
||||
self.access_object = None
|
||||
self.access_token = None
|
||||
self.refresh_token = None
|
||||
|
||||
@property
|
||||
def as_dict(self) -> dict:
|
||||
return {
|
||||
"user": self.user,
|
||||
"access_object": self.access_object,
|
||||
"access_token": self.access_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_user_exists(access_key: str):
|
||||
from ApiLayers.Schemas import Users
|
||||
|
||||
"""Check if user exists."""
|
||||
db_session = Users.new_session()
|
||||
"""
|
||||
Check if the user exists in the database.
|
||||
"""
|
||||
db_session = Users.new_session() # Check if user exists.
|
||||
if "@" in access_key:
|
||||
found_user: Users = Users.filter_one(
|
||||
Users.email == access_key.lower(), db=db_session
|
||||
@@ -31,39 +43,48 @@ class UserLoginModule:
|
||||
error_code="HTTP_400_BAD_REQUEST",
|
||||
lang="en",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="User not found",
|
||||
sys_msg="check_user_exists: User not found",
|
||||
)
|
||||
return found_user
|
||||
|
||||
def login_user_via_credentials(self, access_data: "Login") -> Dict[str, Any]:
|
||||
def login_user_via_credentials(self, access_data: "Login") -> None:
|
||||
from ApiLayers.ApiServices.Token.token_handler import TokenService
|
||||
from ApiLayers.Schemas import Users
|
||||
|
||||
"""
|
||||
Login the user via the credentials.
|
||||
"""
|
||||
|
||||
# Get the actual data from the BaseRequestModel if needed
|
||||
found_user: Users = self.check_user_exists(access_key=access_data.access_key)
|
||||
self.user = found_user
|
||||
if len(found_user.hash_password) < 5:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="HTTP_400_BAD_REQUEST",
|
||||
lang=found_user.lang,
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Invalid password create a password to user first",
|
||||
sys_msg="login_user_via_credentials: Invalid password create a password to user first",
|
||||
)
|
||||
# Check if the password is correct
|
||||
if PasswordModule.check_password(
|
||||
domain=access_data.domain,
|
||||
id_=found_user.uu_id,
|
||||
password=access_data.password,
|
||||
password_hashed=found_user.hash_password,
|
||||
domain=access_data.domain, id_=found_user.uu_id,
|
||||
password=access_data.password, password_hashed=found_user.hash_password,
|
||||
):
|
||||
return TokenService.set_access_token_to_redis(
|
||||
request=self.request,
|
||||
user=found_user,
|
||||
domain=access_data.domain,
|
||||
remember=access_data.remember_me,
|
||||
# Set the access token to the redis
|
||||
token_response = TokenService.set_access_token_to_redis(
|
||||
request=self.request, user=found_user, domain=access_data.domain, remember=access_data.remember_me,
|
||||
)
|
||||
# Set the user and token information to the instance
|
||||
self.user = found_user.get_dict()
|
||||
self.access_token = token_response.get("access_token")
|
||||
self.refresh_token = token_response.get("refresh_token")
|
||||
self.access_object = {
|
||||
"user_type": token_response.get("user_type", None),
|
||||
"selection_list": token_response.get("selection_list", {})
|
||||
}
|
||||
return None
|
||||
raise HTTPExceptionApi(
|
||||
error_code="HTTP_400_BAD_REQUEST",
|
||||
lang=found_user.lang,
|
||||
lang="tr",
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="login_user_via_credentials raised error",
|
||||
sys_msg="login_user_via_credentials: raised an unknown error",
|
||||
)
|
||||
|
||||
@@ -250,7 +250,7 @@ class TokenService:
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"user_type": UserType.employee.name,
|
||||
"companies_list": companies_list,
|
||||
"selection_list": companies_list,
|
||||
}
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
@@ -264,8 +264,8 @@ class TokenService:
|
||||
"""Remove all tokens for a user with specific domain."""
|
||||
redis_rows = cls._get_user_tokens(user)
|
||||
for redis_row in redis_rows.all:
|
||||
if redis_row.data.get("domain") == domain:
|
||||
RedisActions.delete_key(redis_row.key)
|
||||
if redis_row.row.get("domain") == domain:
|
||||
redis_row.delete()
|
||||
|
||||
@classmethod
|
||||
def remove_all_token(cls, user: Users) -> None:
|
||||
|
||||
@@ -2,17 +2,27 @@ from typing import Optional, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DefaultContext(BaseModel):
|
||||
...
|
||||
class DefaultContext(BaseModel): ...
|
||||
|
||||
|
||||
class EventContext(DefaultContext):
|
||||
|
||||
|
||||
auth: Any
|
||||
code: str
|
||||
url: str
|
||||
request: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def base(self) -> dict[str, Any]:
|
||||
return {"url": self.url, "code": self.code}
|
||||
|
||||
|
||||
class AuthContext(DefaultContext):
|
||||
|
||||
auth: Any
|
||||
url: str
|
||||
request: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def base(self) -> dict[str, Any]:
|
||||
return {"url": self.url}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
from typing import Optional
|
||||
from ApiLayers.ApiValidations.Request import BaseModelRegular, PydanticBaseModel, ListOptions
|
||||
from ApiLayers.ApiValidations.Request import (
|
||||
BaseModelRegular,
|
||||
PydanticBaseModel,
|
||||
ListOptions,
|
||||
)
|
||||
|
||||
|
||||
class DecisionBookDecisionBookInvitations(BaseModelRegular):
|
||||
|
||||
@@ -15,85 +15,85 @@ class BaseEndpointResponse:
|
||||
from ApiLayers.AllConfigs.Redis.configs import RedisValidationKeys
|
||||
|
||||
language_model_key = f"{RedisValidationKeys.LANGUAGE_MODELS}:{RedisValidationKeys.RESPONSES}"
|
||||
language_model = RedisActions.get_json(list_keys=[language_model_key, self.code , self.lang])
|
||||
language_model = RedisActions.get_json(list_keys=[language_model_key, self.code, self.lang])
|
||||
if language_model.status:
|
||||
return language_model.first.as_dict
|
||||
raise ValueError("Language model not found")
|
||||
return {"message": f"{self.code} -> Language model not found"}
|
||||
|
||||
|
||||
class EndpointSuccessResponse(BaseEndpointResponse): # 1. 200 OK
|
||||
class EndpointSuccessResponse(BaseEndpointResponse): # 200 OK
|
||||
|
||||
def as_dict(self, data: Optional[dict] = None):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content=dict(completed=True, lang=self.lang, data=data, **self.response)
|
||||
content=dict(completed=True, **self.response, lang=self.lang, data=data),
|
||||
)
|
||||
|
||||
|
||||
class EndpointCreatedResponse(BaseEndpointResponse): # 2. 201 Created
|
||||
class EndpointCreatedResponse(BaseEndpointResponse): # 201 Created
|
||||
|
||||
def as_dict(self, data: Optional[dict] = None):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
content=dict(completed=True, lang=self.lang, data=data, **self.response)
|
||||
content=dict(completed=True, **self.response, lang=self.lang, data=data),
|
||||
)
|
||||
|
||||
|
||||
class EndpointAcceptedResponse(BaseEndpointResponse): # 3. 202 Accepted
|
||||
class EndpointAcceptedResponse(BaseEndpointResponse): # 202 Accepted
|
||||
|
||||
def as_dict(self, data: Optional[dict] = None):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
content=dict(completed=True, lang=self.lang, data=data, **self.response)
|
||||
content=dict(completed=True, **self.response, lang=self.lang, data=data),
|
||||
)
|
||||
|
||||
|
||||
class EndpointBadRequestResponse(BaseEndpointResponse): # 4. 400 Bad Request
|
||||
class EndpointBadRequestResponse(BaseEndpointResponse): # 400 Bad Request
|
||||
|
||||
def as_dict(self, data: Optional[dict] = None):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content=dict(completed=False, lang=self.lang, data=data, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang, data=data),
|
||||
)
|
||||
|
||||
|
||||
class EndpointUnauthorizedResponse(BaseEndpointResponse): # 5. 401 Unauthorized
|
||||
class EndpointUnauthorizedResponse(BaseEndpointResponse): # 401 Unauthorized
|
||||
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
class EndpointNotFoundResponse(BaseEndpointResponse): # 6. 404 Not Found
|
||||
class EndpointNotFoundResponse(BaseEndpointResponse): # 404 Not Found
|
||||
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
class EndpointForbiddenResponse(BaseEndpointResponse): # 3. 403 Forbidden
|
||||
class EndpointForbiddenResponse(BaseEndpointResponse): # 403 Forbidden
|
||||
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
class EndpointConflictResponse(BaseEndpointResponse): # 6. 409 Conflict
|
||||
class EndpointConflictResponse(BaseEndpointResponse): # 409 Conflict
|
||||
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 7. 429 Too Many Requests
|
||||
class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 429 Too Many Requests
|
||||
|
||||
def __init__(self, retry_after: int, lang: str, code: str):
|
||||
super().__init__(lang=lang, code=code)
|
||||
@@ -103,16 +103,16 @@ class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 7. 429 Too Ma
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
headers={"Retry-After": str(self.retry_after)},
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
class EndpointInternalErrorResponse(BaseEndpointResponse): # 7. 500 Internal Server Error
|
||||
class EndpointInternalErrorResponse(BaseEndpointResponse): # 500 Internal Server Error
|
||||
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
|
||||
@@ -121,5 +121,5 @@ class EndpointErrorResponse(BaseEndpointResponse):
|
||||
def as_dict(self):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_304_NOT_MODIFIED,
|
||||
content=dict(completed=False, lang=self.lang, **self.response)
|
||||
content=dict(completed=False, **self.response, lang=self.lang),
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from Services.Redis.Actions.actions import RedisActions
|
||||
from ApiLayers.AllConfigs.Redis.configs import RedisValidationKeys
|
||||
|
||||
|
||||
class HTTPExceptionApi(Exception):
|
||||
|
||||
def __init__(self, error_code: str, lang: str, loc: str = "", sys_msg: str = ""):
|
||||
"""
|
||||
Initialize the HTTPExceptionApi class.
|
||||
:param error_code: The error code. To retrieve the error message.
|
||||
:param lang: The language. Catch error msg from redis.
|
||||
:param loc: The location. To log where error occurred.
|
||||
:param sys_msg: The system message. To log the error message.
|
||||
"""
|
||||
self.error_code = error_code
|
||||
self.lang = lang
|
||||
self.loc = loc
|
||||
@@ -13,6 +21,12 @@ class HTTPExceptionApi(Exception):
|
||||
"""
|
||||
Retrieve the error message from the redis by the error code.
|
||||
"""
|
||||
error_msg = RedisActions.get_json(list_keys=["LANGUAGE_MODELS", "ERRORCODES", self.lang])
|
||||
if error_msg.status:
|
||||
return error_msg.first
|
||||
error_redis_key = (
|
||||
f"{RedisValidationKeys.LANGUAGE_MODELS}:{RedisValidationKeys.ERRORCODES}"
|
||||
)
|
||||
error_message = RedisActions.get_json(list_keys=[error_redis_key, self.lang])
|
||||
if error_message.status:
|
||||
error_message_dict = error_message.first.as_dict
|
||||
if error_message_dict.get(self.error_code, None):
|
||||
return error_message_dict.get(self.error_code)
|
||||
return f"System Message -> {self.sys_msg}"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
AccountBooksLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
BuildIbansLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
DecisionBookBudgetBooksLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
BuildTypesLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
BuildDecisionBookLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
RelationshipDutyCompanyLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
DepartmentsLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
StaffLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
|
||||
EventsLanguageModel = dict(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
UsersTokensLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import CrudCollectionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.Mixins.crud_mixin import (
|
||||
CrudCollectionLanguageModel,
|
||||
)
|
||||
|
||||
EndpointRestrictionLanguageModel = dict(
|
||||
tr={
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from .merge_all_error_languages import MergedErrorLanguageModels
|
||||
|
||||
__all__ = ["MergedErrorLanguageModels"]
|
||||
@@ -92,11 +92,12 @@ class MiddlewareModule:
|
||||
auth_context = AuthContext(
|
||||
auth={"token": {"access_token": "", "refresher_token": "", "context": {}}},
|
||||
url=endpoint_url,
|
||||
request=request,
|
||||
)
|
||||
|
||||
# Set auth context on the wrapper function itself
|
||||
setattr(func, 'auth_context', auth_context)
|
||||
setattr(wrapper, 'auth_context', auth_context)
|
||||
setattr(func, "auth_context", auth_context)
|
||||
setattr(wrapper, "auth_context", auth_context)
|
||||
|
||||
# Call the original endpoint function
|
||||
if inspect.iscoroutinefunction(func):
|
||||
@@ -105,10 +106,11 @@ class MiddlewareModule:
|
||||
result = func(request, *args, **kwargs)
|
||||
|
||||
# Set auth context on the wrapper function itself
|
||||
setattr(func, 'auth_context', auth_context)
|
||||
setattr(wrapper, 'auth_context', auth_context)
|
||||
|
||||
setattr(func, "auth_context", auth_context)
|
||||
setattr(wrapper, "auth_context", auth_context)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class TokenEventMiddleware:
|
||||
Returns:
|
||||
Tuple[str, list[str]]: The access token and a list of reachable event codes
|
||||
"""
|
||||
# Get token context from request
|
||||
# Get token context from request
|
||||
access_token = TokenService.get_access_token_from_request(request_from_scope)
|
||||
if not access_token:
|
||||
raise HTTPExceptionApi(
|
||||
@@ -50,13 +50,19 @@ class TokenEventMiddleware:
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Token not found",
|
||||
)
|
||||
|
||||
|
||||
# Get token context from Redis by access token and collect reachable event codes
|
||||
token_context = TokenService.get_object_via_access_key(access_token=access_token)
|
||||
token_context = TokenService.get_object_via_access_key(
|
||||
access_token=access_token
|
||||
)
|
||||
if token_context.is_employee:
|
||||
reachable_event_codes: list[str] = token_context.selected_company.reachable_event_codes
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_company.reachable_event_codes
|
||||
)
|
||||
elif token_context.is_occupant:
|
||||
reachable_event_codes: list[str] = token_context.selected_occupant.reachable_event_codes
|
||||
reachable_event_codes: list[str] = (
|
||||
token_context.selected_occupant.reachable_event_codes
|
||||
)
|
||||
else:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
@@ -67,12 +73,15 @@ class TokenEventMiddleware:
|
||||
return token_context, reachable_event_codes
|
||||
|
||||
@staticmethod
|
||||
def retrieve_intersected_event_code(request: Request, reachable_event_codes: list[str]) -> str:
|
||||
def retrieve_intersected_event_code(
|
||||
request: Request, reachable_event_codes: list[str]
|
||||
) -> str:
|
||||
"""
|
||||
Match an endpoint with accessible events.
|
||||
|
||||
Args:
|
||||
request: The endpoint to match
|
||||
reachable_event_codes: The list of event codes accessible to the user
|
||||
|
||||
Returns:
|
||||
Dict containing the endpoint registration data
|
||||
@@ -81,9 +90,7 @@ class TokenEventMiddleware:
|
||||
endpoint_url = str(request.url.path)
|
||||
# Get the endpoint URL for matching with events
|
||||
function_codes_of_endpoint = RedisActions.get_json(
|
||||
list_keys=[
|
||||
RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url
|
||||
]
|
||||
list_keys=[RedisCategoryKeys.METHOD_FUNCTION_CODES, "*", endpoint_url]
|
||||
)
|
||||
function_code_list_of_event = function_codes_of_endpoint.first
|
||||
if not function_codes_of_endpoint.status:
|
||||
@@ -93,9 +100,11 @@ class TokenEventMiddleware:
|
||||
loc=get_line_number_for_error(),
|
||||
sys_msg="Function code not found",
|
||||
)
|
||||
|
||||
|
||||
# Intersect function codes with user accers objects available event codes
|
||||
intersected_code = list(set(function_code_list_of_event) & set(reachable_event_codes))
|
||||
intersected_code = list(
|
||||
set(function_code_list_of_event) & set(reachable_event_codes)
|
||||
)
|
||||
if not len(intersected_code) == 1:
|
||||
raise HTTPExceptionApi(
|
||||
error_code="",
|
||||
@@ -122,18 +131,23 @@ class TokenEventMiddleware:
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(request: Request, *args, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
|
||||
# Get and validate token context from request
|
||||
# token_context, reachable_event_codes = cls.retrieve_access_content(request)
|
||||
token_context, reachable_event_codes = {"token": "context", "context": {}}, ["g1j8i6j7-9k4h-0h6l-4i3j-2j0k1k0j0i0k"]
|
||||
endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(request, reachable_event_codes)
|
||||
event_context = EventContext(auth=token_context, code=reachable_event_code, url=endpoint_url)
|
||||
reachable_event_codes = ["g1j8i6j7-9k4h-0h6l-4i3j-2j0k1k0j0i0k"]
|
||||
token_context = {"token": "context","context": {}}
|
||||
endpoint_url, reachable_event_code = cls.retrieve_intersected_event_code(
|
||||
request, reachable_event_codes
|
||||
)
|
||||
event_context = EventContext(
|
||||
auth=token_context, code=reachable_event_code, url=endpoint_url, request=request,
|
||||
)
|
||||
|
||||
# Get auth context from the authenticated function's wrapper
|
||||
if token_context is not None:
|
||||
setattr(wrapper, 'event_context', event_context)
|
||||
setattr(func, 'event_context', event_context)
|
||||
|
||||
setattr(wrapper, "event_context", event_context)
|
||||
setattr(func, "event_context", event_context)
|
||||
|
||||
# Execute the authenticated function and get its result
|
||||
if inspect.iscoroutinefunction(func):
|
||||
result = await func(request, *args, **kwargs)
|
||||
@@ -179,4 +193,4 @@ class TokenEventMiddleware:
|
||||
# result = func(request, *args, **kwargs)
|
||||
# return result
|
||||
|
||||
# return wrapper
|
||||
# return wrapper
|
||||
|
||||
@@ -24,7 +24,10 @@ from ApiLayers.ApiValidations.Request import (
|
||||
UpdateBuild,
|
||||
)
|
||||
|
||||
from ApiLayers.ApiValidations.Custom.token_objects import EmployeeTokenObject, OccupantTokenObject
|
||||
from ApiLayers.ApiValidations.Custom.token_objects import (
|
||||
EmployeeTokenObject,
|
||||
OccupantTokenObject,
|
||||
)
|
||||
from ApiLayers.LanguageModels.Database.building.build import (
|
||||
BuildTypesLanguageModel,
|
||||
Part2EmployeeLanguageModel,
|
||||
|
||||
@@ -17,7 +17,10 @@ from sqlalchemy import (
|
||||
from sqlalchemy.orm import mapped_column, relationship, Mapped
|
||||
|
||||
from ApiLayers.ApiLibrary.date_time_actions.date_functions import system_arrow
|
||||
from ApiLayers.ApiLibrary.extensions.select import SelectAction, SelectActionWithEmployee
|
||||
from ApiLayers.ApiLibrary.extensions.select import (
|
||||
SelectAction,
|
||||
SelectActionWithEmployee,
|
||||
)
|
||||
|
||||
from ApiLayers.AllConfigs.Token.config import Auth
|
||||
from ApiLayers.ApiServices.Login.user_login_handler import UserLoginModule
|
||||
@@ -138,6 +141,24 @@ class Users(CrudCollection, UserLoginModule, SelectAction):
|
||||
def is_employee(self):
|
||||
return str(self.email).split("@")[1] == Auth.ACCESS_EMAIL_EXT
|
||||
|
||||
@property
|
||||
def user_type(self):
|
||||
return "Occupant" if self.is_occupant else "Employee"
|
||||
|
||||
@classmethod
|
||||
def credentials(cls):
|
||||
db_session = cls.new_session()
|
||||
person_object: People = People.filter_by_one(db=db_session, system=True, id=cls.person_id).data
|
||||
if person_object:
|
||||
return {
|
||||
"person_id": person_object.id,
|
||||
"person_uu_id": str(person_object.uu_id),
|
||||
}
|
||||
return {
|
||||
"person_id": None,
|
||||
"person_uu_id": None,
|
||||
}
|
||||
|
||||
@property
|
||||
def password_expiry_ends(self):
|
||||
"""Calculates the expiry end date based on expiry begins and expires day"""
|
||||
@@ -153,16 +174,6 @@ class Users(CrudCollection, UserLoginModule, SelectAction):
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_super_user(self):
|
||||
"""Checks if the user is a superuser based on priority code"""
|
||||
return getattr(self.priority, "priority_code", 0) == 78
|
||||
|
||||
@property
|
||||
def is_user(self):
|
||||
"""Checks if the user is a regular user based on priority code"""
|
||||
return getattr(self.priority, "priority_code", 0) == 0
|
||||
|
||||
@classmethod
|
||||
def create_action(cls, create_user: InsertUsers, token_dict):
|
||||
db_session = cls.new_session()
|
||||
@@ -192,21 +203,7 @@ class Users(CrudCollection, UserLoginModule, SelectAction):
|
||||
created_user.reset_password_token(found_user=created_user)
|
||||
return created_user
|
||||
|
||||
@classmethod
|
||||
def credentials(cls):
|
||||
db_session = cls.new_session()
|
||||
person_object = People.filter_by_one(
|
||||
db=db_session, system=True, id=cls.person_id
|
||||
).data
|
||||
if person_object:
|
||||
return {
|
||||
"person_id": person_object.id,
|
||||
"person_uu_id": str(person_object.uu_id),
|
||||
}
|
||||
return {
|
||||
"person_id": None,
|
||||
"person_uu_id": None,
|
||||
}
|
||||
|
||||
|
||||
def get_employee_and_duty_details(self):
|
||||
from Schemas import Employees, Duties
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.orm import mapped_column, Mapped
|
||||
|
||||
from ApiLayers.LanguageModels.Database.rules.rules import EndpointRestrictionLanguageModel
|
||||
from ApiLayers.LanguageModels.Database.rules.rules import (
|
||||
EndpointRestrictionLanguageModel,
|
||||
)
|
||||
from Services.PostgresDb import CrudCollection
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user