auth endpoints added

This commit is contained in:
berkay 2025-04-03 14:19:34 +03:00
parent 3583d178e9
commit ee405133be
37 changed files with 976 additions and 570 deletions

View File

@ -25,4 +25,4 @@ COPY /Schemas/identity /Schemas/identity
ENV PYTHONPATH=/ PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1 ENV PYTHONPATH=/ PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
# Run the application using the configured uvicorn server # Run the application using the configured uvicorn server
CMD ["poetry", "run", "python", "ApiServices/TemplateService/app.py"] CMD ["poetry", "run", "python", "ApiServices/AuthService/app.py"]

View File

@ -2,11 +2,12 @@ import uvicorn
from config import api_config from config import api_config
from ApiServices.TemplateService.create_app import create_app from ApiServices.AuthService.create_app import create_app
# from prometheus_fastapi_instrumentator import Instrumentator # from prometheus_fastapi_instrumentator import Instrumentator
app = create_app() # Create FastAPI application app = create_app() # Create FastAPI application
# Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics # Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics

View File

@ -8,9 +8,9 @@ class Configs(BaseSettings):
""" """
PATH: str = "" PATH: str = ""
HOST: str = "", HOST: str = ("",)
PORT: int = 0, PORT: int = (0,)
LOG_LEVEL: str = "info", LOG_LEVEL: str = ("info",)
RELOAD: int = 0 RELOAD: int = 0
ACCESS_TOKEN_TAG: str = "" ACCESS_TOKEN_TAG: str = ""
@ -36,7 +36,7 @@ class Configs(BaseSettings):
"host": self.HOST, "host": self.HOST,
"port": int(self.PORT), "port": int(self.PORT),
"log_level": self.LOG_LEVEL, "log_level": self.LOG_LEVEL,
"reload": bool(self.RELOAD) "reload": bool(self.RELOAD),
} }
@property @property

View File

