updated login/select

This commit is contained in:
2025-01-27 21:43:36 +03:00
parent b88f910a43
commit c0bd9c1685
18 changed files with 257 additions and 275 deletions

View File

@@ -2,6 +2,7 @@ from ApiLayers.AllConfigs import HostConfig
class WagRedis:
REDIS_HOST = HostConfig.MAIN_HOST
REDIS_PASSWORD: str = "commercial_redis_password"
REDIS_PORT: int = 11222
@@ -10,10 +11,8 @@ class WagRedis:
@classmethod
def as_dict(cls):
return dict(
host=WagRedis.REDIS_HOST,
password=WagRedis.REDIS_PASSWORD,
port=WagRedis.REDIS_PORT,
db=WagRedis.REDIS_DB,
host=WagRedis.REDIS_HOST, password=WagRedis.REDIS_PASSWORD,
port=WagRedis.REDIS_PORT, db=WagRedis.REDIS_DB,
)

View File

@@ -50,7 +50,6 @@ class UserLoginModule:
def login_user_via_credentials(self, access_data: "Login") -> None:
from ApiLayers.ApiServices.Token.token_handler import TokenService
from ApiLayers.Schemas import Users
"""
Login the user via the credentials.
"""

View File

@@ -265,7 +265,7 @@ class TokenService:
redis_rows = cls._get_user_tokens(user)
for redis_row in redis_rows.all:
if redis_row.row.get("domain") == domain:
redis_row.delete()
RedisActions.delete([redis_row.key])
@classmethod
def remove_all_token(cls, user: Users) -> None:
@@ -284,16 +284,11 @@ class TokenService:
"""Set access token to redis and handle user session."""
cls.remove_token_with_domain(user=user, domain=domain)
Users.client_arrow = DateTimeLocal(is_client=True, timezone=user.local_timezone)
db_session = UsersTokens.new_session()
# Handle login based on user type
if user.is_occupant:
login_dict = cls.do_occupant_login(
request=request, user=user, domain=domain
)
login_dict, db_session = {}, UsersTokens.new_session()
if user.is_occupant: # Handle login based on user type
login_dict = cls.do_occupant_login(request=request, user=user, domain=domain)
elif user.is_employee:
login_dict = cls.do_employee_login(
request=request, user=user, domain=domain
)
login_dict = cls.do_employee_login(request=request, user=user, domain=domain)
# Handle remember me functionality
if remember:
@@ -321,10 +316,7 @@ class TokenService:
).query.delete(synchronize_session=False)
user.remember_me = False
user.save(db=db_session)
return {
**login_dict,
"user": user.get_dict(),
}
return {**login_dict, "user": user.get_dict()}
@classmethod
def update_token_at_redis(
@@ -424,9 +416,8 @@ class TokenService:
sys_msg="Access token token is not found or unable to retrieve",
)
if redis_object := redis_response.first:
redis_object_dict = redis_object.data
access_token_obj.userUUID = redis_object_dict.get("user_uu_id")
return cls._process_redis_object(redis_object_dict)
access_token_obj.userUUID = redis_object.get("user_uu_id")
return cls._process_redis_object(redis_object)
@classmethod
def get_object_via_user_uu_id(cls, user_id: str) -> T:
@@ -434,7 +425,7 @@ class TokenService:
access_token = AccessToken(userUUID=user_id)
redis_response = RedisActions.get_json(list_keys=access_token.to_list())
if redis_object := redis_response.first.data:
if redis_object := redis_response.first.row:
access_token.userUUID = redis_object.get("user_uu_id")
return cls._process_redis_object(redis_object)

View File

@@ -59,11 +59,10 @@ class OccupantToken(BaseModel):
responsible_employee_uuid: Optional[str] = None
reachable_event_codes: Optional[list[str]] = None # ID list of reachable modules
reachable_event_endpoints: Optional[list[str]] = None
class CompanyToken(BaseModel): # Required Company Object for an employee
class CompanyToken(BaseModel):
# Selection of the company for an employee is made by the user
company_id: int
company_uu_id: str
@@ -82,14 +81,12 @@ class CompanyToken(BaseModel): # Required Company Object for an employee
bulk_duties_id: int
reachable_event_codes: Optional[list[str]] = None # ID list of reachable modules
reachable_event_endpoints: Optional[list[str]] = None
class OccupantTokenObject(ApplicationToken):
# Occupant Token Object -> Requires selection of the occupant type for a specific build part
available_occupants: dict = None
selected_occupant: Optional[OccupantToken] = None # Selected Occupant Type
@property

