updated Mongo Postgres Redis Controllers

This commit is contained in:
2025-04-01 13:37:36 +03:00
parent 5d30bc2701
commit 6b9e9050a2
16 changed files with 1700 additions and 48 deletions

View File

@@ -0,0 +1,106 @@
from typing import Type, TypeVar
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from fastapi import status
from fastapi.exceptions import HTTPException
from database import get_db
# Type variable for class methods returning self
T = TypeVar("T", bound="BaseAlchemyModel")
class BaseAlchemyModel:
"""
Controller of alchemy to database transactions.
Query: Query object for model
Session: Session object for model
Actions: save, flush, rollback, commit
"""
__abstract__ = True
@classmethod
def new_session(cls):
"""Get database session."""
return get_db()
@classmethod
def flush(cls: Type[T], db: Session) -> T:
"""
Flush the current session to the database.
Args:
db: Database session
Returns:
Self instance
Raises:
HTTPException: If database operation fails
"""
try:
db.flush()
return cls
except SQLAlchemyError as e:
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail={
"message": "Database operation failed",
},
)
def destroy(self: Type[T], db: Session) -> None:
"""
Delete the record from the database.
Args:
db: Database session
"""
db.delete(self)
@classmethod
def save(cls: Type[T], db: Session) -> None:
"""
Commit changes to database.
Args:
db: Database session
Raises:
HTTPException: If commit fails
"""
try:
db.commit()
db.flush()
except SQLAlchemyError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail={
"message": "Alchemy save operation failed",
"error": str(e),
},
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": "Unknown exception raised.",
"error": str(e),
},
)
@classmethod
def rollback(cls: Type[T], db: Session) -> None:
"""
Rollback current transaction.
Args:
db: Database session
"""
db.rollback()

View File