@ -3,16 +3,16 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from ApiServices.TemplateService.create_route import RouteRegisterController from ApiServices.AuthService.create_route import RouteRegisterController
from ApiServices.TemplateService.endpoints.routes import get_routes from ApiServices.AuthService.endpoints.routes import get_routes
from ApiServices.TemplateService.open_api_creator import create_openapi_schema from ApiServices.AuthService.open_api_creator import create_openapi_schema
from ApiServices.TemplateService.middlewares.token_middleware import token_middleware from ApiServices.AuthService.middlewares.token_middleware import token_middleware
from ApiServices.TemplateService.config import template_api_config from ApiServices.AuthService.config import api_config
def create_app(): def create_app():
application = FastAPI(**template_api_config.api_info) application = FastAPI(**api_config.api_info)
# application.mount( # application.mount(
# "/application/static", # "/application/static",
# StaticFiles(directory="application/static"), # StaticFiles(directory="application/static"),
@ -20,7 +20,7 @@ def create_app():
# ) # )
application.add_middleware( application.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=template_api_config.ALLOW_ORIGINS, allow_origins=api_config.ALLOW_ORIGINS,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@ -0,0 +1,322 @@
import uuid
from typing import Union
from fastapi import APIRouter, Request, status, Header
from fastapi.responses import JSONResponse
from ApiServices.AuthService.config import api_config
from ApiServices.AuthService.validations.request.authentication.login_post import (
RequestLogin,
RequestSelectLiving,
RequestSelectOccupant, RequestCreatePassword, RequestChangePassword, RequestForgotPasswordPhone,
RequestForgotPasswordEmail,
)
auth_route = APIRouter(
prefix="/authentication",
tags=["Authentication Cluster"],
)
@auth_route.post(
path="/login",
summary="Login via domain and access key : [email] | [phone]",
description="Login Route",
)
def authentication_login_post(
request: Request,
data: RequestLogin,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Authentication Login Route with Post Method
"""
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0001"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={**data.model_dump()},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.post(
path="/select",
summary="Select company or occupant type",
description="Selection of users company or occupant type",
)
def authentication_select_post(
request: Request,
data: Union[RequestSelectOccupant, RequestSelectLiving],
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Authentication Select Route with Post Method
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0001"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content=data.model_dump(),
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.get(
path="/logout",
summary="Logout user",
description="Logout only single session of user which domain is provided",
)
def authentication_logout_post(
request: Request,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Logout user from the system
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0003"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.get(
path="/disconnect",
summary="Disconnect all sessions",
description="Disconnect all sessions of user in access token",
)
def authentication_disconnect_post(
request: Request,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Disconnect all sessions of user in access token
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0003"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.get(
path="/token/check",
summary="Check if token is valid",
description="Check if access token is valid for user",
)
def authentication_token_check_post(
request: Request,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Check if access token is valid for user
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0003"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.get(
path="/token/refresh",
summary="Refresh if token is valid",
description="Refresh if access token is valid for user",
)
def authentication_token_refresh_post(
request: Request,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Refresh if access token is valid for user
"""
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0003"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.post(
path="/password/create",
summary="Create password with access token",
description="Create password",
)
def authentication_password_create_post(
request: Request,
data: RequestCreatePassword,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Authentication create password Route with Post Method
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0001"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={**data.model_dump()},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.post(
path="/password/change",
summary="Change password with access token",
description="Change password",
)
def authentication_password_change_post(
request: Request,
data: RequestChangePassword,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Authentication change password Route with Post Method
"""
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"token": token,
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0001"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={**data.model_dump()},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)
@auth_route.post(
path="/password/reset",
summary="Reset password with access token",
description="Reset password",
)
def authentication_password_reset_post(
request: Request,
data: Union[RequestForgotPasswordEmail, RequestForgotPasswordPhone],
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
):
"""
Authentication reset password Route with Post Method
"""
headers = {
"language": language or "",
"domain": domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
}
if not domain or not language:
return JSONResponse(
content={"error": "EYS_0001"},
status_code=status.HTTP_406_NOT_ACCEPTABLE,
headers=headers,
)
return JSONResponse(
content={**data.model_dump()},
status_code=status.HTTP_202_ACCEPTED,
headers=headers,
)

View File

@ -1,9 +1,9 @@
from fastapi import APIRouter from fastapi import APIRouter
from .test_template.route import test_template_route from ApiServices.AuthService.endpoints.auth.route import auth_route
def get_routes() -> list[APIRouter]: def get_routes() -> list[APIRouter]:
return [test_template_route] return [auth_route]
def get_safe_endpoint_urls() -> list[tuple[str, str]]: def get_safe_endpoint_urls() -> list[tuple[str, str]]:
@ -15,6 +15,5 @@ def get_safe_endpoint_urls() -> list[tuple[str, str]]:
("/auth/register", "POST"), ("/auth/register", "POST"),
("/auth/login", "POST"), ("/auth/login", "POST"),
("/metrics", "GET"), ("/metrics", "GET"),
("/test/template", "GET"), ("/authentication/login", "POST"),
("/test/template", "POST"), ]
]

View File

@ -1,40 +0,0 @@
from fastapi import APIRouter, Request, Response
test_template_route = APIRouter(prefix="/test", tags=["Test"])
@test_template_route.get(path="/template", description="Test Template Route")
def test_template(request: Request, response: Response):
"""
Test Template Route
"""
headers = dict(request.headers)
response.headers["X-Header"] = "Test Header GET"
return {
"completed": True,
"message": "Test Template Route",
"info": {
"host": headers.get("host", "Not Found"),
"user_agent": headers.get("user-agent", "Not Found"),
},
}
@test_template_route.post(
path="/template",
description="Test Template Route with Post Method",
)
def test_template_post(request: Request, response: Response):
"""
Test Template Route with Post Method
"""
headers = dict(request.headers)
response.headers["X-Header"] = "Test Header POST"
return {
"completed": True,
"message": "Test Template Route with Post Method",
"info": {
"host": headers.get("host", "Not Found"),
"user_agent": headers.get("user-agent", "Not Found"),
},
}

View File

@ -1,5 +1,7 @@
from fastapi import Request, Response from fastapi import Request, status
from ApiServices.TemplateService.endpoints.routes import get_safe_endpoint_urls from fastapi.responses import JSONResponse
from ..endpoints.routes import get_safe_endpoint_urls
from ..config import api_config
async def token_middleware(request: Request, call_next): async def token_middleware(request: Request, call_next):
@ -9,9 +11,14 @@ async def token_middleware(request: Request, call_next):
if base_url in safe_endpoints: if base_url in safe_endpoints:
return await call_next(request) return await call_next(request)
token = request.headers.get("Authorization") token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
if not token: if not token:
return Response(content="Missing token", status_code=400) return JSONResponse(
content={
"error": "EYS_0002",
},
status_code=status.HTTP_401_UNAUTHORIZED,
)
response = await call_next(request) response = await call_next(request)
return response return response

View File

@ -3,8 +3,8 @@ from fastapi import FastAPI
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from ApiServices.TemplateService.config import template_api_config from ApiServices.AuthService.config import api_config
from ApiServices.TemplateService.endpoints.routes import get_safe_endpoint_urls from ApiServices.AuthService.endpoints.routes import get_safe_endpoint_urls
class OpenAPISchemaCreator: class OpenAPISchemaCreator:
@ -36,7 +36,7 @@ class OpenAPISchemaCreator:
"BearerAuth": { "BearerAuth": {
"type": "apiKey", "type": "apiKey",
"in": "header", "in": "header",
"name": template_api_config.ACCESS_TOKEN_TAG, "name": api_config.ACCESS_TOKEN_TAG,
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token", "description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
} }
} }
@ -73,9 +73,9 @@ class OpenAPISchemaCreator:
Dict[str, Any]: Complete OpenAPI schema Dict[str, Any]: Complete OpenAPI schema
""" """
openapi_schema = get_openapi( openapi_schema = get_openapi(
title=template_api_config.TITLE, title=api_config.TITLE,
description=template_api_config.DESCRIPTION, description=api_config.DESCRIPTION,
version=template_api_config.VERSION, version=api_config.VERSION,
routes=self.app.routes, routes=self.app.routes,
) )
@ -83,9 +83,7 @@ class OpenAPISchemaCreator:
if "components" not in openapi_schema: if "components" not in openapi_schema:
openapi_schema["components"] = {} openapi_schema["components"] = {}
openapi_schema["components"][ openapi_schema["components"]["securitySchemes"] = self.create_security_schemes()
"securitySchemes"
] = self.create_security_schemes()
# Configure route security and responses # Configure route security and responses
for route in self.app.routes: for route in self.app.routes:
@ -115,4 +113,4 @@ def create_openapi_schema(app: FastAPI) -> Dict[str, Any]:
Dict[str, Any]: Complete OpenAPI schema Dict[str, Any]: Complete OpenAPI schema
""" """
creator = OpenAPISchemaCreator(app) creator = OpenAPISchemaCreator(app)
return creator.create_schema() return creator.create_schema()

View File

@ -0,0 +1,38 @@
from typing import Optional
from pydantic import BaseModel
class RequestLogin(BaseModel):
access_key: str
password: str
remember_me: Optional[bool]
class RequestSelectOccupant(BaseModel):
company_uu_id: str
class RequestSelectLiving(BaseModel):
build_living_space_uu_id: str
class RequestCreatePassword(BaseModel):
password_token: str
password: str
re_password: str
class RequestChangePassword(BaseModel):
old_password: str
password: str
re_password: str
class RequestForgotPasswordEmail(BaseModel):
email: str
class RequestForgotPasswordPhone(BaseModel):
phone_number: str

View File

@ -3,10 +3,11 @@ import uvicorn
from config import api_config from config import api_config
from ApiServices.TemplateService.create_app import create_app from ApiServices.TemplateService.create_app import create_app
# from prometheus_fastapi_instrumentator import Instrumentator # from prometheus_fastapi_instrumentator import Instrumentator
app = create_app() # Create FastAPI application app = create_app() # Create FastAPI application
# Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics # Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics

View File

@ -8,9 +8,9 @@ class Configs(BaseSettings):
""" """
PATH: str = "" PATH: str = ""
HOST: str = "", HOST: str = ("",)
PORT: int = 0, PORT: int = (0,)
LOG_LEVEL: str = "info", LOG_LEVEL: str = ("info",)
RELOAD: int = 0 RELOAD: int = 0
ACCESS_TOKEN_TAG: str = "" ACCESS_TOKEN_TAG: str = ""
@ -36,7 +36,7 @@ class Configs(BaseSettings):
"host": self.HOST, "host": self.HOST,
"port": int(self.PORT), "port": int(self.PORT),
"log_level": self.LOG_LEVEL, "log_level": self.LOG_LEVEL,
"reload": bool(self.RELOAD) "reload": bool(self.RELOAD),
} }
@property @property

View File

@ -17,4 +17,4 @@ def get_safe_endpoint_urls() -> list[tuple[str, str]]:
("/metrics", "GET"), ("/metrics", "GET"),
("/test/template", "GET"), ("/test/template", "GET"),
("/test/template", "POST"), ("/test/template", "POST"),
] ]

View File

@ -1,5 +1,7 @@
from fastapi import Request, Response from fastapi import Request, status
from ApiServices.TemplateService.endpoints.routes import get_safe_endpoint_urls from fastapi.responses import JSONResponse
from ..endpoints.routes import get_safe_endpoint_urls
from ..config import api_config
async def token_middleware(request: Request, call_next): async def token_middleware(request: Request, call_next):
@ -9,9 +11,14 @@ async def token_middleware(request: Request, call_next):
if base_url in safe_endpoints: if base_url in safe_endpoints:
return await call_next(request) return await call_next(request)
token = request.headers.get("Authorization") token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
if not token: if not token:
return Response(content="Missing token", status_code=400) return JSONResponse(
content={
"error": "EYS_0002",
},
status_code=status.HTTP_401_UNAUTHORIZED,
)
response = await call_next(request) response = await call_next(request)
return response return response

View File

@ -83,9 +83,7 @@ class OpenAPISchemaCreator:
if "components" not in openapi_schema: if "components" not in openapi_schema:
openapi_schema["components"] = {} openapi_schema["components"] = {}
openapi_schema["components"][ openapi_schema["components"]["securitySchemes"] = self.create_security_schemes()
"securitySchemes"
] = self.create_security_schemes()
# Configure route security and responses # Configure route security and responses
for route in self.app.routes: for route in self.app.routes:
@ -115,4 +113,4 @@ def create_openapi_schema(app: FastAPI) -> Dict[str, Any]:
Dict[str, Any]: Complete OpenAPI schema Dict[str, Any]: Complete OpenAPI schema
""" """
creator = OpenAPISchemaCreator(app) creator = OpenAPISchemaCreator(app)
return creator.create_schema() return creator.create_schema()

View File

@ -17,29 +17,30 @@ def test_basic_crud_operations():
try: try:
with mongo_handler.collection("users") as users_collection: with mongo_handler.collection("users") as users_collection:
# Insert multiple documents # Insert multiple documents
users_collection.insert_many([ users_collection.insert_many(
{"username": "john", "email": "john@example.com", "role": "user"}, [
{"username": "jane", "email": "jane@example.com", "role": "admin"}, {"username": "john", "email": "john@example.com", "role": "user"},
{"username": "bob", "email": "bob@example.com", "role": "user"} {"username": "jane", "email": "jane@example.com", "role": "admin"},
]) {"username": "bob", "email": "bob@example.com", "role": "user"},
]
)
# Find with multiple conditions # Find with multiple conditions
admin_users = list(users_collection.find({"role": "admin"})) admin_users = list(users_collection.find({"role": "admin"}))
# Update multiple documents # Update multiple documents
update_result = users_collection.update_many( update_result = users_collection.update_many(
{"role": "user"}, {"role": "user"}, {"$set": {"last_login": datetime.now().isoformat()}}
{"$set": {"last_login": datetime.now().isoformat()}}
) )
# Delete documents # Delete documents
delete_result = users_collection.delete_many({"username": "bob"}) delete_result = users_collection.delete_many({"username": "bob"})
success = ( success = (
len(admin_users) == 1 and len(admin_users) == 1
admin_users[0]["username"] == "jane" and and admin_users[0]["username"] == "jane"
update_result.modified_count == 2 and and update_result.modified_count == 2
delete_result.deleted_count == 1 and delete_result.deleted_count == 1
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -54,35 +55,32 @@ def test_nested_documents():
try: try:
with mongo_handler.collection("products") as products_collection: with mongo_handler.collection("products") as products_collection:
# Insert a product with nested data # Insert a product with nested data
products_collection.insert_one({ products_collection.insert_one(
"name": "Laptop", {
"price": 999.99, "name": "Laptop",
"specs": { "price": 999.99,
"cpu": "Intel i7", "specs": {"cpu": "Intel i7", "ram": "16GB", "storage": "512GB SSD"},
"ram": "16GB", "in_stock": True,
"storage": "512GB SSD" "tags": ["electronics", "computers", "laptops"],
}, }
"in_stock": True, )
"tags": ["electronics", "computers", "laptops"]
})
# Find with nested field query # Find with nested field query
laptop = products_collection.find_one({"specs.cpu": "Intel i7"}) laptop = products_collection.find_one({"specs.cpu": "Intel i7"})
# Update nested field # Update nested field
update_result = products_collection.update_one( update_result = products_collection.update_one(
{"name": "Laptop"}, {"name": "Laptop"}, {"$set": {"specs.ram": "32GB"}}
{"$set": {"specs.ram": "32GB"}}
) )
# Verify the update # Verify the update
updated_laptop = products_collection.find_one({"name": "Laptop"}) updated_laptop = products_collection.find_one({"name": "Laptop"})
success = ( success = (
laptop is not None and laptop is not None
laptop["specs"]["ram"] == "16GB" and and laptop["specs"]["ram"] == "16GB"
update_result.modified_count == 1 and and update_result.modified_count == 1
updated_laptop["specs"]["ram"] == "32GB" and updated_laptop["specs"]["ram"] == "32GB"
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -97,16 +95,18 @@ def test_array_operations():
try: try:
with mongo_handler.collection("orders") as orders_collection: with mongo_handler.collection("orders") as orders_collection:
# Insert an order with array of items # Insert an order with array of items
orders_collection.insert_one({ orders_collection.insert_one(
"order_id": "ORD001", {
"customer": "john", "order_id": "ORD001",
"items": [ "customer": "john",
{"product": "Laptop", "quantity": 1}, "items": [
{"product": "Mouse", "quantity": 2} {"product": "Laptop", "quantity": 1},
], {"product": "Mouse", "quantity": 2},
"total": 1099.99, ],
"status": "pending" "total": 1099.99,
}) "status": "pending",
}
)
# Find orders containing specific items # Find orders containing specific items
laptop_orders = list(orders_collection.find({"items.product": "Laptop"})) laptop_orders = list(orders_collection.find({"items.product": "Laptop"}))
@ -114,17 +114,17 @@ def test_array_operations():
# Update array elements # Update array elements
update_result = orders_collection.update_one( update_result = orders_collection.update_one(
{"order_id": "ORD001"}, {"order_id": "ORD001"},
{"$push": {"items": {"product": "Keyboard", "quantity": 1}}} {"$push": {"items": {"product": "Keyboard", "quantity": 1}}},
) )
# Verify the update # Verify the update
updated_order = orders_collection.find_one({"order_id": "ORD001"}) updated_order = orders_collection.find_one({"order_id": "ORD001"})
success = ( success = (
len(laptop_orders) == 1 and len(laptop_orders) == 1
update_result.modified_count == 1 and and update_result.modified_count == 1
len(updated_order["items"]) == 3 and and len(updated_order["items"]) == 3
updated_order["items"][-1]["product"] == "Keyboard" and updated_order["items"][-1]["product"] == "Keyboard"
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -139,23 +139,32 @@ def test_aggregation():
try: try:
with mongo_handler.collection("sales") as sales_collection: with mongo_handler.collection("sales") as sales_collection:
# Insert sample sales data # Insert sample sales data
sales_collection.insert_many([ sales_collection.insert_many(
{"product": "Laptop", "amount": 999.99, "date": datetime.now()}, [
{"product": "Mouse", "amount": 29.99, "date": datetime.now()}, {"product": "Laptop", "amount": 999.99, "date": datetime.now()},
{"product": "Keyboard", "amount": 59.99, "date": datetime.now()} {"product": "Mouse", "amount": 29.99, "date": datetime.now()},
]) {"product": "Keyboard", "amount": 59.99, "date": datetime.now()},
]
)
# Calculate total sales by product # Calculate total sales by product
pipeline = [ pipeline = [{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}}]
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}}
]
sales_summary = list(sales_collection.aggregate(pipeline)) sales_summary = list(sales_collection.aggregate(pipeline))
success = ( success = (
len(sales_summary) == 3 and len(sales_summary) == 3
any(item["_id"] == "Laptop" and item["total"] == 999.99 for item in sales_summary) and and any(
any(item["_id"] == "Mouse" and item["total"] == 29.99 for item in sales_summary) and item["_id"] == "Laptop" and item["total"] == 999.99
any(item["_id"] == "Keyboard" and item["total"] == 59.99 for item in sales_summary) for item in sales_summary
)
and any(
item["_id"] == "Mouse" and item["total"] == 29.99
for item in sales_summary
)
and any(
item["_id"] == "Keyboard" and item["total"] == 59.99
for item in sales_summary
)
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -174,11 +183,15 @@ def test_index_operations():
users_collection.create_index([("username", 1), ("role", 1)]) users_collection.create_index([("username", 1), ("role", 1)])
# Insert initial document # Insert initial document
users_collection.insert_one({"username": "test_user", "email": "test@example.com"}) users_collection.insert_one(
{"username": "test_user", "email": "test@example.com"}
)
# Try to insert duplicate email (should fail) # Try to insert duplicate email (should fail)
try: try:
users_collection.insert_one({"username": "test_user2", "email": "test@example.com"}) users_collection.insert_one(
{"username": "test_user2", "email": "test@example.com"}
)
success = False # Should not reach here success = False # Should not reach here
except Exception: except Exception:
success = True success = True
@ -196,49 +209,49 @@ def test_complex_queries():
try: try:
with mongo_handler.collection("products") as products_collection: with mongo_handler.collection("products") as products_collection:
# Insert test data # Insert test data
products_collection.insert_many([ products_collection.insert_many(
{ [
"name": "Expensive Laptop", {
"price": 999.99, "name": "Expensive Laptop",
"tags": ["electronics", "computers"], "price": 999.99,
"in_stock": True "tags": ["electronics", "computers"],
}, "in_stock": True,
{ },
"name": "Cheap Mouse", {
"price": 29.99, "name": "Cheap Mouse",
"tags": ["electronics", "peripherals"], "price": 29.99,
"in_stock": True "tags": ["electronics", "peripherals"],
} "in_stock": True,
]) },
]
)
# Find products with price range and specific tags # Find products with price range and specific tags
expensive_electronics = list(products_collection.find({ expensive_electronics = list(
"price": {"$gt": 500}, products_collection.find(
"tags": {"$in": ["electronics"]}, {
"in_stock": True "price": {"$gt": 500},
})) "tags": {"$in": ["electronics"]},
"in_stock": True,
}
)
)
# Update with multiple conditions # Update with multiple conditions
update_result = products_collection.update_many( update_result = products_collection.update_many(
{ {"price": {"$lt": 100}, "in_stock": True},
"price": {"$lt": 100}, {"$set": {"discount": 0.1}, "$inc": {"price": -10}},
"in_stock": True
},
{
"$set": {"discount": 0.1},
"$inc": {"price": -10}
}
) )
# Verify the update # Verify the update
updated_product = products_collection.find_one({"name": "Cheap Mouse"}) updated_product = products_collection.find_one({"name": "Cheap Mouse"})
success = ( success = (
len(expensive_electronics) == 1 and len(expensive_electronics) == 1
expensive_electronics[0]["name"] == "Expensive Laptop" and and expensive_electronics[0]["name"] == "Expensive Laptop"
update_result.modified_count == 1 and and update_result.modified_count == 1
updated_product["price"] == 19.99 and and updated_product["price"] == 19.99
updated_product["discount"] == 0.1 and updated_product["discount"] == 0.1
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -250,19 +263,19 @@ def test_complex_queries():
def run_all_tests(): def run_all_tests():
"""Run all MongoDB tests and report results.""" """Run all MongoDB tests and report results."""
print("Starting MongoDB tests...") print("Starting MongoDB tests...")
# Clean up any existing test data before starting # Clean up any existing test data before starting
cleanup_test_data() cleanup_test_data()
tests = [ tests = [
test_basic_crud_operations, test_basic_crud_operations,
test_nested_documents, test_nested_documents,
test_array_operations, test_array_operations,
test_aggregation, test_aggregation,
test_index_operations, test_index_operations,
test_complex_queries test_complex_queries,
] ]
passed_list, not_passed_list = [], [] passed_list, not_passed_list = [], []
passed, failed = 0, 0 passed, failed = 0, 0
@ -282,9 +295,9 @@ def run_all_tests():
not_passed_list.append(f"Test {test.__name__} failed") not_passed_list.append(f"Test {test.__name__} failed")
print(f"\nTest Results: {passed} passed, {failed} failed") print(f"\nTest Results: {passed} passed, {failed} failed")
print('Passed Tests:') print("Passed Tests:")
print("\n".join(passed_list)) print("\n".join(passed_list))
print('Failed Tests:') print("Failed Tests:")
print("\n".join(not_passed_list)) print("\n".join(not_passed_list))
return passed, failed return passed, failed

View File

@ -14,6 +14,7 @@ class Credentials(BaseModel):
""" """
Class to store user credentials. Class to store user credentials.
""" """
person_id: int person_id: int
person_name: str person_name: str
full_name: Optional[str] = None full_name: Optional[str] = None
@ -23,6 +24,7 @@ class MetaData:
""" """
Class to store metadata for a query. Class to store metadata for a query.
""" """
created: bool = False created: bool = False
updated: bool = False updated: bool = False
@ -30,7 +32,7 @@ class MetaData:
class CRUDModel: class CRUDModel:
""" """
Base class for CRUD operations on PostgreSQL models. Base class for CRUD operations on PostgreSQL models.
Features: Features:
- User credential tracking - User credential tracking
- Metadata tracking for operations - Metadata tracking for operations
@ -38,21 +40,21 @@ class CRUDModel:
- Automatic timestamp management - Automatic timestamp management
- Soft delete support - Soft delete support
""" """
__abstract__ = True __abstract__ = True
creds: Credentials = None creds: Credentials = None
meta_data: MetaData = MetaData() meta_data: MetaData = MetaData()
# Define required columns for CRUD operations # Define required columns for CRUD operations
required_columns = { required_columns = {
'expiry_starts': TIMESTAMP, "expiry_starts": TIMESTAMP,
'expiry_ends': TIMESTAMP, "expiry_ends": TIMESTAMP,
'created_by': str, "created_by": str,
'created_by_id': int, "created_by_id": int,
'updated_by': str, "updated_by": str,
'updated_by_id': int, "updated_by_id": int,
'deleted': bool "deleted": bool,
} }
@classmethod @classmethod
@ -65,24 +67,25 @@ class CRUDModel:
""" """
if not cls.creds: if not cls.creds:
return return
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None): if getattr(cls.creds, "person_id", None) and getattr(
cls.creds, "person_name", None
):
record_created.created_by_id = cls.creds.person_id record_created.created_by_id = cls.creds.person_id
record_created.created_by = cls.creds.person_name record_created.created_by = cls.creds.person_name
@classmethod @classmethod
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400): def raise_exception(
cls, message: str = "Exception raised.", status_code: int = 400
):
""" """
Raise HTTP exception with custom message and status code. Raise HTTP exception with custom message and status code.
Args: Args:
message: Error message message: Error message
status_code: HTTP status code status_code: HTTP status code
""" """
raise HTTPException( raise HTTPException(status_code=status_code, detail={"message": message})
status_code=status_code,
detail={"message": message}
)
@classmethod @classmethod
def create_or_abort(cls, db: Session, **kwargs): def create_or_abort(cls, db: Session, **kwargs):
@ -111,7 +114,7 @@ class CRUDModel:
query = query.filter(getattr(cls, key) == value) query = query.filter(getattr(cls, key) == value)
already_record = query.first() already_record = query.first()
# Handle existing record # Handle existing record
if already_record and already_record.deleted: if already_record and already_record.deleted:
cls.raise_exception("Record already exists and is deleted") cls.raise_exception("Record already exists and is deleted")
@ -122,12 +125,12 @@ class CRUDModel:
created_record = cls() created_record = cls()
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(created_record, key, value) setattr(created_record, key, value)
cls.create_credentials(created_record) cls.create_credentials(created_record)
db.add(created_record) db.add(created_record)
db.flush() db.flush()
return created_record return created_record
except Exception as e: except Exception as e:
db.rollback() db.rollback()
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500) cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
@ -146,7 +149,7 @@ class CRUDModel:
""" """
try: try:
key_ = cls.__annotations__.get(key, None) key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, 'primary_keys', []) is_primary = key in getattr(cls, "primary_keys", [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None)) row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys # Skip primary keys and foreign keys
@ -167,12 +170,16 @@ class CRUDModel:
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]: elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3) return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]: elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")) return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif key_ == Mapped[str]: elif key_ == Mapped[str]:
return True, str(val) return True, str(val)
else: # Handle based on Python types else: # Handle based on Python types
if isinstance(val, datetime.datetime): if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")) return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif isinstance(val, bool): elif isinstance(val, bool):
return True, bool(val) return True, bool(val)
elif isinstance(val, (float, Decimal)): elif isinstance(val, (float, Decimal)):
@ -185,17 +192,19 @@ class CRUDModel:
return True, None return True, None
return False, None return False, None
except Exception as e: except Exception as e:
return False, None return False, None
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]: def get_dict(
self, exclude_list: Optional[list[InstrumentedAttribute]] = None
) -> Dict[str, Any]:
""" """
Convert model instance to dictionary with customizable fields. Convert model instance to dictionary with customizable fields.
Args: Args:
exclude_list: List of fields to exclude from the dictionary exclude_list: List of fields to exclude from the dictionary
Returns: Returns:
Dictionary representation of the model Dictionary representation of the model
""" """
@ -207,7 +216,7 @@ class CRUDModel:
# Get all column names from the model # Get all column names from the model
columns = [col.name for col in self.__table__.columns] columns = [col.name for col in self.__table__.columns]
columns_set = set(columns) columns_set = set(columns)
# Filter columns # Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"]) columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set( columns_extend = set(
@ -223,7 +232,7 @@ class CRUDModel:
return_dict[key] = value_of_database return_dict[key] = value_of_database
return return_dict return return_dict
except Exception as e: except Exception as e:
return {} return {}
@ -251,10 +260,10 @@ class CRUDModel:
cls.expiry_ends > str(arrow.now()), cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()), cls.expiry_starts <= str(arrow.now()),
) )
exclude_args = exclude_args or [] exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args] exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(cls, key) and key not in exclude_args: if hasattr(cls, key) and key not in exclude_args:
query = query.filter(getattr(cls, key) == value) query = query.filter(getattr(cls, key) == value)
@ -268,16 +277,18 @@ class CRUDModel:
created_record = cls() created_record = cls()
for key, value in kwargs.items(): for key, value in kwargs.items():
setattr(created_record, key, value) setattr(created_record, key, value)
cls.create_credentials(created_record) cls.create_credentials(created_record)
db.add(created_record) db.add(created_record)
db.flush() db.flush()
cls.meta_data.created = True cls.meta_data.created = True
return created_record return created_record
except Exception as e: except Exception as e:
db.rollback() db.rollback()
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500) cls.raise_exception(
f"Failed to find or create record: {str(e)}", status_code=500
)
def update(self, db: Session, **kwargs): def update(self, db: Session, **kwargs):
""" """
@ -301,7 +312,7 @@ class CRUDModel:
db.flush() db.flush()
self.meta_data.updated = True self.meta_data.updated = True
return self return self
except Exception as e: except Exception as e:
self.meta_data.updated = False self.meta_data.updated = False
db.rollback() db.rollback()
@ -313,10 +324,10 @@ class CRUDModel:
""" """
if not self.creds: if not self.creds:
return return
person_id = getattr(self.creds, "person_id", None) person_id = getattr(self.creds, "person_id", None)
person_name = getattr(self.creds, "person_name", None) person_name = getattr(self.creds, "person_name", None)
if person_id and person_name: if person_id and person_name:
self.updated_by_id = self.creds.person_id self.updated_by_id = self.creds.person_id
self.updated_by = self.creds.person_name self.updated_by = self.creds.person_name

View File

@ -10,11 +10,11 @@ from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Sessi
engine = create_engine( engine = create_engine(
postgres_configs.url, postgres_configs.url,
pool_pre_ping=True, pool_pre_ping=True,
pool_size=10, # Reduced from 20 to better match your CPU cores pool_size=10, # Reduced from 20 to better match your CPU cores
max_overflow=5, # Reduced from 10 to prevent too many connections max_overflow=5, # Reduced from 10 to prevent too many connections
pool_recycle=600, # Keep as is pool_recycle=600, # Keep as is
pool_timeout=30, # Keep as is pool_timeout=30, # Keep as is
echo=False, # Consider setting to False in production echo=False, # Consider setting to False in production
) )

View File

@ -35,51 +35,49 @@ class QueryModel:
@classmethod @classmethod
def add_new_arg_to_args( def add_new_arg_to_args(
cls: Type[T], cls: Type[T],
args_list: tuple[BinaryExpression, ...], args_list: tuple[BinaryExpression, ...],
argument: str, argument: str,
value: BinaryExpression value: BinaryExpression,
) -> tuple[BinaryExpression, ...]: ) -> tuple[BinaryExpression, ...]:
""" """
Add a new argument to the query arguments if it doesn't exist. Add a new argument to the query arguments if it doesn't exist.
Args: Args:
args_list: Existing query arguments args_list: Existing query arguments
argument: Key of the argument to check for argument: Key of the argument to check for
value: New argument value to add value: New argument value to add
Returns: Returns:
Updated tuple of query arguments Updated tuple of query arguments
""" """
# Convert to set to remove duplicates while preserving order # Convert to set to remove duplicates while preserving order
new_args = list(dict.fromkeys( new_args = list(
arg for arg in args_list dict.fromkeys(arg for arg in args_list if isinstance(arg, BinaryExpression))
if isinstance(arg, BinaryExpression) )
))
# Check if argument already exists # Check if argument already exists
if not any( if not any(
getattr(getattr(arg, "left", None), "key", None) == argument getattr(getattr(arg, "left", None), "key", None) == argument
for arg in new_args for arg in new_args
): ):
new_args.append(value) new_args.append(value)
return tuple(new_args) return tuple(new_args)
@classmethod @classmethod
def get_not_expired_query_arg( def get_not_expired_query_arg(
cls: Type[T], cls: Type[T], args: tuple[BinaryExpression, ...]
args: tuple[BinaryExpression, ...]
) -> tuple[BinaryExpression, ...]: ) -> tuple[BinaryExpression, ...]:
""" """
Add expiry date filtering to the query arguments. Add expiry date filtering to the query arguments.
Args: Args:
args: Existing query arguments args: Existing query arguments
Returns: Returns:
Updated tuple of query arguments with expiry filters Updated tuple of query arguments with expiry filters
Raises: Raises:
AttributeError: If model does not have expiry_starts or expiry_ends columns AttributeError: If model does not have expiry_starts or expiry_ends columns
""" """
@ -87,21 +85,21 @@ class QueryModel:
current_time = str(arrow.now()) current_time = str(arrow.now())
# Only add expiry filters if they don't already exist # Only add expiry filters if they don't already exist
if not any( if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends" getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
for arg in args for arg in args
): ):
ends = cls.expiry_ends > current_time ends = cls.expiry_ends > current_time
args = cls.add_new_arg_to_args(args, "expiry_ends", ends) args = cls.add_new_arg_to_args(args, "expiry_ends", ends)
if not any( if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts" getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
for arg in args for arg in args
): ):
starts = cls.expiry_starts <= current_time starts = cls.expiry_starts <= current_time
args = cls.add_new_arg_to_args(args, "expiry_starts", starts) args = cls.add_new_arg_to_args(args, "expiry_starts", starts)
return args return args
except AttributeError as e: except AttributeError as e:
raise AttributeError( raise AttributeError(
f"Model {cls.__name__} must have expiry_starts and expiry_ends columns" f"Model {cls.__name__} must have expiry_starts and expiry_ends columns"
@ -111,7 +109,7 @@ class QueryModel:
def produce_query_to_add(cls: Type[T], filter_list: dict, args: tuple) -> tuple: def produce_query_to_add(cls: Type[T], filter_list: dict, args: tuple) -> tuple:
""" """
Adds query to main filter options Adds query to main filter options
Args: Args:
filter_list: Dictionary containing query parameters filter_list: Dictionary containing query parameters
args: Existing query arguments to add to args: Existing query arguments to add to
@ -122,11 +120,11 @@ class QueryModel:
try: try:
if not filter_list or not isinstance(filter_list, dict): if not filter_list or not isinstance(filter_list, dict):
return args return args
query_params = filter_list.get("query") query_params = filter_list.get("query")
if not query_params or not isinstance(query_params, dict): if not query_params or not isinstance(query_params, dict):
return args return args
for key, value in query_params.items(): for key, value in query_params.items():
if hasattr(cls, key): if hasattr(cls, key):
# Create a new filter expression # Create a new filter expression
@ -134,39 +132,34 @@ class QueryModel:
# Add it to args if it doesn't exist # Add it to args if it doesn't exist
args = cls.add_new_arg_to_args(args, key, filter_expr) args = cls.add_new_arg_to_args(args, key, filter_expr)
return args return args
except Exception as e: except Exception as e:
print(f"Error in produce_query_to_add: {str(e)}") print(f"Error in produce_query_to_add: {str(e)}")
return args return args
@classmethod @classmethod
def convert( def convert(
cls: Type[T], cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None
smart_options: dict[str, Any],
validate_model: Any = None
) -> Optional[tuple[BinaryExpression, ...]]: ) -> Optional[tuple[BinaryExpression, ...]]:
""" """
Convert smart options to SQLAlchemy filter expressions. Convert smart options to SQLAlchemy filter expressions.
Args: Args:
smart_options: Dictionary of filter options smart_options: Dictionary of filter options
validate_model: Optional model to validate against validate_model: Optional model to validate against
Returns: Returns:
Tuple of SQLAlchemy filter expressions or None if validation fails Tuple of SQLAlchemy filter expressions or None if validation fails
""" """
if validate_model is not None: if validate_model is not None:
# Add validation logic here if needed # Add validation logic here if needed
pass pass
return tuple(cls.filter_expr(**smart_options)) return tuple(cls.filter_expr(**smart_options))
@classmethod @classmethod
def filter_by_one( def filter_by_one(
cls: Type[T], cls: Type[T], db: Session, system: bool = False, **kwargs: Any
db: Session,
system: bool = False,
**kwargs: Any
) -> PostgresResponse[T]: ) -> PostgresResponse[T]:
""" """
Filter single record by keyword arguments. Filter single record by keyword arguments.
@ -181,30 +174,28 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add keyword filters first # Add keyword filters first
query = query.filter_by(**kwargs) query = query.filter_by(**kwargs)
# Add status filters if not system query # Add status filters if not system query
if not system: if not system:
query = query.filter( query = query.filter(
cls.is_confirmed == True, cls.is_confirmed == True, cls.deleted == False, cls.active == True
cls.deleted == False,
cls.active == True
) )
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=False is_array=False,
) )
@classmethod @classmethod
@ -225,29 +216,27 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add expression filters first # Add expression filters first
query = query.filter(*args) query = query.filter(*args)
# Add status filters # Add status filters
query = query.filter( query = query.filter(
cls.is_confirmed == True, cls.is_confirmed == True, cls.deleted == False, cls.active == True
cls.deleted == False,
cls.active == True
) )
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=False is_array=False,
) )
@classmethod @classmethod
@ -268,22 +257,22 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add expression filters first # Add expression filters first
query = query.filter(*args) query = query.filter(*args)
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=False is_array=False,
) )
@classmethod @classmethod
@ -304,22 +293,22 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add expression filters first # Add expression filters first
query = query.filter(*args) query = query.filter(*args)
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=True is_array=True,
) )
@classmethod @classmethod
@ -340,36 +329,32 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add expression filters first # Add expression filters first
query = query.filter(*args) query = query.filter(*args)
# Add status filters # Add status filters
query = query.filter( query = query.filter(
cls.is_confirmed == True, cls.is_confirmed == True, cls.deleted == False, cls.active == True
cls.deleted == False,
cls.active == True
) )
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=True is_array=True,
) )
@classmethod @classmethod
def filter_by_all_system( def filter_by_all_system(
cls: Type[T], cls: Type[T], db: Session, **kwargs: Any
db: Session,
**kwargs: Any
) -> PostgresResponse[T]: ) -> PostgresResponse[T]:
""" """
Filter multiple records by keyword arguments without status filtering. Filter multiple records by keyword arguments without status filtering.
@ -383,29 +368,27 @@ class QueryModel:
""" """
# Get base query (either pre_query or new query) # Get base query (either pre_query or new query)
base_query = cls._query(db) base_query = cls._query(db)
# Create the final query by applying filters # Create the final query by applying filters
query = base_query query = base_query
# Add keyword filters first # Add keyword filters first
query = query.filter_by(**kwargs) query = query.filter_by(**kwargs)
# Add expiry filters last # Add expiry filters last
args = cls.get_not_expired_query_arg(()) args = cls.get_not_expired_query_arg(())
query = query.filter(*args) query = query.filter(*args)
return PostgresResponse( return PostgresResponse(
model=cls, model=cls,
pre_query=base_query, # Use the base query for pre_query pre_query=base_query, # Use the base query for pre_query
query=query, query=query,
is_array=True is_array=True,
) )
@classmethod @classmethod
def filter_by_one_system( def filter_by_one_system(
cls: Type[T], cls: Type[T], db: Session, **kwargs: Any
db: Session,
**kwargs: Any
) -> PostgresResponse[T]: ) -> PostgresResponse[T]:
""" """
Filter single record by keyword arguments without status filtering. Filter single record by keyword arguments without status filtering.

View File

@ -5,34 +5,35 @@ from Controllers.Postgres.database import Base, engine
def generate_table_in_postgres(): def generate_table_in_postgres():
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist.""" """Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
# Create all tables defined in the Base metadata # Create all tables defined in the Base metadata
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
return True return True
def cleanup_test_data(): def cleanup_test_data():
"""Clean up test data from the database.""" """Clean up test data from the database."""
with EndpointRestriction.new_session() as db_session: with EndpointRestriction.new_session() as db_session:
try: try:
# Get all test records # Get all test records
test_records = EndpointRestriction.filter_all( test_records = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_code.like("TEST%"), EndpointRestriction.endpoint_code.like("TEST%"), db=db_session
db=db_session
).data ).data
# Delete each record using the same session # Delete each record using the same session
for record in test_records: for record in test_records:
# Merge the record into the current session if it's not already attached # Merge the record into the current session if it's not already attached
if record not in db_session: if record not in db_session:
record = db_session.merge(record) record = db_session.merge(record)
db_session.delete(record) db_session.delete(record)
db_session.commit() db_session.commit()
except Exception as e: except Exception as e:
print(f"Error cleaning up test data: {str(e)}") print(f"Error cleaning up test data: {str(e)}")
db_session.rollback() db_session.rollback()
raise e raise e
def create_sample_endpoint_restriction(endpoint_code=None): def create_sample_endpoint_restriction(endpoint_code=None):
"""Create a sample endpoint restriction for testing.""" """Create a sample endpoint restriction for testing."""
if endpoint_code is None: if endpoint_code is None:
@ -43,13 +44,12 @@ def create_sample_endpoint_restriction(endpoint_code=None):
try: try:
# First check if record exists # First check if record exists
existing = EndpointRestriction.filter_one( existing = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code, EndpointRestriction.endpoint_code == endpoint_code, db=db_session
db=db_session
) )
if existing and existing.data: if existing and existing.data:
return existing.data return existing.data
# If not found, create new record # If not found, create new record
endpoint = EndpointRestriction.find_or_create( endpoint = EndpointRestriction.find_or_create(
endpoint_function="test_function", endpoint_function="test_function",
@ -77,6 +77,7 @@ def create_sample_endpoint_restriction(endpoint_code=None):
db_session.rollback() db_session.rollback()
raise e raise e
def test_filter_by_one(): def test_filter_by_one():
"""Test filtering a single record by keyword arguments.""" """Test filtering a single record by keyword arguments."""
print("\nTesting filter_by_one...") print("\nTesting filter_by_one...")
@ -84,22 +85,20 @@ def test_filter_by_one():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
sample_endpoint = create_sample_endpoint_restriction("TEST001") sample_endpoint = create_sample_endpoint_restriction("TEST001")
result = EndpointRestriction.filter_by_one( result = EndpointRestriction.filter_by_one(
db=db_session, db=db_session, endpoint_code="TEST001"
endpoint_code="TEST001"
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 1 and and result.count == 1
result.total_count == 1 and and result.total_count == 1
result.is_list is False and result.is_list is False
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -107,6 +106,7 @@ def test_filter_by_one():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_by_one_system(): def test_filter_by_one_system():
"""Test filtering a single record by keyword arguments without status filtering.""" """Test filtering a single record by keyword arguments without status filtering."""
print("\nTesting filter_by_one_system...") print("\nTesting filter_by_one_system...")
@ -114,23 +114,20 @@ def test_filter_by_one_system():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
sample_endpoint = create_sample_endpoint_restriction("TEST002") sample_endpoint = create_sample_endpoint_restriction("TEST002")
result = EndpointRestriction.filter_by_one( result = EndpointRestriction.filter_by_one(
db=db_session, db=db_session, endpoint_code="TEST002", system=True
endpoint_code="TEST002",
system=True
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 1 and and result.count == 1
result.total_count == 1 and and result.total_count == 1
result.is_list is False and result.is_list is False
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -138,6 +135,7 @@ def test_filter_by_one_system():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_one(): def test_filter_one():
"""Test filtering a single record by expressions.""" """Test filtering a single record by expressions."""
print("\nTesting filter_one...") print("\nTesting filter_one...")
@ -145,22 +143,20 @@ def test_filter_one():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
sample_endpoint = create_sample_endpoint_restriction("TEST003") sample_endpoint = create_sample_endpoint_restriction("TEST003")
result = EndpointRestriction.filter_one( result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == "TEST003", EndpointRestriction.endpoint_code == "TEST003", db=db_session
db=db_session
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 1 and and result.count == 1
result.total_count == 1 and and result.total_count == 1
result.is_list is False and result.is_list is False
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -168,6 +164,7 @@ def test_filter_one():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_one_system(): def test_filter_one_system():
"""Test filtering a single record by expressions without status filtering.""" """Test filtering a single record by expressions without status filtering."""
print("\nTesting filter_one_system...") print("\nTesting filter_one_system...")
@ -175,22 +172,20 @@ def test_filter_one_system():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
sample_endpoint = create_sample_endpoint_restriction("TEST004") sample_endpoint = create_sample_endpoint_restriction("TEST004")
result = EndpointRestriction.filter_one_system( result = EndpointRestriction.filter_one_system(
EndpointRestriction.endpoint_code == "TEST004", EndpointRestriction.endpoint_code == "TEST004", db=db_session
db=db_session
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 1 and and result.count == 1
result.total_count == 1 and and result.total_count == 1
result.is_list is False and result.is_list is False
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -198,6 +193,7 @@ def test_filter_one_system():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_all(): def test_filter_all():
"""Test filtering multiple records by expressions.""" """Test filtering multiple records by expressions."""
print("\nTesting filter_all...") print("\nTesting filter_all...")
@ -205,25 +201,23 @@ def test_filter_all():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
# Create two endpoint restrictions # Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST005") endpoint1 = create_sample_endpoint_restriction("TEST005")
endpoint2 = create_sample_endpoint_restriction("TEST006") endpoint2 = create_sample_endpoint_restriction("TEST006")
result = EndpointRestriction.filter_all( result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
db=db_session
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 2 and and result.count == 2
result.total_count == 2 and and result.total_count == 2
result.is_list is True and result.is_list is True
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -231,6 +225,7 @@ def test_filter_all():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_all_system(): def test_filter_all_system():
"""Test filtering multiple records by expressions without status filtering.""" """Test filtering multiple records by expressions without status filtering."""
print("\nTesting filter_all_system...") print("\nTesting filter_all_system...")
@ -238,25 +233,23 @@ def test_filter_all_system():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
# Create two endpoint restrictions # Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST007") endpoint1 = create_sample_endpoint_restriction("TEST007")
endpoint2 = create_sample_endpoint_restriction("TEST008") endpoint2 = create_sample_endpoint_restriction("TEST008")
result = EndpointRestriction.filter_all_system( result = EndpointRestriction.filter_all_system(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
db=db_session
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 2 and and result.count == 2
result.total_count == 2 and and result.total_count == 2
result.is_list is True and result.is_list is True
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -264,6 +257,7 @@ def test_filter_all_system():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_filter_by_all_system(): def test_filter_by_all_system():
"""Test filtering multiple records by keyword arguments without status filtering.""" """Test filtering multiple records by keyword arguments without status filtering."""
print("\nTesting filter_by_all_system...") print("\nTesting filter_by_all_system...")
@ -271,25 +265,23 @@ def test_filter_by_all_system():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
# Create two endpoint restrictions # Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST009") endpoint1 = create_sample_endpoint_restriction("TEST009")
endpoint2 = create_sample_endpoint_restriction("TEST010") endpoint2 = create_sample_endpoint_restriction("TEST010")
result = EndpointRestriction.filter_by_all_system( result = EndpointRestriction.filter_by_all_system(
db=db_session, db=db_session, endpoint_method="GET"
endpoint_method="GET"
) )
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
result is not None and result is not None
result.count == 2 and and result.count == 2
result.total_count == 2 and and result.total_count == 2
result.is_list is True and result.is_list is True
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -297,23 +289,32 @@ def test_filter_by_all_system():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_get_not_expired_query_arg(): def test_get_not_expired_query_arg():
"""Test adding expiry date filtering to query arguments.""" """Test adding expiry date filtering to query arguments."""
print("\nTesting get_not_expired_query_arg...") print("\nTesting get_not_expired_query_arg...")
with EndpointRestriction.new_session() as db_session: with EndpointRestriction.new_session() as db_session:
try: try:
# Create a sample endpoint with a unique code # Create a sample endpoint with a unique code
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}" endpoint_code = (
f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
)
sample_endpoint = create_sample_endpoint_restriction(endpoint_code) sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Test the query argument generation # Test the query argument generation
args = EndpointRestriction.get_not_expired_query_arg(()) args = EndpointRestriction.get_not_expired_query_arg(())
# Verify the arguments # Verify the arguments
success = ( success = (
len(args) == 2 and len(args) == 2
any(str(arg).startswith("endpoint_restriction.expiry_starts") for arg in args) and and any(
any(str(arg).startswith("endpoint_restriction.expiry_ends") for arg in args) str(arg).startswith("endpoint_restriction.expiry_starts")
for arg in args
)
and any(
str(arg).startswith("endpoint_restriction.expiry_ends")
for arg in args
)
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
@ -321,27 +322,33 @@ def test_get_not_expired_query_arg():
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_add_new_arg_to_args(): def test_add_new_arg_to_args():
"""Test adding new arguments to query arguments.""" """Test adding new arguments to query arguments."""
print("\nTesting add_new_arg_to_args...") print("\nTesting add_new_arg_to_args...")
try: try:
args = (EndpointRestriction.endpoint_code == "TEST001",) args = (EndpointRestriction.endpoint_code == "TEST001",)
new_arg = EndpointRestriction.endpoint_method == "GET" new_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg) updated_args = EndpointRestriction.add_new_arg_to_args(
args, "endpoint_method", new_arg
)
success = len(updated_args) == 2 success = len(updated_args) == 2
# Test duplicate prevention # Test duplicate prevention
duplicate_arg = EndpointRestriction.endpoint_method == "GET" duplicate_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg) updated_args = EndpointRestriction.add_new_arg_to_args(
updated_args, "endpoint_method", duplicate_arg
)
success = success and len(updated_args) == 2 # Should not add duplicate success = success and len(updated_args) == 2 # Should not add duplicate
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
except Exception as e: except Exception as e:
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_produce_query_to_add(): def test_produce_query_to_add():
"""Test adding query parameters to filter options.""" """Test adding query parameters to filter options."""
print("\nTesting produce_query_to_add...") print("\nTesting produce_query_to_add...")
@ -349,36 +356,31 @@ def test_produce_query_to_add():
try: try:
sample_endpoint = create_sample_endpoint_restriction("TEST001") sample_endpoint = create_sample_endpoint_restriction("TEST001")
filter_list = { filter_list = {
"query": { "query": {"endpoint_method": "GET", "endpoint_code": "TEST001"}
"endpoint_method": "GET",
"endpoint_code": "TEST001"
}
} }
args = () args = ()
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args) updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
success = len(updated_args) == 2 success = len(updated_args) == 2
result = EndpointRestriction.filter_all( result = EndpointRestriction.filter_all(*updated_args, db=db_session)
*updated_args,
db=db_session
)
# Test PostgresResponse properties # Test PostgresResponse properties
success = ( success = (
success and success
result is not None and and result is not None
result.count == 1 and and result.count == 1
result.total_count == 1 and and result.total_count == 1
result.is_list is True and result.is_list is True
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
except Exception as e: except Exception as e:
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def test_get_dict(): def test_get_dict():
"""Test the get_dict() function for single-record filters.""" """Test the get_dict() function for single-record filters."""
print("\nTesting get_dict...") print("\nTesting get_dict...")
@ -386,51 +388,50 @@ def test_get_dict():
try: try:
# Set up pre_query first # Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", EndpointRestriction.endpoint_method == "GET", db=db_session
db=db_session
).query ).query
# Create a sample endpoint # Create a sample endpoint
endpoint_code = "TEST_DICT_001" endpoint_code = "TEST_DICT_001"
sample_endpoint = create_sample_endpoint_restriction(endpoint_code) sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Get the endpoint using filter_one # Get the endpoint using filter_one
result = EndpointRestriction.filter_one( result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code, EndpointRestriction.endpoint_code == endpoint_code, db=db_session
db=db_session
) )
# Get the data and convert to dict # Get the data and convert to dict
data = result.data data = result.data
data_dict = data.get_dict() data_dict = data.get_dict()
# Test dictionary properties # Test dictionary properties
success = ( success = (
data_dict is not None and data_dict is not None
isinstance(data_dict, dict) and and isinstance(data_dict, dict)
data_dict.get("endpoint_code") == endpoint_code and and data_dict.get("endpoint_code") == endpoint_code
data_dict.get("endpoint_method") == "GET" and and data_dict.get("endpoint_method") == "GET"
data_dict.get("endpoint_function") == "test_function" and and data_dict.get("endpoint_function") == "test_function"
data_dict.get("endpoint_name") == "Test Endpoint" and and data_dict.get("endpoint_name") == "Test Endpoint"
data_dict.get("endpoint_desc") == "Test Description" and and data_dict.get("endpoint_desc") == "Test Description"
data_dict.get("is_confirmed") is True and and data_dict.get("is_confirmed") is True
data_dict.get("active") is True and and data_dict.get("active") is True
data_dict.get("deleted") is False and data_dict.get("deleted") is False
) )
print(f"Test {'passed' if success else 'failed'}") print(f"Test {'passed' if success else 'failed'}")
return success return success
except Exception as e: except Exception as e:
print(f"Test failed with exception: {e}") print(f"Test failed with exception: {e}")
return False return False
def run_all_tests(): def run_all_tests():
"""Run all tests and report results.""" """Run all tests and report results."""
print("Starting EndpointRestriction tests...") print("Starting EndpointRestriction tests...")
# Clean up any existing test data before starting # Clean up any existing test data before starting
cleanup_test_data() cleanup_test_data()
tests = [ tests = [
test_filter_by_one, test_filter_by_one,
test_filter_by_one_system, test_filter_by_one_system,
@ -442,7 +443,7 @@ def run_all_tests():
test_get_not_expired_query_arg, test_get_not_expired_query_arg,
test_add_new_arg_to_args, test_add_new_arg_to_args,
test_produce_query_to_add, test_produce_query_to_add,
test_get_dict # Added new test test_get_dict, # Added new test
] ]
passed_list, not_passed_list = [], [] passed_list, not_passed_list = [], []
passed, failed = 0, 0 passed, failed = 0, 0
@ -453,33 +454,24 @@ def run_all_tests():
try: try:
if test(): if test():
passed += 1 passed += 1
passed_list.append( passed_list.append(f"Test {test.__name__} passed")
f"Test {test.__name__} passed"
)
else: else:
failed += 1 failed += 1
not_passed_list.append( not_passed_list.append(f"Test {test.__name__} failed")
f"Test {test.__name__} failed"
)
except Exception as e: except Exception as e:
print(f"Test {test.__name__} failed with exception: {e}") print(f"Test {test.__name__} failed with exception: {e}")
failed += 1 failed += 1
not_passed_list.append( not_passed_list.append(f"Test {test.__name__} failed")
f"Test {test.__name__} failed"
)
print(f"\nTest Results: {passed} passed, {failed} failed") print(f"\nTest Results: {passed} passed, {failed} failed")
print('Passed Tests:') print("Passed Tests:")
print( print("\n".join(passed_list))
"\n".join(passed_list) print("Failed Tests:")
) print("\n".join(not_passed_list))
print('Failed Tests:')
print(
"\n".join(not_passed_list)
)
return passed, failed return passed, failed
if __name__ == "__main__": if __name__ == "__main__":
generate_table_in_postgres() generate_table_in_postgres()
run_all_tests() run_all_tests()

View File

@ -27,4 +27,3 @@ class EndpointRestriction(CrudCollection):
endpoint_code: Mapped[str] = mapped_column( endpoint_code: Mapped[str] = mapped_column(
String, server_default="", unique=True, comment="Unique code for the endpoint" String, server_default="", unique=True, comment="Unique code for the endpoint"
) )

View File

@ -15,7 +15,7 @@ from typing import Union, Dict, List, Optional, Any, TypeVar
from Controllers.Redis.connection import redis_cli from Controllers.Redis.connection import redis_cli
T = TypeVar('T', Dict[str, Any], List[Any]) T = TypeVar("T", Dict[str, Any], List[Any])
class RedisKeyError(Exception): class RedisKeyError(Exception):
@ -277,18 +277,18 @@ class RedisRow:
""" """
if not key: if not key:
raise RedisKeyError("Cannot set empty key") raise RedisKeyError("Cannot set empty key")
# Convert to string for validation # Convert to string for validation
key_str = key.decode() if isinstance(key, bytes) else str(key) key_str = key.decode() if isinstance(key, bytes) else str(key)
# Validate key length (Redis has a 512MB limit for keys) # Validate key length (Redis has a 512MB limit for keys)
if len(key_str) > 512 * 1024 * 1024: if len(key_str) > 512 * 1024 * 1024:
raise RedisKeyError("Key exceeds maximum length of 512MB") raise RedisKeyError("Key exceeds maximum length of 512MB")
# Validate key format (basic check for invalid characters) # Validate key format (basic check for invalid characters)
if any(c in key_str for c in ['\n', '\r', '\t', '\0']): if any(c in key_str for c in ["\n", "\r", "\t", "\0"]):
raise RedisKeyError("Key contains invalid characters") raise RedisKeyError("Key contains invalid characters")
self.key = key if isinstance(key, bytes) else str(key).encode() self.key = key if isinstance(key, bytes) else str(key).encode()
@property @property

View File

@ -5,11 +5,12 @@ class Configs(BaseSettings):
""" """
MongoDB configuration settings. MongoDB configuration settings.
""" """
HOST: str = "" HOST: str = ""
PASSWORD: str = "" PASSWORD: str = ""
PORT: int = 0 PORT: int = 0
DB: int = 0 DB: int = 0
def as_dict(self): def as_dict(self):
return dict( return dict(
host=self.HOST, host=self.HOST,

View File

@ -98,9 +98,7 @@ class RedisConn:
err = e err = e
return False return False
def set_connection( def set_connection(self, **kwargs) -> Redis:
self, **kwargs
) -> Redis:
""" """
Recreate Redis connection with new parameters. Recreate Redis connection with new parameters.

View File

@ -14,6 +14,7 @@ def example_set_json() -> None:
result = RedisActions.set_json(list_keys=keys, value=data, expires=expiry) result = RedisActions.set_json(list_keys=keys, value=data, expires=expiry)
print("Set JSON with expiry:", result.as_dict()) print("Set JSON with expiry:", result.as_dict())
def example_get_json() -> None: def example_get_json() -> None:
"""Example of retrieving JSON data from Redis.""" """Example of retrieving JSON data from Redis."""
# Example 1: Get all matching keys # Example 1: Get all matching keys
@ -25,11 +26,16 @@ def example_get_json() -> None:
result = RedisActions.get_json(list_keys=keys, limit=5) result = RedisActions.get_json(list_keys=keys, limit=5)
print("Get JSON with limit:", result.as_dict()) print("Get JSON with limit:", result.as_dict())
def example_get_json_iterator() -> None: def example_get_json_iterator() -> None:
"""Example of using the JSON iterator for large datasets.""" """Example of using the JSON iterator for large datasets."""
keys = ["user", "profile", "*"] keys = ["user", "profile", "*"]
for row in RedisActions.get_json_iterator(list_keys=keys): for row in RedisActions.get_json_iterator(list_keys=keys):
print("Iterating over JSON row:", row.as_dict if isinstance(row.as_dict, dict) else row.as_dict) print(
"Iterating over JSON row:",
row.as_dict if isinstance(row.as_dict, dict) else row.as_dict,
)
def example_delete_key() -> None: def example_delete_key() -> None:
"""Example of deleting a specific key.""" """Example of deleting a specific key."""
@ -37,12 +43,14 @@ def example_delete_key() -> None:
result = RedisActions.delete_key(key) result = RedisActions.delete_key(key)
print("Delete specific key:", result) print("Delete specific key:", result)
def example_delete() -> None: def example_delete() -> None:
"""Example of deleting multiple keys matching a pattern.""" """Example of deleting multiple keys matching a pattern."""
keys = ["user", "profile", "*"] keys = ["user", "profile", "*"]
result = RedisActions.delete(list_keys=keys) result = RedisActions.delete(list_keys=keys)
print("Delete multiple keys:", result) print("Delete multiple keys:", result)
def example_refresh_ttl() -> None: def example_refresh_ttl() -> None:
"""Example of refreshing TTL for a key.""" """Example of refreshing TTL for a key."""
key = "user:profile:123" key = "user:profile:123"
@ -50,48 +58,53 @@ def example_refresh_ttl() -> None:
result = RedisActions.refresh_ttl(key=key, expires=new_expiry) result = RedisActions.refresh_ttl(key=key, expires=new_expiry)
print("Refresh TTL:", result.as_dict()) print("Refresh TTL:", result.as_dict())
def example_key_exists() -> None: def example_key_exists() -> None:
"""Example of checking if a key exists.""" """Example of checking if a key exists."""
key = "user:profile:123" key = "user:profile:123"
exists = RedisActions.key_exists(key) exists = RedisActions.key_exists(key)
print(f"Key {key} exists:", exists) print(f"Key {key} exists:", exists)
def example_resolve_expires_at() -> None: def example_resolve_expires_at() -> None:
"""Example of resolving expiry time for a key.""" """Example of resolving expiry time for a key."""
from Controllers.Redis.base import RedisRow from Controllers.Redis.base import RedisRow
redis_row = RedisRow() redis_row = RedisRow()
redis_row.set_key("user:profile:123") redis_row.set_key("user:profile:123")
print(redis_row.keys) print(redis_row.keys)
expires_at = RedisActions.resolve_expires_at(redis_row) expires_at = RedisActions.resolve_expires_at(redis_row)
print("Resolve expires at:", expires_at) print("Resolve expires at:", expires_at)
def run_all_examples() -> None: def run_all_examples() -> None:
"""Run all example functions to demonstrate RedisActions functionality.""" """Run all example functions to demonstrate RedisActions functionality."""
print("\n=== Redis Actions Examples ===\n") print("\n=== Redis Actions Examples ===\n")
print("1. Setting JSON data:") print("1. Setting JSON data:")
example_set_json() example_set_json()
print("\n2. Getting JSON data:") print("\n2. Getting JSON data:")
example_get_json() example_get_json()
print("\n3. Using JSON iterator:") print("\n3. Using JSON iterator:")
example_get_json_iterator() example_get_json_iterator()
# print("\n4. Deleting specific key:") # print("\n4. Deleting specific key:")
# example_delete_key() # example_delete_key()
# #
# print("\n5. Deleting multiple keys:") # print("\n5. Deleting multiple keys:")
# example_delete() # example_delete()
print("\n6. Refreshing TTL:") print("\n6. Refreshing TTL:")
example_refresh_ttl() example_refresh_ttl()
print("\n7. Checking key existence:") print("\n7. Checking key existence:")
example_key_exists() example_key_exists()
print("\n8. Resolving expiry time:") print("\n8. Resolving expiry time:")
example_resolve_expires_at() example_resolve_expires_at()
if __name__ == "__main__": if __name__ == "__main__":
run_all_examples() run_all_examples()

View File

@ -67,7 +67,7 @@ class RedisResponse:
# Process single RedisRow # Process single RedisRow
if isinstance(data, RedisRow): if isinstance(data, RedisRow):
result = {**main_dict} result = {**main_dict}
if hasattr(data, 'keys') and hasattr(data, 'row'): if hasattr(data, "keys") and hasattr(data, "row"):
if not isinstance(data.keys, str): if not isinstance(data.keys, str):
raise ValueError("RedisRow keys must be string type") raise ValueError("RedisRow keys must be string type")
result[data.keys] = data.row result[data.keys] = data.row
@ -80,7 +80,11 @@ class RedisResponse:
# Handle list of RedisRow objects # Handle list of RedisRow objects
rows_dict = {} rows_dict = {}
for row in data: for row in data:
if isinstance(row, RedisRow) and hasattr(row, 'keys') and hasattr(row, 'row'): if (
isinstance(row, RedisRow)
and hasattr(row, "keys")
and hasattr(row, "row")
):
if not isinstance(row.keys, str): if not isinstance(row.keys, str):
raise ValueError("RedisRow keys must be string type") raise ValueError("RedisRow keys must be string type")
rows_dict[row.keys] = row.row rows_dict[row.keys] = row.row
@ -137,10 +141,10 @@ class RedisResponse:
if isinstance(self.data, list) and self.data: if isinstance(self.data, list) and self.data:
item = self.data[0] item = self.data[0]
if isinstance(item, RedisRow) and hasattr(item, 'row'): if isinstance(item, RedisRow) and hasattr(item, "row"):
return item.row return item.row
return item return item
elif isinstance(self.data, RedisRow) and hasattr(self.data, 'row'): elif isinstance(self.data, RedisRow) and hasattr(self.data, "row"):
return self.data.row return self.data.row
elif isinstance(self.data, dict): elif isinstance(self.data, dict):
return self.data return self.data
@ -168,16 +172,16 @@ class RedisResponse:
"success": self.status, "success": self.status,
"message": self.message, "message": self.message,
} }
if self.error: if self.error:
response["error"] = self.error response["error"] = self.error
if self.data is not None: if self.data is not None:
if self.data_type == "row" and hasattr(self.data, 'to_dict'): if self.data_type == "row" and hasattr(self.data, "to_dict"):
response["data"] = self.data.to_dict() response["data"] = self.data.to_dict()
elif self.data_type == "list": elif self.data_type == "list":
try: try:
if all(hasattr(item, 'to_dict') for item in self.data): if all(hasattr(item, "to_dict") for item in self.data):
response["data"] = [item.to_dict() for item in self.data] response["data"] = [item.to_dict() for item in self.data]
else: else:
response["data"] = self.data response["data"] = self.data
@ -192,5 +196,5 @@ class RedisResponse:
return { return {
"success": False, "success": False,
"message": "Error formatting response", "message": "Error formatting response",
"error": str(e) "error": str(e),
} }

View File

@ -15,11 +15,15 @@ class PasswordModule:
@staticmethod @staticmethod
def generate_token(length=32) -> str: def generate_token(length=32) -> str:
letters = "abcdefghijklmnopqrstuvwxyz" letters = "abcdefghijklmnopqrstuvwxyz"
merged_letters = [letter for letter in letters] + [letter.upper() for letter in letters] merged_letters = [letter for letter in letters] + [
letter.upper() for letter in letters
]
token_generated = secrets.token_urlsafe(length) token_generated = secrets.token_urlsafe(length)
for i in str(token_generated): for i in str(token_generated):
if i not in merged_letters: if i not in merged_letters:
token_generated = token_generated.replace(i, random.choice(merged_letters), 1) token_generated = token_generated.replace(
i, random.choice(merged_letters), 1
)
return token_generated return token_generated
@classmethod @classmethod

View File

@ -573,4 +573,3 @@ class AccountRecords(CrudCollection):
# ) # )
# ) # )
# print("is all dues_type", payment_dict["dues_type"], paid_value) # print("is all dues_type", payment_dict["dues_type"], paid_value)

View File

@ -6,7 +6,8 @@ from sqlalchemy import (
Boolean, Boolean,
BigInteger, BigInteger,
Integer, Integer,
Text, or_, Text,
or_,
) )
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from Controllers.Postgres.mixin import CrudCollection from Controllers.Postgres.mixin import CrudCollection
@ -107,7 +108,7 @@ class Addresses(CrudCollection):
post_code_list = RelationshipEmployee2PostCode.filter_all( post_code_list = RelationshipEmployee2PostCode.filter_all(
RelationshipEmployee2PostCode.employee_id RelationshipEmployee2PostCode.employee_id
== token_dict.selected_company.employee_id, == token_dict.selected_company.employee_id,
db=db_session db=db_session,
).data ).data
post_code_id_list = [post_code.member_id for post_code in post_code_list] post_code_id_list = [post_code.member_id for post_code in post_code_list]
if not post_code_id_list: if not post_code_id_list:
@ -118,7 +119,9 @@ class Addresses(CrudCollection):
# status_code=404, # status_code=404,
# detail="User has no post code registered. User can not list addresses.", # detail="User has no post code registered. User can not list addresses.",
# ) # )
cls.pre_query = cls.filter_all(cls.post_code_id.in_(post_code_id_list), db=db_session).query cls.pre_query = cls.filter_all(
cls.post_code_id.in_(post_code_id_list), db=db_session
).query
filter_cls = cls.filter_all(*filter_expr or [], db=db_session) filter_cls = cls.filter_all(*filter_expr or [], db=db_session)
cls.pre_query = None cls.pre_query = None
return filter_cls.data return filter_cls.data

View File

@ -244,7 +244,7 @@ class Build(CrudCollection):
livable_parts = BuildParts.filter_all( livable_parts = BuildParts.filter_all(
BuildParts.build_id == self.id, BuildParts.build_id == self.id,
BuildParts.human_livable == True, BuildParts.human_livable == True,
db=db_session db=db_session,
) )
if not livable_parts.data: if not livable_parts.data:
raise HTTPException( raise HTTPException(
@ -260,8 +260,7 @@ class Build(CrudCollection):
for part in self.parts: for part in self.parts:
building_types = {} building_types = {}
build_type = BuildTypes.filter_by_one( build_type = BuildTypes.filter_by_one(
system=True, id=part.build_part_type_id, system=True, id=part.build_part_type_id, db=db_session
db=db_session
).data ).data
if build_type.type_code in building_types: if build_type.type_code in building_types:
building_types[build_type.type_code]["list"].append(part.part_no) building_types[build_type.type_code]["list"].append(part.part_no)
@ -354,7 +353,9 @@ class BuildParts(CrudCollection):
if build_type := BuildTypes.filter_by_one( if build_type := BuildTypes.filter_by_one(
system=True, id=self.part_type_id, db=db_session system=True, id=self.part_type_id, db=db_session
).data: ).data:
return f"{str(build_type.type_name).upper()} : {str(self.part_no).upper()}" return (
f"{str(build_type.type_name).upper()} : {str(self.part_no).upper()}"
)
return f"Undefined:{str(build_type.type_name).upper()}" return f"Undefined:{str(build_type.type_name).upper()}"
@ -430,7 +431,7 @@ class BuildLivingSpace(CrudCollection):
), ),
cls.start_date < formatted_date - timedelta(days=add_days), cls.start_date < formatted_date - timedelta(days=add_days),
cls.stop_date > formatted_date + timedelta(days=add_days), cls.stop_date > formatted_date + timedelta(days=add_days),
db=db_session db=db_session,
) )
return living_spaces.data, living_spaces.count return living_spaces.data, living_spaces.count
@ -625,4 +626,3 @@ class BuildPersonProviding(CrudCollection):
), ),
{"comment": "People providing services for building"}, {"comment": "People providing services for building"},
) )

View File

@ -92,6 +92,7 @@ class BuildDecisionBook(CrudCollection):
@classmethod @classmethod
def retrieve_active_rbm(cls): def retrieve_active_rbm(cls):
from Schemas.building.build import Build from Schemas.building.build import Build
with cls.new_session() as db_session: with cls.new_session() as db_session:
related_build = Build.find_one(id=cls.build_id) related_build = Build.find_one(id=cls.build_id)
related_date = arrow.get(related_build.build_date) related_date = arrow.get(related_build.build_date)
@ -103,7 +104,7 @@ class BuildDecisionBook(CrudCollection):
cls.expiry_ends <= date_processed, cls.expiry_ends <= date_processed,
cls.decision_type == "RBM", cls.decision_type == "RBM",
cls.build_id == related_build.id, cls.build_id == related_build.id,
db=db_session db=db_session,
).data ).data
if not book: if not book:
cls.raise_http_exception( cls.raise_http_exception(
@ -220,7 +221,8 @@ class BuildDecisionBookInvitations(CrudCollection):
first_book_invitation = BuildDecisionBookInvitations.filter_one( first_book_invitation = BuildDecisionBookInvitations.filter_one(
BuildDecisionBookInvitations.build_id BuildDecisionBookInvitations.build_id
== token_dict.selected_occupant.build_id, == token_dict.selected_occupant.build_id,
BuildDecisionBookInvitations.decision_book_id == selected_decision_book.id, BuildDecisionBookInvitations.decision_book_id
== selected_decision_book.id,
BuildDecisionBookInvitations.invitation_attempt == 1, BuildDecisionBookInvitations.invitation_attempt == 1,
db=db_session, db=db_session,
).data ).data
@ -247,11 +249,15 @@ class BuildDecisionBookInvitations(CrudCollection):
second_book_invitation = BuildDecisionBookInvitations.filter_one_system( second_book_invitation = BuildDecisionBookInvitations.filter_one_system(
BuildDecisionBookInvitations.build_id BuildDecisionBookInvitations.build_id
== token_dict.selected_occupant.build_id, == token_dict.selected_occupant.build_id,
BuildDecisionBookInvitations.decision_book_id == selected_decision_book.id, BuildDecisionBookInvitations.decision_book_id
== selected_decision_book.id,
BuildDecisionBookInvitations.invitation_attempt == 2, BuildDecisionBookInvitations.invitation_attempt == 2,
db=db_session, db=db_session,
).data ).data
if not valid_invite_count >= need_attend_count and not second_book_invitation: if (
not valid_invite_count >= need_attend_count
and not second_book_invitation
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=f"In order meeting to be held, {math.ceil(need_attend_count)} people must attend " detail=f"In order meeting to be held, {math.ceil(need_attend_count)} people must attend "
@ -336,7 +342,7 @@ class BuildDecisionBookPerson(CrudCollection):
with self.new_session() as db_session: with self.new_session() as db_session:
all_decision_book_people = self.filter_all_system( all_decision_book_people = self.filter_all_system(
BuildDecisionBookPersonOccupants.invite_id == self.invite_id, BuildDecisionBookPersonOccupants.invite_id == self.invite_id,
db=db_session db=db_session,
) )
BuildDecisionBookPersonOccupants.pre_query = all_decision_book_people.query BuildDecisionBookPersonOccupants.pre_query = all_decision_book_people.query
return BuildDecisionBookPersonOccupants.filter_all_system( return BuildDecisionBookPersonOccupants.filter_all_system(
@ -346,8 +352,9 @@ class BuildDecisionBookPerson(CrudCollection):
def get_occupant_types(self): def get_occupant_types(self):
with self.new_session() as db_session: with self.new_session() as db_session:
if occupants := BuildDecisionBookPersonOccupants.filter_all( if occupants := BuildDecisionBookPersonOccupants.filter_all(
BuildDecisionBookPersonOccupants.build_decision_book_person_id == self.id, BuildDecisionBookPersonOccupants.build_decision_book_person_id
db=db_session == self.id,
db=db_session,
).data: ).data:
return occupants return occupants
return return
@ -355,7 +362,8 @@ class BuildDecisionBookPerson(CrudCollection):
def check_occupant_type(self, occupant_type): def check_occupant_type(self, occupant_type):
with self.new_session() as db_session: with self.new_session() as db_session:
book_person_occupant_type = BuildDecisionBookPersonOccupants.filter_one( book_person_occupant_type = BuildDecisionBookPersonOccupants.filter_one(
BuildDecisionBookPersonOccupants.build_decision_book_person_id == self.id, BuildDecisionBookPersonOccupants.build_decision_book_person_id
== self.id,
BuildDecisionBookPersonOccupants.occupant_type_id == occupant_type.id, BuildDecisionBookPersonOccupants.occupant_type_id == occupant_type.id,
BuildDecisionBookPersonOccupants.active == True, BuildDecisionBookPersonOccupants.active == True,
BuildDecisionBookPersonOccupants.is_confirmed == True, BuildDecisionBookPersonOccupants.is_confirmed == True,

View File

@ -66,13 +66,12 @@ class RelationshipDutyCompany(CrudCollection):
) )
list_match_company_id = [] list_match_company_id = []
send_duties = Duties.filter_one( send_duties = Duties.filter_one(
Duties.uu_id == data.duty_uu_id, Duties.uu_id == data.duty_uu_id, db=db_session
db=db_session
) )
send_user_duties = Duties.filter_one( send_user_duties = Duties.filter_one(
Duties.duties_id == send_duties.id, Duties.duties_id == send_duties.id,
Duties.company_id == token_duties_id, Duties.company_id == token_duties_id,
db=db_session db=db_session,
) )
if not send_user_duties: if not send_user_duties:
raise Exception( raise Exception(
@ -81,14 +80,13 @@ class RelationshipDutyCompany(CrudCollection):
for company_uu_id in list(data.match_company_uu_id): for company_uu_id in list(data.match_company_uu_id):
company = Companies.filter_one( company = Companies.filter_one(
Companies.uu_id == company_uu_id, Companies.uu_id == company_uu_id, db=db_session
db=db_session
) )
bulk_company = RelationshipDutyCompany.filter_one( bulk_company = RelationshipDutyCompany.filter_one(
RelationshipDutyCompany.owner_id == token_company_id, RelationshipDutyCompany.owner_id == token_company_id,
RelationshipDutyCompany.relationship_type == "Bulk", RelationshipDutyCompany.relationship_type == "Bulk",
RelationshipDutyCompany.member_id == company.id, RelationshipDutyCompany.member_id == company.id,
db=db_session db=db_session,
) )
if not bulk_company: if not bulk_company:
raise Exception( raise Exception(
@ -105,7 +103,7 @@ class RelationshipDutyCompany(CrudCollection):
parent_id=match_company_id.parent_id, parent_id=match_company_id.parent_id,
relationship_type="Commercial", relationship_type="Commercial",
show_only=False, show_only=False,
db=db_session db=db_session,
) )
@classmethod @classmethod
@ -116,13 +114,12 @@ class RelationshipDutyCompany(CrudCollection):
) )
list_match_company_id = [] list_match_company_id = []
send_duties = Duties.filter_one( send_duties = Duties.filter_one(
Duties.uu_id == data.duty_uu_id, Duties.uu_id == data.duty_uu_id, db=db_session
db=db_session
) )
send_user_duties = Duties.filter_one( send_user_duties = Duties.filter_one(
Duties.duties_id == send_duties.id, Duties.duties_id == send_duties.id,
Duties.company_id == token_duties_id, Duties.company_id == token_duties_id,
db=db_session db=db_session,
) )
if not send_user_duties: if not send_user_duties:
raise Exception( raise Exception(
@ -131,14 +128,13 @@ class RelationshipDutyCompany(CrudCollection):
for company_uu_id in list(data.match_company_uu_id): for company_uu_id in list(data.match_company_uu_id):
company = Companies.filter_one( company = Companies.filter_one(
Companies.uu_id == company_uu_id, Companies.uu_id == company_uu_id, db=db_session
db=db_session
) )
bulk_company = RelationshipDutyCompany.filter_one( bulk_company = RelationshipDutyCompany.filter_one(
RelationshipDutyCompany.owner_id == token_company_id, RelationshipDutyCompany.owner_id == token_company_id,
RelationshipDutyCompany.relationship_type == "Bulk", RelationshipDutyCompany.relationship_type == "Bulk",
RelationshipDutyCompany.member_id == company.id, RelationshipDutyCompany.member_id == company.id,
db=db_session db=db_session,
) )
if not bulk_company: if not bulk_company:
raise Exception( raise Exception(
@ -151,7 +147,7 @@ class RelationshipDutyCompany(CrudCollection):
Duties.init_a_company_default_duties( Duties.init_a_company_default_duties(
company_id=match_company_id.id, company_id=match_company_id.id,
company_uu_id=str(match_company_id.uu_id), company_uu_id=str(match_company_id.uu_id),
db=db_session db=db_session,
) )
RelationshipDutyCompany.find_or_create( RelationshipDutyCompany.find_or_create(
owner_id=token_company_id, owner_id=token_company_id,
@ -160,7 +156,7 @@ class RelationshipDutyCompany(CrudCollection):
parent_id=match_company_id.parent_id, parent_id=match_company_id.parent_id,
relationship_type="Organization", relationship_type="Organization",
show_only=False, show_only=False,
db=db_session db=db_session,
) )
__table_args__ = ( __table_args__ = (
@ -236,4 +232,3 @@ class Companies(CrudCollection):
Index("_company_ndx_02", formal_name, public_name), Index("_company_ndx_02", formal_name, public_name),
{"comment": "Company Information"}, {"comment": "Company Information"},
) )

View File

@ -13,12 +13,20 @@ class Staff(CrudCollection):
__tablename__ = "staff" __tablename__ = "staff"
__exclude__fields__ = [] __exclude__fields__ = []
staff_description: Mapped[str] = mapped_column(String, server_default="", comment="Staff Description") staff_description: Mapped[str] = mapped_column(
staff_name: Mapped[str] = mapped_column(String, nullable=False, comment="Staff Name") String, server_default="", comment="Staff Description"
staff_code: Mapped[str] = mapped_column(String, nullable=False, comment="Staff Code") )
staff_name: Mapped[str] = mapped_column(
String, nullable=False, comment="Staff Name"
)
staff_code: Mapped[str] = mapped_column(
String, nullable=False, comment="Staff Code"
)
duties_id: Mapped[int] = mapped_column(ForeignKey("duties.id"), nullable=False) duties_id: Mapped[int] = mapped_column(ForeignKey("duties.id"), nullable=False)
duties_uu_id: Mapped[str] = mapped_column(String, nullable=False, comment="Duty UUID") duties_uu_id: Mapped[str] = mapped_column(
String, nullable=False, comment="Duty UUID"
)
__table_args__ = ({"comment": "Staff Information"},) __table_args__ = ({"comment": "Staff Information"},)
@ -29,9 +37,13 @@ class Employees(CrudCollection):
__exclude__fields__ = [] __exclude__fields__ = []
staff_id: Mapped[int] = mapped_column(ForeignKey("staff.id")) staff_id: Mapped[int] = mapped_column(ForeignKey("staff.id"))
staff_uu_id: Mapped[str] = mapped_column(String, nullable=False, comment="Staff UUID") staff_uu_id: Mapped[str] = mapped_column(
String, nullable=False, comment="Staff UUID"
)
people_id: Mapped[int] = mapped_column(ForeignKey("people.id"), nullable=True) people_id: Mapped[int] = mapped_column(ForeignKey("people.id"), nullable=True)
people_uu_id: Mapped[str] = mapped_column(String, nullable=True, comment="People UUID") people_uu_id: Mapped[str] = mapped_column(
String, nullable=True, comment="People UUID"
)
__table_args__ = ( __table_args__ = (
Index("employees_ndx_00", people_id, staff_id, unique=True), Index("employees_ndx_00", people_id, staff_id, unique=True),
@ -44,10 +56,18 @@ class EmployeeHistory(CrudCollection):
__tablename__ = "employee_history" __tablename__ = "employee_history"
__exclude__fields__ = [] __exclude__fields__ = []
staff_id: Mapped[int] = mapped_column(ForeignKey("staff.id"), nullable=False, comment="Staff ID") staff_id: Mapped[int] = mapped_column(
staff_uu_id: Mapped[str] = mapped_column(String, nullable=False, comment="Staff UUID") ForeignKey("staff.id"), nullable=False, comment="Staff ID"
people_id: Mapped[int] = mapped_column(ForeignKey("people.id"), nullable=False, comment="People ID") )
people_uu_id: Mapped[str] = mapped_column(String, nullable=False, comment="People UUID") staff_uu_id: Mapped[str] = mapped_column(
String, nullable=False, comment="Staff UUID"
)
people_id: Mapped[int] = mapped_column(
ForeignKey("people.id"), nullable=False, comment="People ID"
)
people_uu_id: Mapped[str] = mapped_column(
String, nullable=False, comment="People UUID"
)
__table_args__ = ( __table_args__ = (
Index("_employee_history_ndx_00", people_id, staff_id), Index("_employee_history_ndx_00", people_id, staff_id),
@ -67,7 +87,9 @@ class EmployeesSalaries(CrudCollection):
Numeric(20, 6), nullable=False, comment="Net Salary" Numeric(20, 6), nullable=False, comment="Net Salary"
) )
people_id: Mapped[int] = mapped_column(ForeignKey("people.id"), nullable=False) people_id: Mapped[int] = mapped_column(ForeignKey("people.id"), nullable=False)
people_uu_id: Mapped[str] = mapped_column(String, nullable=False, comment="People UUID") people_uu_id: Mapped[str] = mapped_column(
String, nullable=False, comment="People UUID"
)
__table_args__ = ( __table_args__ = (
Index("_employee_salaries_ndx_00", people_id, "expiry_starts"), Index("_employee_salaries_ndx_00", people_id, "expiry_starts"),

View File

@ -110,9 +110,7 @@ class Services(CrudCollection):
def retrieve_service_via_occupant_code(cls, occupant_code): def retrieve_service_via_occupant_code(cls, occupant_code):
with cls.new_session() as db_session: with cls.new_session() as db_session:
occupant_type = OccupantTypes.filter_by_one( occupant_type = OccupantTypes.filter_by_one(
system=True, system=True, occupant_code=occupant_code, db=db_session
occupant_code=occupant_code,
db=db_session
).data ).data
if not occupant_type: if not occupant_type:
cls.raise_http_exception( cls.raise_http_exception(
@ -124,8 +122,7 @@ class Services(CrudCollection):
}, },
) )
return cls.filter_one( return cls.filter_one(
cls.related_responsibility == occupant_type.occupant_code, cls.related_responsibility == occupant_type.occupant_code, db=db_session
db=db_session
).data ).data
__table_args__ = ({"comment": "Services Information"},) __table_args__ = ({"comment": "Services Information"},)

View File

@ -431,4 +431,3 @@ class Contracts(CrudCollection):
Index("_contract_ndx_01", contract_code, unique=True), Index("_contract_ndx_01", contract_code, unique=True),
{"comment": "Contract Information"}, {"comment": "Contract Information"},
) )

View File

@ -40,15 +40,19 @@ class ApiEnumDropdown(CrudCollection):
if search := cls.filter_one_system( if search := cls.filter_one_system(
cls.enum_class.in_(["DebitTypes"]), cls.enum_class.in_(["DebitTypes"]),
cls.uu_id == search_uu_id, cls.uu_id == search_uu_id,
db=db_session db=db_session,
).data: ).data:
return search return search
elif search_debit: elif search_debit:
if search := cls.filter_one( if search := cls.filter_one(
cls.enum_class.in_(["DebitTypes"]), cls.key == search_debit, db=db_session cls.enum_class.in_(["DebitTypes"]),
cls.key == search_debit,
db=db_session,
).data: ).data:
return search return search
return cls.filter_all_system(cls.enum_class.in_(["DebitTypes"]), db=db_session).data return cls.filter_all_system(
cls.enum_class.in_(["DebitTypes"]), db=db_session
).data
@classmethod @classmethod
def get_due_types(cls): def get_due_types(cls):
@ -56,7 +60,7 @@ class ApiEnumDropdown(CrudCollection):
if due_list := cls.filter_all_system( if due_list := cls.filter_all_system(
cls.enum_class == "BuildDuesTypes", cls.enum_class == "BuildDuesTypes",
cls.key.in_(["BDT-A", "BDT-D"]), cls.key.in_(["BDT-A", "BDT-D"]),
db=db_session db=db_session,
).data: ).data:
return [due.uu_id.__str__() for due in due_list] return [due.uu_id.__str__() for due in due_list]
# raise HTTPException( # raise HTTPException(
@ -71,17 +75,19 @@ class ApiEnumDropdown(CrudCollection):
if search := cls.filter_one_system( if search := cls.filter_one_system(
cls.enum_class.in_(["BuildDuesTypes"]), cls.enum_class.in_(["BuildDuesTypes"]),
cls.uu_id == search_uu_id, cls.uu_id == search_uu_id,
db=db_session db=db_session,
).data: ).data:
return search return search
elif search_management: elif search_management:
if search := cls.filter_one_system( if search := cls.filter_one_system(
cls.enum_class.in_(["BuildDuesTypes"]), cls.enum_class.in_(["BuildDuesTypes"]),
cls.key == search_management, cls.key == search_management,
db=db_session db=db_session,
).data: ).data:
return search return search
return cls.filter_all_system(cls.enum_class.in_(["BuildDuesTypes"]), db=db_session).data return cls.filter_all_system(
cls.enum_class.in_(["BuildDuesTypes"]), db=db_session
).data
def get_enum_dict(self): def get_enum_dict(self):
return { return {

View File

@ -50,11 +50,39 @@ services:
ports: ports:
- "11222:6379" - "11222:6379"
template_service: # template_service:
container_name: template_service # container_name: template_service
# build:
# context: .
# dockerfile: ApiServices/TemplateService/Dockerfile
# networks:
# - wag-services
# env_file:
# - api_env.env
# environment:
# - API_PATH=app:app
# - API_HOST=0.0.0.0
# - API_PORT=8000
# - API_LOG_LEVEL=info
# - API_RELOAD=1
# - API_ACCESS_TOKEN_TAG=1
# - API_APP_NAME=evyos-template-api-gateway
# - API_TITLE=WAG API Template Api Gateway
# - API_FORGOT_LINK=https://template_service/forgot-password
# - API_DESCRIPTION=This api is serves as web template api gateway only to evyos web services.
# - API_APP_URL=https://template_service
# ports:
# - "8000:8000"
# depends_on:
# - postgres-service
# - mongo_service
# - redis_service
auth_service:
container_name: auth_service
build: build:
context: . context: .
dockerfile: ApiServices/TemplateService/Dockerfile dockerfile: ApiServices/AuthService/Dockerfile
networks: networks:
- wag-services - wag-services
env_file: env_file:
@ -62,17 +90,17 @@ services:
environment: environment:
- API_PATH=app:app - API_PATH=app:app
- API_HOST=0.0.0.0 - API_HOST=0.0.0.0
- API_PORT=8000 - API_PORT=8001
- API_LOG_LEVEL=info - API_LOG_LEVEL=info
- API_RELOAD=1 - API_RELOAD=1
- API_ACCESS_TOKEN_TAG=1 - API_ACCESS_TOKEN_TAG=eys-acs-tkn
- API_APP_NAME=evyos-template-api-gateway - API_APP_NAME=evyos-auth-api-gateway
- API_TITLE=WAG API Template Api Gateway - API_TITLE=WAG API Auth Api Gateway
- API_FORGOT_LINK=https://template_service/forgot-password - API_FORGOT_LINK=https://auth_service/forgot-password
- API_DESCRIPTION=This api is serves as web template api gateway only to evyos web services. - API_DESCRIPTION=This api is serves as web auth api gateway only to evyos web services.
- API_APP_URL=https://template_service - API_APP_URL=https://auth_service
ports: ports:
- "8000:8000" - "8001:8001"
depends_on: depends_on:
- postgres-service - postgres-service
- mongo_service - mongo_service