View File

@@ -21,7 +21,7 @@ class BaseEndpointResponse:
return {"message": f"{self.code} -> Language model not found"}
class EndpointSuccessResponse(BaseEndpointResponse): # 200 OK
class EndpointSuccessResponse(BaseEndpointResponse): # 200 OK
def as_dict(self, data: Optional[dict] = None):
return JSONResponse(
@@ -30,7 +30,7 @@ class EndpointSuccessResponse(BaseEndpointResponse): # 200 OK
)
class EndpointCreatedResponse(BaseEndpointResponse): # 201 Created
class EndpointCreatedResponse(BaseEndpointResponse): # 201 Created
def as_dict(self, data: Optional[dict] = None):
return JSONResponse(
@@ -39,7 +39,7 @@ class EndpointCreatedResponse(BaseEndpointResponse): # 201 Created
)
class EndpointAcceptedResponse(BaseEndpointResponse): # 202 Accepted
class EndpointAcceptedResponse(BaseEndpointResponse): # 202 Accepted
def as_dict(self, data: Optional[dict] = None):
return JSONResponse(
@@ -48,7 +48,16 @@ class EndpointAcceptedResponse(BaseEndpointResponse): # 202 Accepted
)
class EndpointBadRequestResponse(BaseEndpointResponse): # 400 Bad Request
class EndpointNotModifiedResponse(BaseEndpointResponse): # 304 Not Modified
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_304_NOT_MODIFIED,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointBadRequestResponse(BaseEndpointResponse): # 400 Bad Request
def as_dict(self, data: Optional[dict] = None):
return JSONResponse(
@@ -57,7 +66,7 @@ class EndpointBadRequestResponse(BaseEndpointResponse): # 400 Bad Reques
)
class EndpointUnauthorizedResponse(BaseEndpointResponse): # 401 Unauthorized
class EndpointUnauthorizedResponse(BaseEndpointResponse): # 401 Unauthorized
def as_dict(self):
return JSONResponse(
@@ -66,16 +75,7 @@ class EndpointUnauthorizedResponse(BaseEndpointResponse): # 401 Unauthoriz
)
class EndpointNotFoundResponse(BaseEndpointResponse): # 404 Not Found
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointForbiddenResponse(BaseEndpointResponse): # 403 Forbidden
class EndpointForbiddenResponse(BaseEndpointResponse): # 403 Forbidden
def as_dict(self):
return JSONResponse(
@@ -84,7 +84,34 @@ class EndpointForbiddenResponse(BaseEndpointResponse): # 403 Forbidden
)
class EndpointConflictResponse(BaseEndpointResponse): # 409 Conflict
class EndpointNotFoundResponse(BaseEndpointResponse): # 404 Not Found
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointMethodNotAllowedResponse(BaseEndpointResponse): # 405 Method Not Allowed
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointNotAcceptableResponse(BaseEndpointResponse): # 406 Not Acceptable
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointConflictResponse(BaseEndpointResponse): # 409 Conflict
def as_dict(self):
return JSONResponse(
@@ -93,7 +120,16 @@ class EndpointConflictResponse(BaseEndpointResponse): # 409 Conflict
)
class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 429 Too Many Requests
class EndpointUnprocessableEntityResponse(BaseEndpointResponse): # 422 Unprocessable Entity
def as_dict(self, data: Optional[dict] = None):
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=dict(completed=False, **self.response, lang=self.lang, data=data),
)
class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 429 Too Many Requests
def __init__(self, retry_after: int, lang: str, code: str):
super().__init__(lang=lang, code=code)
@@ -107,19 +143,10 @@ class EndpointTooManyRequestsResponse(BaseEndpointResponse): # 429 Too Many R
)
class EndpointInternalErrorResponse(BaseEndpointResponse): # 500 Internal Server Error
class EndpointInternalErrorResponse(BaseEndpointResponse): # 500 Internal Server Error
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=dict(completed=False, **self.response, lang=self.lang),
)
class EndpointErrorResponse(BaseEndpointResponse):
def as_dict(self):
return JSONResponse(
status_code=status.HTTP_304_NOT_MODIFIED,
content=dict(completed=False, **self.response, lang=self.lang),
)

View File