@@ -0,0 +1,322 @@
import arrow
import datetime
from typing import Optional, Any, Dict, List
from sqlalchemy.orm import Session, Mapped
from pydantic import BaseModel
from fastapi.exceptions import HTTPException
from decimal import Decimal
from sqlalchemy import TIMESTAMP, NUMERIC
from sqlalchemy.orm.attributes import InstrumentedAttribute
class Credentials(BaseModel):
"""
Class to store user credentials.
"""
person_id: int
person_name: str
full_name: Optional[str] = None
class MetaData:
"""
Class to store metadata for a query.
"""
created: bool = False
updated: bool = False
class CRUDModel:
"""
Base class for CRUD operations on PostgreSQL models.
Features:
- User credential tracking
- Metadata tracking for operations
- Type-safe field handling
- Automatic timestamp management
- Soft delete support
"""
__abstract__ = True
creds: Credentials = None
meta_data: MetaData = MetaData()
# Define required columns for CRUD operations
required_columns = {
'expiry_starts': TIMESTAMP,
'expiry_ends': TIMESTAMP,
'created_by': str,
'created_by_id': int,
'updated_by': str,
'updated_by_id': int,
'deleted': bool
}
@classmethod
def create_credentials(cls, record_created) -> None:
"""
Save user credentials for tracking.
Args:
record_created: Record that created or updated
"""
if not cls.creds:
return
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
record_created.created_by_id = cls.creds.person_id
record_created.created_by = cls.creds.person_name
@classmethod
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400):
"""
Raise HTTP exception with custom message and status code.
Args:
message: Error message
status_code: HTTP status code
"""
raise HTTPException(
status_code=status_code,
detail={"message": message}
)
@classmethod
def create_or_abort(cls, db: Session, **kwargs):
"""
Create a new record or abort if it already exists.
Args:
db: Database session
**kwargs: Record fields
Returns:
New record if successfully created
Raises:
HTTPException: If record already exists or creation fails
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
for key, value in kwargs.items():
if hasattr(cls, key):
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
# Handle existing record
if already_record and already_record.deleted:
cls.raise_exception("Record already exists and is deleted")
elif already_record:
cls.raise_exception("Record already exists")
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
@classmethod
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
"""
Process a field value based on its type and convert it to the appropriate format.
Args:
val: Field value
key: Field name
Returns:
Tuple of (should_include, processed_value)
"""
try:
key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, 'primary_keys', [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
if is_primary or row_attr:
return False, None
if val is None: # Handle None values
return True, None
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
return True, str(val)
if key_: # Handle typed fields
if key_ == Mapped[int]:
return True, int(val)
elif key_ == Mapped[bool]:
return True, bool(val)
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
return True, round(float(val), 3)
elif isinstance(val, int):
return True, int(val)
elif isinstance(val, str):
return True, str(val)
elif val is None:
return True, None
return False, None
except Exception as e:
return False, None
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Args:
exclude_list: List of fields to exclude from the dictionary
Returns:
Dictionary representation of the model
"""
try:
return_dict: Dict[str, Any] = {}
exclude_list = exclude_list or []
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
# Get all column names from the model
columns = [col.name for col in self.__table__.columns]
columns_set = set(columns)
# Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set(
col for col in columns_set if str(col)[-5:].lower() == "uu_id"
)
columns_list = set(columns_list) | set(columns_extend)
columns_list = list(set(columns_list) - set(exclude_list))
for key in columns_list:
val = getattr(self, key)
correct, value_of_database = self.iterate_over_variables(val, key)
if correct:
return_dict[key] = value_of_database
return return_dict
except Exception as e:
return {}
@classmethod
def find_or_create(
cls,
db: Session,
exclude_args: Optional[list[InstrumentedAttribute]] = None,
**kwargs,
):
"""
Find an existing record matching the criteria or create a new one.
Args:
db: Database session
exclude_args: Keys to exclude from search
**kwargs: Search/creation criteria
Returns:
Existing or newly created record
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
for key, value in kwargs.items():
if hasattr(cls, key) and key not in exclude_args:
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
if already_record: # Handle existing record
cls.meta_data.created = False
return already_record
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
cls.meta_data.created = True
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500)
def update(self, db: Session, **kwargs):
"""
Update the record with new values.
Args:
db: Database session
**kwargs: Fields to update
Returns:
Updated record
Raises:
HTTPException: If update fails
"""
try:
for key, value in kwargs.items():
setattr(self, key, value)
self.update_credentials()
db.flush()
self.meta_data.updated = True
return self
except Exception as e:
self.meta_data.updated = False
db.rollback()
self.raise_exception(f"Failed to update record: {str(e)}", status_code=500)
def update_credentials(self) -> None:
"""
Save user credentials for tracking.
"""
if not self.creds:
return
person_id = getattr(self.creds, "person_id", None)
person_name = getattr(self.creds, "person_name", None)
if person_id and person_name:
self.updated_by_id = self.creds.person_id
self.updated_by = self.creds.person_name

View File

@@ -9,12 +9,12 @@ from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Sessi
# Configure the database engine with proper pooling
engine = create_engine(
postgres_configs.url,
pool_pre_ping=True, # Verify connection before using
pool_size=20, # Maximum number of permanent connections
max_overflow=10, # Maximum number of additional connections
pool_recycle=600, # Recycle connections after 1 hour
pool_timeout=30, # Wait up to 30 seconds for a connection
echo=True, # Set to True for debugging SQL queries
pool_pre_ping=True,
pool_size=10, # Reduced from 20 to better match your CPU cores
max_overflow=5, # Reduced from 10 to prevent too many connections
pool_recycle=600, # Keep as is
pool_timeout=30, # Keep as is
echo=True, # Consider setting to False in production
)

View File

@@ -0,0 +1,285 @@
"""
Advanced filtering functionality for SQLAlchemy models.
This module provides a comprehensive set of filtering capabilities for SQLAlchemy models,
including pagination, ordering, and complex query building.
"""
from __future__ import annotations
import arrow
from typing import Any, TypeVar, Type, Union, Optional
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.elements import BinaryExpression
from response import PostgresResponse
T = TypeVar("T", bound="QueryModel")
class QueryModel:
__abstract__ = True
pre_query = None
@classmethod
def _query(cls: Type[T], db: Session) -> Query:
"""Returns the query to use in the model."""
return cls.pre_query if cls.pre_query else db.query(cls)
@classmethod
def add_new_arg_to_args(
cls: Type[T],
args_list: tuple[BinaryExpression, ...],
argument: str,
value: BinaryExpression
) -> tuple[BinaryExpression, ...]:
"""
Add a new argument to the query arguments if it doesn't exist.
Args:
args_list: Existing query arguments
argument: Key of the argument to check for
value: New argument value to add
Returns:
Updated tuple of query arguments
"""
# Convert to set to remove duplicates while preserving order
new_args = list(dict.fromkeys(
arg for arg in args_list
if isinstance(arg, BinaryExpression)
))
# Check if argument already exists
if not any(
getattr(getattr(arg, "left", None), "key", None) == argument
for arg in new_args
):
new_args.append(value)
return tuple(new_args)
@classmethod
def get_not_expired_query_arg(
cls: Type[T],
args: tuple[BinaryExpression, ...]
) -> tuple[BinaryExpression, ...]:
"""
Add expiry date filtering to the query arguments.
Args:
args: Existing query arguments
Returns:
Updated tuple of query arguments with expiry filters
Raises:
AttributeError: If model does not have expiry_starts or expiry_ends columns
"""
try:
current_time = str(arrow.now())
starts = cls.expiry_starts <= current_time
ends = cls.expiry_ends > current_time
args = cls.add_new_arg_to_args(args, "expiry_ends", ends)
args = cls.add_new_arg_to_args(args, "expiry_starts", starts)
return args
except AttributeError as e:
raise AttributeError(
f"Model {cls.__name__} must have expiry_starts and expiry_ends columns"
) from e
@classmethod
def produce_query_to_add(cls: Type[T], filter_list, args):
"""
Adds query to main filter options
Args:
filter_list: Dictionary containing query parameters
args: Existing query arguments to add to
Returns:
Updated query arguments tuple
"""
if filter_list.get("query"):
for smart_iter in cls.filter_expr(**filter_list["query"]):
if key := getattr(getattr(smart_iter, "left", None), "key", None):
args = cls.add_new_arg_to_args(args, key, smart_iter)
return args
@classmethod
def convert(
cls: Type[T],
smart_options: dict[str, Any],
validate_model: Any = None
) -> Optional[tuple[BinaryExpression, ...]]:
"""
Convert smart options to SQLAlchemy filter expressions.
Args:
smart_options: Dictionary of filter options
validate_model: Optional model to validate against
Returns:
Tuple of SQLAlchemy filter expressions or None if validation fails
"""
if validate_model is not None:
# Add validation logic here if needed
pass
return tuple(cls.filter_expr(**smart_options))
@classmethod
def filter_by_one(
cls: Type[T],
db: Session,
system: bool = False,
**kwargs: Any
) -> PostgresResponse[T]:
"""
Filter single record by keyword arguments.
Args:
db: Database session
system: If True, skip status filtering
**kwargs: Filter criteria
Returns:
Query response with single record
"""
if "is_confirmed" not in kwargs and not system:
kwargs["is_confirmed"] = True
kwargs.pop("system", None)
query = cls._query(db).filter_by(**kwargs)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=False
)
@classmethod
def filter_one(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter single record by expressions.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with single record
"""
args = cls.get_not_expired_query_arg(args)
query = cls._query(db).filter(*args)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=False
)
@classmethod
def filter_one_system(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter single record by expressions without status filtering.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with single record
"""
query = cls._query(db).filter(*args)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=False
)
@classmethod
def filter_all_system(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter multiple records by expressions without status filtering.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with matching records
"""
query = cls._query(db).filter(*args)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=True
)
@classmethod
def filter_all(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter multiple records by expressions.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with matching records
"""
args = cls.get_not_expired_query_arg(args)
query = cls._query(db).filter(*args)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=True
)
@classmethod
def filter_by_all_system(
cls: Type[T],
db: Session,
**kwargs: Any
) -> PostgresResponse[T]:
"""
Filter multiple records by keyword arguments.
Args:
db: Database session
**kwargs: Filter criteria
Returns:
Query response with matching records
"""
query = cls._query(db).filter_by(**kwargs)
return PostgresResponse(
model=cls,
pre_query=cls._query(db),
query=query,
is_array=True
)

View File

@@ -0,0 +1,339 @@
import arrow
from schema import EndpointRestriction
def create_sample_endpoint_restriction():
"""Create a sample endpoint restriction for testing."""
with EndpointRestriction.new_session() as db_session:
endpoint = EndpointRestriction.find_or_create(
endpoint_function="test_function",
endpoint_name="Test Endpoint",
endpoint_method="GET",
endpoint_desc="Test Description",
endpoint_code="TEST001",
is_confirmed=True,
expiry_starts=arrow.now().shift(days=-1),
expiry_ends=arrow.now().shift(days=1)
)
endpoint.save(db=db_session)
return endpoint
def test_filter_by_one():
"""Test filtering a single record by keyword arguments."""
print("\nTesting filter_by_one...")
with EndpointRestriction.new_session() as db_session:
sample_endpoint = create_sample_endpoint_restriction()
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST001"
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
result.data is not None and
result.data.endpoint_code == "TEST001" and
result.is_list is False and
isinstance(result.data_as_dict, dict) and
result.data_as_dict.get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_by_one_system():
"""Test filtering a single record by keyword arguments without status filtering."""
print("\nTesting filter_by_one_system...")
with EndpointRestriction.new_session() as db_session:
sample_endpoint = create_sample_endpoint_restriction()
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST001",
system=True
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
result.data is not None and
result.data.endpoint_code == "TEST001" and
result.is_list is False and
isinstance(result.data_as_dict, dict) and
result.data_as_dict.get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_one():
"""Test filtering a single record by expressions."""
print("\nTesting filter_one...")
with EndpointRestriction.new_session() as db_session:
sample_endpoint = create_sample_endpoint_restriction()
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == "TEST001",
db=db_session
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
result.data is not None and
result.data.endpoint_code == "TEST001" and
result.is_list is False and
isinstance(result.data_as_dict, dict) and
result.data_as_dict.get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_one_system():
"""Test filtering a single record by expressions without status filtering."""
print("\nTesting filter_one_system...")
with EndpointRestriction.new_session() as db_session:
sample_endpoint = create_sample_endpoint_restriction()
result = EndpointRestriction.filter_one_system(
EndpointRestriction.endpoint_code == "TEST001",
db=db_session
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
result.data is not None and
result.data.endpoint_code == "TEST001" and
result.is_list is False and
isinstance(result.data_as_dict, dict) and
result.data_as_dict.get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_all():
"""Test filtering multiple records by expressions."""
print("\nTesting filter_all...")
with EndpointRestriction.new_session() as db_session:
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction()
endpoint2 = EndpointRestriction.find_or_create(
endpoint_function="test_function2",
endpoint_name="Test Endpoint 2",
endpoint_method="POST",
endpoint_desc="Test Description 2",
endpoint_code="TEST002",
is_confirmed=True,
expiry_starts=arrow.now().shift(days=-1),
expiry_ends=arrow.now().shift(days=1)
)
result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method.in_(["GET", "POST"]),
db=db_session
)
# Test PostgresResponse properties
success = (
result.count == 2 and
result.total_count == 2 and
len(result.data) == 2 and
{r.endpoint_code for r in result.data} == {"TEST001", "TEST002"} and
result.is_list is True and
isinstance(result.data_as_dict, list) and
len(result.data_as_dict) == 2
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_all_system():
"""Test filtering multiple records by expressions without status filtering."""
print("\nTesting filter_all_system...")
with EndpointRestriction.new_session() as db_session:
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction()
endpoint2 = EndpointRestriction.find_or_create(
endpoint_function="test_function2",
endpoint_name="Test Endpoint 2",
endpoint_method="POST",
endpoint_desc="Test Description 2",
endpoint_code="TEST002",
is_confirmed=True,
expiry_starts=arrow.now().shift(days=-1),
expiry_ends=arrow.now().shift(days=1)
)
result = EndpointRestriction.filter_all_system(
EndpointRestriction.endpoint_method.in_(["GET", "POST"]),
db=db_session
)
# Test PostgresResponse properties
success = (
result.count == 2 and
result.total_count == 2 and
len(result.data) == 2 and
{r.endpoint_code for r in result.data} == {"TEST001", "TEST002"} and
result.is_list is True and
isinstance(result.data_as_dict, list) and
len(result.data_as_dict) == 2
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_filter_by_all_system():
"""Test filtering multiple records by keyword arguments."""
print("\nTesting filter_by_all_system...")
with EndpointRestriction.new_session() as db_session:
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction()
endpoint2 = EndpointRestriction.find_or_create(
endpoint_function="test_function2",
endpoint_name="Test Endpoint 2",
endpoint_method="POST",
endpoint_desc="Test Description 2",
endpoint_code="TEST002",
is_confirmed=True,
expiry_starts=arrow.now().shift(days=-1),
expiry_ends=arrow.now().shift(days=1)
)
result = EndpointRestriction.filter_by_all_system(
db=db_session,
endpoint_method="POST"
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
len(result.data) == 1 and
result.data[0].endpoint_code == "TEST002" and
result.is_list is True and
isinstance(result.data_as_dict, list) and
len(result.data_as_dict) == 1
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_get_not_expired_query_arg():
"""Test expiry date filtering in query arguments."""
print("\nTesting get_not_expired_query_arg...")
with EndpointRestriction.new_session() as db_session:
# Create active and expired endpoints
active_endpoint = create_sample_endpoint_restriction()
expired_endpoint = EndpointRestriction.find_or_create(
endpoint_function="expired_function",
endpoint_name="Expired Endpoint",
endpoint_method="GET",
endpoint_desc="Expired Description",
endpoint_code="EXP001",
is_confirmed=True,
expiry_starts=arrow.now().shift(days=-2),
expiry_ends=arrow.now().shift(days=-1)
)
result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_code.in_(["TEST001", "EXP001"]),
db=db_session
)
# Test PostgresResponse properties
success = (
result.count == 1 and
result.total_count == 1 and
len(result.data) == 1 and
result.data[0].endpoint_code == "TEST001" and
result.is_list is True and
isinstance(result.data_as_dict, list) and
len(result.data_as_dict) == 1 and
result.data_as_dict[0].get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def test_add_new_arg_to_args():
"""Test adding new arguments to query arguments."""
print("\nTesting add_new_arg_to_args...")
args = (EndpointRestriction.endpoint_code == "TEST001",)
new_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg)
success = len(updated_args) == 2
# Test duplicate prevention
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg)
success = success and len(updated_args) == 2 # Should not add duplicate
print(f"Test {'passed' if success else 'failed'}")
return success
def test_produce_query_to_add():
"""Test adding query parameters to filter options."""
print("\nTesting produce_query_to_add...")
with EndpointRestriction.new_session() as db_session:
sample_endpoint = create_sample_endpoint_restriction()
filter_list = {
"query": {
"endpoint_method": "GET",
"endpoint_code": "TEST001"
}
}
args = ()
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
success = len(updated_args) == 2
result = EndpointRestriction.filter_all(
*updated_args,
db=db_session
)
# Test PostgresResponse properties
success = (
success and
result.count == 1 and
result.total_count == 1 and
len(result.data) == 1 and
result.data[0].endpoint_code == "TEST001" and
result.is_list is True and
isinstance(result.data_as_dict, list) and
len(result.data_as_dict) == 1 and
result.data_as_dict[0].get("endpoint_code") == "TEST001"
)
print(f"Test {'passed' if success else 'failed'}")
return success
def run_all_tests():
"""Run all tests and report results."""
print("Starting EndpointRestriction tests...")
tests = [
test_filter_by_one,
test_filter_by_one_system,
test_filter_one,
test_filter_one_system,
test_filter_all,
test_filter_all_system,
test_filter_by_all_system,
test_get_not_expired_query_arg,
test_add_new_arg_to_args,
test_produce_query_to_add
]
passed = 0
failed = 0
for test in tests:
if test():
passed += 1
else:
failed += 1
print(f"\nTest Summary:")
print(f"Total tests: {len(tests)}")
print(f"Passed: {passed}")
print(f"Failed: {failed}")
if __name__ == "__main__":
run_all_tests()

View File

@@ -0,0 +1,153 @@
import arrow
from sqlalchemy import (
TIMESTAMP,
func,
text,
UUID,
String,
Integer,
Boolean,
SmallInteger,
)
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy_mixins.serialize import SerializeMixin
from sqlalchemy_mixins.repr import ReprMixin
from sqlalchemy_mixins.smartquery import SmartQueryMixin
from base import BaseAlchemyModel
from crud import CRUDModel
from filter import QueryModel
from database import Base
class BasicMixin(
Base,
BaseAlchemyModel,
CRUDModel,
SerializeMixin,
ReprMixin,
SmartQueryMixin,
QueryModel,
):
__abstract__ = True
__repr__ = ReprMixin.__repr__
class CrudMixin(BasicMixin):
"""
Base mixin providing CRUD operations and common fields for PostgreSQL models.
Features:
- Automatic timestamps (created_at, updated_at)
- Soft delete capability
- User tracking (created_by, updated_by)
- Data serialization
- Multi-language support
"""
__abstract__ = True
# Primary and reference fields
id: Mapped[int] = mapped_column(Integer, primary_key=True)
uu_id: Mapped[str] = mapped_column(
UUID,
server_default=text("gen_random_uuid()"),
index=True,
unique=True,
comment="Unique identifier UUID",
)
# Common timestamp fields for all models
expiry_starts: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
comment="Record validity start timestamp",
)
expiry_ends: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
default=str(arrow.get("2099-12-31")),
server_default=func.now(),
comment="Record validity end timestamp",
)
class CrudCollection(CrudMixin):
"""
Full-featured model class with all common fields.
Includes:
- UUID and reference ID
- Timestamps
- User tracking
- Confirmation status
- Soft delete
- Notification flags
"""
__abstract__ = True
__repr__ = ReprMixin.__repr__
ref_id: Mapped[str] = mapped_column(
String(100), nullable=True, index=True, comment="External reference ID"
)
# Timestamps
created_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True,
comment="Record creation timestamp",
)
updated_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
index=True,
comment="Last update timestamp",
)
# Cryptographic and user tracking
cryp_uu_id: Mapped[str] = mapped_column(
String, nullable=True, index=True, comment="Cryptographic UUID"
)
created_by: Mapped[str] = mapped_column(
String, nullable=True, comment="Creator name"
)
created_by_id: Mapped[int] = mapped_column(
Integer, nullable=True, comment="Creator ID"
)
updated_by: Mapped[str] = mapped_column(
String, nullable=True, comment="Last modifier name"
)
updated_by_id: Mapped[int] = mapped_column(
Integer, nullable=True, comment="Last modifier ID"
)
confirmed_by: Mapped[str] = mapped_column(
String, nullable=True, comment="Confirmer name"
)
confirmed_by_id: Mapped[int] = mapped_column(
Integer, nullable=True, comment="Confirmer ID"
)
# Status flags
is_confirmed: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Record confirmation status"
)
replication_id: Mapped[int] = mapped_column(
SmallInteger, server_default="0", comment="Replication identifier"
)
deleted: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Soft delete flag"
)
active: Mapped[bool] = mapped_column(
Boolean, server_default="1", comment="Record active status"
)
is_notification_send: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Notification sent flag"
)
is_email_send: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Email sent flag"
)

View File

@@ -0,0 +1,109 @@
"""
Response handler for PostgreSQL query results.
This module provides a wrapper class for SQLAlchemy query results,
adding convenience methods for accessing data and managing query state.
"""
from typing import Any, Dict, Optional, TypeVar, Generic, Union
from sqlalchemy.orm import Query
T = TypeVar("T")
class PostgresResponse(Generic[T]):
"""
Wrapper for PostgreSQL/SQLAlchemy query results.
Attributes:
metadata: Additional metadata for the query
Properties:
count: Total count of results
query: Get query object
as_dict: Convert response to dictionary format
"""
def __init__(
self,
pre_query: Query,
query: Query,
model,
is_array: bool = True,
metadata: Any = None,
):
self._core_class = model
self._is_list = is_array
self._query = query
self._pre_query = pre_query
self._count: Optional[int] = None
self.metadata = metadata
@property
def core_class(self):
"""Get query object."""
return self._core_class
@property
def data(self) -> Union[T, list[T]]:
"""Get query results."""
if not self.is_list:
first_item = self._query.first()
return first_item if first_item else None
return self._query.all() if self._query.all() else []
@property
def data_as_dict(self) -> Union[Dict[str, Any], list[Dict[str, Any]]] | None:
"""Get query results as dictionary."""
if not self.count:
return None
if self.is_list:
first_item = self._query.first()
return first_item.get_dict() if first_item.first() else None
all_items = self._query.all()
return [result.get_dict() for result in all_items] if all_items else []
@property
def total_count(self) -> int:
"""Lazy load and return total count of results."""
return self._pre_query.count() if self._pre_query else 0
@property
def count(self) -> int:
"""Lazy load and return total count of results."""
return self._query.count()
@property
def query(self) -> str:
"""Get query object."""
return str(self._query)
@property
def core_query(self) -> Query:
"""Get query object."""
return self._query
@property
def is_list(self) -> bool:
"""Check if response is a list."""
return self._is_list
@property
def as_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary format."""
if isinstance(self.data, list):
return {
"metadata": self.metadata,
"is_list": self._is_list,
"query": str(self.query),
"count": self.count,
"data": [result.get_dict() for result in self.data],
}
return {
"metadata": self.metadata,
"is_list": self._is_list,
"query": str(self.query),
"count": self.count,
"data": self.data.get_dict() if self.data else {},
}

View File

@@ -0,0 +1,30 @@
from sqlalchemy import String
from sqlalchemy.orm import mapped_column, Mapped
from mixin import CrudCollection
class EndpointRestriction(CrudCollection):
"""
Initialize Endpoint Restriction with default values
"""
__tablename__ = "endpoint_restriction"
__exclude__fields__ = []
endpoint_function: Mapped[str] = mapped_column(
String, server_default="", comment="Function name of the API endpoint"
)
endpoint_name: Mapped[str] = mapped_column(
String, server_default="", comment="Name of the API endpoint"
)
endpoint_method: Mapped[str] = mapped_column(
String, server_default="", comment="HTTP method used by the endpoint"
)
endpoint_desc: Mapped[str] = mapped_column(
String, server_default="", comment="Description of the endpoint"
)
endpoint_code: Mapped[str] = mapped_column(
String, server_default="", unique=True, comment="Unique code for the endpoint"
)