@@ -1,20 +1,40 @@
import json
from typing import Any, Union, Awaitable
from fastapi import Request, WebSocket
from fastapi.responses import Response
from pydantic import ValidationError
from ApiLayers.LanguageModels.Errors.merge_all_error_languages import (
MergedErrorLanguageModels,
)
from fastapi import Request, WebSocket, status
from fastapi.responses import Response, JSONResponse
from ApiLayers.LanguageModels.Errors.merge_all_error_languages import MergedErrorLanguageModels
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
from ApiLayers.ErrorHandlers.bases import BaseErrorModelClass
def validation_exception_handler(request, exc: ValidationError) -> JSONResponse:
"""
{"message": [{
"type": "missing", "location": ["company_uu_id"], "message": "Field required",
"input": {"invalid_key_input": "e9869a25"}
}], "request": "http://0.0.0.0:41575/authentication/select", "title": "EmployeeSelection"
}
Validation error on pydantic model of each event validation
"""
validation_messages, validation_list = exc.errors() or [], []
for validation in validation_messages:
validation_list.append({
"type": dict(validation).get("type"),
"location": dict(validation).get("loc"),
"message": dict(validation).get("msg"), # todo change message with language message
"input": dict(validation).get("input"),
})
error_response_dict = dict(message=validation_list, request=str(request.url.path), title=exc.title)
return JSONResponse(
content=error_response_dict, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
class HTTPExceptionApiHandler:
def __init__(
self,
response_model: Any,
):
def __init__(self, response_model: Any):
self.RESPONSE_MODEL: Any = response_model
@staticmethod
@@ -26,7 +46,6 @@ class HTTPExceptionApiHandler:
@staticmethod
def retrieve_error_message(exc: HTTPExceptionApi, error_languages) -> str:
from ApiLayers.ErrorHandlers import DEFAULT_ERROR
return error_languages.get(str(exc.error_code).upper(), DEFAULT_ERROR)
async def handle_exception(

View File

@@ -27,9 +27,7 @@ class MiddlewareModule:
"""
@staticmethod
def get_user_from_request(
request: Request,
) -> dict:
def get_user_from_request(request: Request) -> dict:
"""
Get authenticated token context from request.
@@ -54,7 +52,6 @@ class MiddlewareModule:
loc=get_line_number_for_error(),
sys_msg="TokenService: Token Context couldnt retrieved from redis",
)
return token_context
@classmethod
@@ -89,11 +86,8 @@ class MiddlewareModule:
async def wrapper(request: Request, *args, **kwargs):
# Get and validate token context from request
endpoint_url = str(request.url.path)
auth_context = AuthContext(
auth={"token": {"access_token": "", "refresher_token": "", "context": {}}},
url=endpoint_url,
request=request,
)
token_context = cls.get_user_from_request(request=request)
auth_context = AuthContext(auth=token_context, url=endpoint_url, request=request)
# Set auth context on the wrapper function itself
setattr(func, "auth_context", auth_context)

View File

@@ -256,12 +256,11 @@ class Event2Employee(CrudCollection):
cls.employee_id == employee_id,
db=db,
).data
active_event_ids = Service2Events.filter_all(
active_event_ids = Service2Events.filter_all_system(
Service2Events.service_id.in_(
[event.event_service_id for event in employee_events]
),
db=db,
system=True,
).data
active_events = Events.filter_all(
Events.id.in_([event.event_id for event in active_event_ids]),
@@ -278,41 +277,41 @@ class Event2Employee(CrudCollection):
active_events.extend(events_extra)
return [event.function_code for event in active_events]
@classmethod
def get_event_endpoints(cls, employee_id: int) -> list:
from Schemas import EndpointRestriction
db = cls.new_session()
employee_events = cls.filter_all(
cls.employee_id == employee_id,
db=db,
).data
active_event_ids = Service2Events.filter_all(
Service2Events.service_id.in_(
[event.event_service_id for event in employee_events]
),
db=db,
system=True,
).data
active_events = Events.filter_all(
Events.id.in_([event.event_id for event in active_event_ids]),
db=db,
).data
if extra_events := Event2EmployeeExtra.filter_all(
Event2EmployeeExtra.employee_id == employee_id,
db=db,
).data:
events_extra = Events.filter_all(
Events.id.in_([event.event_id for event in extra_events]),
db=db,
).data
active_events.extend(events_extra)
endpoint_restrictions = EndpointRestriction.filter_all(
EndpointRestriction.id.in_([event.endpoint_id for event in active_events]),
db=db,
).data
return [event.endpoint_name for event in endpoint_restrictions]
# @classmethod
# def get_event_endpoints(cls, employee_id: int) -> list:
# from Schemas import EndpointRestriction
#
# db = cls.new_session()
# employee_events = cls.filter_all(
# cls.employee_id == employee_id,
# db=db,
# ).data
# active_event_ids = Service2Events.filter_all(
# Service2Events.service_id.in_(
# [event.event_service_id for event in employee_events]
# ),
# db=db,
# system=True,
# ).data
# active_events = Events.filter_all(
# Events.id.in_([event.event_id for event in active_event_ids]),
# db=db,
# ).data
# if extra_events := Event2EmployeeExtra.filter_all(
# Event2EmployeeExtra.employee_id == employee_id,
# db=db,
# ).data:
# events_extra = Events.filter_all(
# Events.id.in_([event.event_id for event in extra_events]),
# db=db,
# ).data
# active_events.extend(events_extra)
# endpoint_restrictions = EndpointRestriction.filter_all(
# EndpointRestriction.id.in_([event.endpoint_id for event in active_events]),
# db=db,
# ).data
# return [event.endpoint_name for event in endpoint_restrictions]
#
class Event2Occupant(CrudCollection):
"""
@@ -355,12 +354,11 @@ class Event2Occupant(CrudCollection):
cls.build_living_space_id == build_living_space_id,
db=db,
).data
active_event_ids = Service2Events.filter_all(
active_event_ids = Service2Events.filter_all_system(
Service2Events.service_id.in_(
[event.event_service_id for event in occupant_events]
),
db=db,
system=True,
).data
active_events = Events.filter_all(
Events.id.in_([event.event_id for event in active_event_ids]),
@@ -377,40 +375,40 @@ class Event2Occupant(CrudCollection):
active_events.extend(events_extra)
return [event.function_code for event in active_events]
@classmethod
def get_event_endpoints(cls, build_living_space_id) -> list:
from Schemas import EndpointRestriction
db = cls.new_session()
occupant_events = cls.filter_all(
cls.build_living_space_id == build_living_space_id,
db=db,
).data
active_event_ids = Service2Events.filter_all(
Service2Events.service_id.in_(
[event.event_service_id for event in occupant_events]
),
db=db,
system=True,
).data
active_events = Events.filter_all(
Events.id.in_([event.event_id for event in active_event_ids]),
db=db,
).data
if extra_events := Event2OccupantExtra.filter_all(
Event2OccupantExtra.build_living_space_id == build_living_space_id,
db=db,
).data:
events_extra = Events.filter_all(
Events.id.in_([event.event_id for event in extra_events]),
db=db,
).data
active_events.extend(events_extra)
endpoint_restrictions = EndpointRestriction.filter_all(
EndpointRestriction.id.in_([event.endpoint_id for event in active_events]),
db=db,
).data
return [event.endpoint_name for event in endpoint_restrictions]
# @classmethod
# def get_event_endpoints(cls, build_living_space_id) -> list:
# from Schemas import EndpointRestriction
#
# db = cls.new_session()
# occupant_events = cls.filter_all(
# cls.build_living_space_id == build_living_space_id,
# db=db,
# ).data
# active_event_ids = Service2Events.filter_all(
# Service2Events.service_id.in_(
# [event.event_service_id for event in occupant_events]
# ),
# db=db,
# system=True,
# ).data
# active_events = Events.filter_all(
# Events.id.in_([event.event_id for event in active_event_ids]),
# db=db,
# ).data
# if extra_events := Event2OccupantExtra.filter_all(
# Event2OccupantExtra.build_living_space_id == build_living_space_id,
# db=db,
# ).data:
# events_extra = Events.filter_all(
# Events.id.in_([event.event_id for event in extra_events]),
# db=db,
# ).data
# active_events.extend(events_extra)
# endpoint_restrictions = EndpointRestriction.filter_all(
# EndpointRestriction.id.in_([event.endpoint_id for event in active_events]),
# db=db,
# ).data
# return [event.endpoint_name for event in endpoint_restrictions]
class ModulePrice(CrudCollection):

View File

@@ -203,8 +203,6 @@ class Users(CrudCollection, UserLoginModule, SelectAction):
created_user.reset_password_token(found_user=created_user)
return created_user
def get_employee_and_duty_details(self):
from Schemas import Employees, Duties