updated Empty Runner

This commit is contained in:
2025-03-22 00:30:51 +03:00
parent d9dd8ac244
commit adfa5868a0
45 changed files with 9370 additions and 1 deletions

View File

@@ -0,0 +1,25 @@
from typing import Any, Dict, List, Optional
from functools import wraps
# from pymongo.errors import (
# ConnectionFailure,
# OperationFailure,
# ServerSelectionTimeoutError,
# PyMongoError,
# )
def mongo_error_wrapper(func):
"""Decorator to handle MongoDB operation errors.
Catches MongoDB-specific errors and converts them to HTTPExceptionApi.
"""
@wraps(func)
def wrapper(*args, **kwargs):
"""
:param args:
:param kwargs:
:return:
"""
return func(*args, **kwargs)
return wrapper

View File

@@ -0,0 +1,178 @@
from typing import Optional, Dict, Any, List, TypeVar, Iterator
from contextlib import contextmanager
from pymongo import MongoClient
from pymongo.collection import Collection
from Configs.mongo import MongoConfig
from Services.MongoService.handlers import mongo_error_wrapper
class MongoBase:
"""Base class for MongoDB connection and operations."""
collection: Collection = None
class MongoErrorHandler:
"""Error handler for MongoDB operations."""
...
class MongoInsertMixin(MongoBase):
"""Mixin for MongoDB insert operations."""
def insert_one(self, document: Dict[str, Any]):
"""Insert a single document into the collection."""
return self.collection.insert_one(document)
def insert_many(self, documents: List[Dict[str, Any]]):
"""Insert multiple documents."""
return self.collection.insert_many(documents)
class MongoFindMixin(MongoBase):
"""Mixin for MongoDB find operations."""
@mongo_error_wrapper
def find_one(
self,
filter_query: Dict[str, Any],
projection: Optional[Dict[str, Any]] = None,
):
"""Find a single document in the collection."""
return self.collection.find_one(filter_query, projection)
@mongo_error_wrapper
def find_many(
self,
filter_query: Dict[str, Any],
projection: Optional[Dict[str, Any]] = None,
sort: Optional[List[tuple[str, int]]] = None,
limit: Optional[int] = None,
skip: Optional[int] = None,
):
"""Find multiple documents in the collection with pagination support."""
cursor = self.collection.find(filter_query, projection)
if sort:
cursor = cursor.sort(sort)
if skip:
cursor = cursor.skip(skip)
if limit:
cursor = cursor.limit(limit)
return list(cursor)
class MongoUpdateMixin(MongoBase):
"""Mixin for MongoDB update operations."""
@mongo_error_wrapper
def update_one(
self,
filter_query: Dict[str, Any],
update_data: Dict[str, Any],
upsert: bool = False,
):
"""Update a single document in the collection."""
return self.collection.update_one(filter_query, update_data, upsert=upsert)
@mongo_error_wrapper
def update_many(
self,
filter_query: Dict[str, Any],
update_data: Dict[str, Any],
upsert: bool = False,
):
"""Update multiple documents in the collection."""
return self.collection.update_many(filter_query, update_data, upsert=upsert)
class MongoDeleteMixin(MongoBase):
"""Mixin for MongoDB delete operations."""
@mongo_error_wrapper
def delete_one(self, filter_query: Dict[str, Any]):
"""Delete a single document from the collection."""
return self.collection.delete_one(filter_query)
@mongo_error_wrapper
def delete_many(self, filter_query: Dict[str, Any]):
"""Delete multiple documents from the collection."""
return self.collection.delete_many(filter_query)
class MongoAggregateMixin(MongoBase):
"""Mixin for MongoDB aggregation operations."""
@mongo_error_wrapper
def aggregate(self, collection: Collection, pipeline: List[Dict[str, Any]]):
"""Execute an aggregation pipeline on the collection."""
result = collection.aggregate(pipeline)
return result
class MongoProvider(
MongoUpdateMixin,
MongoInsertMixin,
MongoFindMixin,
MongoDeleteMixin,
MongoAggregateMixin,
):
"""Main MongoDB actions class that inherits all CRUD operation mixins.
This class provides a unified interface for all MongoDB operations while
managing collections based on company UUID and storage reason.
"""
def __init__(
self, client: MongoClient, database: str, storage_reason: list[str]
):
"""Initialize MongoDB actions with client and collection info.
Args:
client: MongoDB client
database: Database name to use
storage_reason: Storage reason for collection naming
"""
self.delimiter = "|"
self._client = client
self._database = database
self._storage_reason: list[str] = storage_reason
self._collection = None
self.use_collection(storage_reason)
@staticmethod
@contextmanager
def mongo_client() -> Iterator[MongoClient]:
"""
Context provider for MongoDB test client.
# Example Usage
with mongo_client() as client:
db = client["your_database"]
print(db.list_collection_names())
"""
client = MongoClient(MongoConfig.URL)
try:
client.admin.command("ping") # Test connection
yield client
finally:
client.close() # Ensure proper cleanup
@property
def collection(self) -> Collection:
"""Get current MongoDB collection."""
return self._collection
def use_collection(self, storage_name_list: list[str]) -> None:
"""Switch to a different collection.
Args:
storage_name_list: New storage reason for collection naming
"""
collection_name = ""
for each_storage_reason in storage_name_list:
if self.delimiter in str(each_storage_reason):
raise ValueError(f"Storage reason cannot contain delimiter : {self.delimiter}")
collection_name += f"{self.delimiter}{each_storage_reason}"
collection_name = collection_name[1:]
self._collection = self._client[self._database][collection_name]

View File

@@ -0,0 +1,108 @@
from typing import Type, TypeVar
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from fastapi import status
from fastapi.exceptions import HTTPException
from Services.PostgresService.database import get_db
# Type variable for class methods returning self
T = TypeVar("T", bound="FilterAttributes")
class BaseAlchemyModel:
"""
Controller of alchemy to database transactions.
Query: Query object for model
Session: Session object for model
Actions: save, flush, rollback, commit
"""
__abstract__ = True
@classmethod
def new_session(cls) -> Session:
"""Get database session."""
with get_db() as session:
return session
@classmethod
def flush(cls: Type[T], db: Session) -> T:
"""
Flush the current session to the database.
Args:
db: Database session
Returns:
Self instance
Raises:
HTTPException: If database operation fails
"""
try:
db.flush()
return cls
except SQLAlchemyError as e:
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail={
"message": "Database operation failed",
},
)
def destroy(self: Type[T], db: Session) -> None:
"""
Delete the record from the database.
Args:
db: Database session
"""
db.delete(self)
@classmethod
def save(cls: Type[T], db: Session) -> None:
"""
Commit changes to database.
Args:
db: Database session
Raises:
HTTPException: If commit fails
"""
try:
db.commit()
db.flush()
except SQLAlchemyError as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_406_NOT_ACCEPTABLE,
detail={
"message": "Alchemy save operation failed",
"error": str(e),
},
)
except Exception as e:
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"message": "Unknown exception raised.",
"error": str(e),
},
)
@classmethod
def rollback(cls: Type[T], db: Session) -> None:
"""
Rollback current transaction.
Args:
db: Database session
"""
db.rollback()

View File

@@ -0,0 +1,249 @@
import arrow
import datetime
from typing import Optional, Any, Dict
from sqlalchemy.orm import Session, Mapped
from pydantic import BaseModel
from fastapi.exceptions import HTTPException
from decimal import Decimal
from sqlalchemy import TIMESTAMP, NUMERIC
from sqlalchemy.orm.attributes import InstrumentedAttribute
class Credentials(BaseModel):
"""
Class to store user credentials.
"""
person_id: int
person_name: str
full_name: Optional[str] = None
class MetaData:
"""
Class to store metadata for a query.
"""
created: bool = False
updated: bool = False
class CRUDModel:
__abstract__ = True
creds: Credentials = None
meta_data: MetaData = MetaData()
@classmethod
def create_credentials(cls, record_created) -> None:
"""
Save user credentials for tracking.
Args:
record_created: Record that created or updated
"""
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
record_created.created_by_id = cls.creds.person_id
record_created.created_by = cls.creds.person_name
return
@classmethod
def raise_exception(cls):
raise HTTPException(
status_code=400,
detail={
"message": "Exception raised.",
},
)
@classmethod
def create_or_abort(cls, db: Session, **kwargs):
"""
Create a new record or abort if it already exists.
Args:
db: Database session
**kwargs: Record fields
Returns:
New record if successfully created
"""
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()), cls.expiry_starts <= str(arrow.now()),
)
for key, value in kwargs.items():
if hasattr(cls, key):
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
# Handle existing record
if already_record or already_record.deleted:
cls.raise_exception()
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
return created_record
@classmethod
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
"""
Process a field value based on its type and convert it to the appropriate format.
Args:
val: Field value
key: Field name
Returns:
Tuple of (should_include, processed_value)
"""
key_ = cls.__annotations__.get(key, None)
is_primary = key in cls.primary_keys
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
if is_primary or row_attr:
return False, None
if val is None: # Handle None values
return True, None
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
return True, str(val)
if key_: # Handle typed fields
if key_ == Mapped[int]:
return True, int(val)
elif key_ == Mapped[bool]:
return True, bool(val)
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
return True, round(float(val), 3)
elif isinstance(val, int):
return True, int(val)
elif isinstance(val, str):
return True, str(val)
elif val is None:
return True, None
return False, None
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Returns:
Dictionary representation of the model
Dictionary returns only UUID fields and fields that are not in exclude_list
"""
return_dict: Dict[str, Any] = {} # Handle default field selection
exclude_list = exclude_list or []
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
columns_set = set(self.columns)
columns_list = set([col for col in list(columns_set) if str(col)[-2:] != "id"])
columns_extend = set(col for col in list(columns_set) if str(col)[-5:].lower() == "uu_id")
columns_list = set(columns_list) | set(columns_extend)
columns_list = list(set(columns_list) - set(exclude_list))
for key in columns_list:
val = getattr(self, key)
correct, value_of_database = self.iterate_over_variables(val, key)
if correct:
return_dict[key] = value_of_database
return return_dict
@classmethod
def find_or_create(
cls, db: Session, exclude_args: Optional[list[InstrumentedAttribute]] = None, **kwargs
):
"""
Find an existing record matching the criteria or create a new one.
Args:
db: Database session
exclude_args: Keys to exclude from search
**kwargs: Search/creation criteria
Returns:
Existing or newly created record
"""
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()), cls.expiry_starts <= str(arrow.now()),
)
exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
for key, value in kwargs.items():
if hasattr(cls, key) and key not in exclude_args:
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
if already_record: # Handle existing record
cls.meta_data.created = False
return already_record
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
cls.meta_data.created = True
return created_record
def update(self, db: Session, **kwargs):
"""
Update the record with new values.
Args:
db: Database session
**kwargs: Fields to update
Returns:
Updated record
"""
for key, value in kwargs.items():
setattr(self, key, value)
self.update_credentials()
try:
db.flush()
self.meta_data.updated = True
except Exception as e:
print('Error:', e)
self.meta_data.updated = False
db.rollback()
return self
def update_credentials(self) -> None:
"""
Save user credentials for tracking.
"""
# Update confirmation or modification tracking
person_id = getattr(self.creds, "person_id", None)
person_name = getattr(self.creds, "person_name", None)
if person_id and person_name:
self.updated_by_id = self.creds.person_id
self.updated_by = self.creds.person_name
return

View File

@@ -0,0 +1,202 @@
"""
Advanced filtering functionality for SQLAlchemy models.
This module provides a comprehensive set of filtering capabilities for SQLAlchemy models,
including pagination, ordering, and complex query building.
"""
from __future__ import annotations
import arrow
from typing import Any, TypeVar, Type, Union, Optional
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.elements import BinaryExpression
from Services.PostgresService.controllers.response_controllers import PostgresResponse
T = TypeVar("T", bound="QueryModel")
class QueryModel:
pre_query = None
__abstract__ = True
@classmethod
def _query(cls: Type[T], db: Session) -> Query:
"""Returns the query to use in the model."""
return cls.pre_query if cls.pre_query else db.query(cls)
@classmethod
def add_new_arg_to_args(cls: Type[T], args_list, argument, value):
new_arg_list = list(
set(
args_
for args_ in list(args_list)
if isinstance(args_, BinaryExpression)
)
)
arg_left = lambda arg_obj: getattr(getattr(arg_obj, "left", None), "key", None)
# arg_right = lambda arg_obj: getattr(getattr(arg_obj, "right", None), "value", None)
if not any(True for arg in new_arg_list if arg_left(arg_obj=arg) == argument):
new_arg_list.append(value)
return tuple(new_arg_list)
@classmethod
def get_not_expired_query_arg(cls: Type[T], arg):
"""Add expiry_starts and expiry_ends to the query."""
starts = cls.expiry_starts <= str(arrow.now())
ends = cls.expiry_ends > str(arrow.now())
arg = cls.add_new_arg_to_args(arg, "expiry_ends", ends)
arg = cls.add_new_arg_to_args(arg, "expiry_starts", starts)
return arg
@classmethod
def produce_query_to_add(cls: Type[T], filter_list):
"""
Adds query to main filter options
Args:
filter_list:
"""
if filter_list.get("query"):
for smart_iter in cls.filter_expr(**filter_list["query"]):
if key := getattr(getattr(smart_iter, "left", None), "key", None):
args = cls.add_new_arg_to_args(args, key, smart_iter)
@classmethod
def convert(
cls: Type[T], smart_options: dict, validate_model: Any = None
) -> Optional[tuple[BinaryExpression]]:
if not validate_model:
return tuple(cls.filter_expr(**smart_options))
@classmethod
def filter_by_one(
cls: Type[T], db: Session, system: bool = False, **kwargs
) -> PostgresResponse:
"""
Filter single record by keyword arguments.
Args:
db: Database session
system: If True, skip status filtering
**kwargs: Filter criteria
Returns:
Query response with single record
"""
if "is_confirmed" not in kwargs and not system:
kwargs["is_confirmed"] = True
kwargs.pop("system", None)
query = cls._query(db).filter_by(**kwargs)
return PostgresResponse(
model=cls, pre_query=cls._query(db), query=query, is_array=False
)
@classmethod
def filter_one(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse:
"""
Filter single record by expressions.
Args:
db: Database session
args: Filter expressions
Returns:
Query response with single record
"""
args = cls.get_not_expired_query_arg(args)
query = db.query(cls).filter(*args)
pre_query = cls._query(db=db).filter(*args)
return PostgresResponse(
model=cls, pre_query=pre_query, query=query, is_array=False
)
@classmethod
def filter_one_system(
cls,
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
):
"""
Filter single record by expressions without status filtering
Args:
*args:
db:
Returns:
Query response with single record
"""
query = cls._query(db=db).filter(*args)
return PostgresResponse(
model=cls, pre_query=cls._query(db=db), query=query, is_array=False
)
@classmethod
def filter_all_system(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse:
"""
Filter multiple records by expressions without status filtering.
Args:
db: Database session
args: Filter expressions
Returns:
Query response with matching records
"""
query = cls._query(db)
query = query.filter(*args)
return PostgresResponse(
model=cls, pre_query=cls._query(db), query=query, is_array=True
)
@classmethod
def filter_all(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse:
"""
Filter multiple records by expressions.
Args:
db: Database session
args: Filter expressions
Returns:
Query response with matching records
"""
args = cls.get_not_expired_query_arg(args)
pre_query = cls._query(db=db)
query = cls._query(db=db).filter(*args)
return PostgresResponse(
model=cls, pre_query=pre_query, query=query, is_array=True
)
@classmethod
def filter_by_all_system(cls: Type[T], db: Session, **kwargs) -> PostgresResponse:
"""
Filter multiple records by keyword arguments.
Args:
db: Database session
**kwargs: Filter criteria
Returns:
Query response with matching records
"""
query = cls._query(db).filter_by(**kwargs)
return PostgresResponse(
model=cls, pre_query=cls._query(db), query=query, is_array=True
)

View File

@@ -0,0 +1,135 @@
import arrow
from sqlalchemy import (
TIMESTAMP,
func,
text,
UUID,
String,
Integer,
Boolean,
SmallInteger,
)
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy_mixins.serialize import SerializeMixin
from sqlalchemy_mixins.repr import ReprMixin
from sqlalchemy_mixins.smartquery import SmartQueryMixin
from Services.PostgresService.controllers.core_controllers import BaseAlchemyModel
from Services.PostgresService.controllers.crud_controllers import CRUDModel
from Services.PostgresService.controllers.filter_controllers import QueryModel
from Services.PostgresService.database import Base
class BasicMixin(
Base,
BaseAlchemyModel,
CRUDModel,
SerializeMixin,
ReprMixin,
SmartQueryMixin,
QueryModel,
):
__abstract__ = True
__repr__ = ReprMixin.__repr__
class CrudMixin(BasicMixin):
"""
Base mixin providing CRUD operations and common fields for PostgreSQL models.
Features:
- Automatic timestamps (created_at, updated_at)
- Soft delete capability
- User tracking (created_by, updated_by)
- Data serialization
- Multi-language support
"""
__abstract__ = True
# Primary and reference fields
id: Mapped[int] = mapped_column(Integer, primary_key=True)
uu_id: Mapped[str] = mapped_column(
UUID,
server_default=text("gen_random_uuid()"),
index=True,
unique=True,
comment="Unique identifier UUID",
)
# Common timestamp fields for all models
expiry_starts: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
comment="Record validity start timestamp",
)
expiry_ends: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
default=str(arrow.get("2099-12-31")),
server_default=func.now(),
comment="Record validity end timestamp",
)
class CrudCollection(CrudMixin):
"""
Full-featured model class with all common fields.
Includes:
- UUID and reference ID
- Timestamps
- User tracking
- Confirmation status
- Soft delete
- Notification flags
"""
__abstract__ = True
__repr__ = ReprMixin.__repr__
ref_id: Mapped[str] = mapped_column(
String(100), nullable=True, index=True, comment="External reference ID"
)
# Timestamps
created_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True,
comment="Record creation timestamp",
)
updated_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
index=True,
comment="Last update timestamp",
)
created_by: Mapped[str] = mapped_column(
String, nullable=True, comment="Creator name"
)
created_by_id: Mapped[int] = mapped_column(
Integer, nullable=True, comment="Creator ID"
)
updated_by: Mapped[str] = mapped_column(
String, nullable=True, comment="Last modifier name"
)
updated_by_id: Mapped[int] = mapped_column(
Integer, nullable=True, comment="Last modifier ID"
)
# Status flags
replication_id: Mapped[int] = mapped_column(
SmallInteger, server_default="0", comment="Replication identifier"
)
deleted: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Soft delete flag"
)
active: Mapped[bool] = mapped_column(
Boolean, server_default="1", comment="Record active status"
)

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from sqlalchemy import desc, asc
from pydantic import BaseModel
# from application.validations.request.list_options.list_options import ListOptions
from Services.PostgresService.controllers.response_controllers import PostgresResponse
from Configs.api import ApiConfigs
class ListOptions:
...
class PaginationConfig(BaseModel):
"""
Configuration for pagination settings.
Attributes:
page: Current page number (default: 1)
size: Items per page (default: 10)
order_field: Field to order by (default: "id")
order_type: Order direction (default: "asc")
"""
page: int = 1
size: int = 10
order_field: Optional[Union[tuple[str], list[str]]] = None
order_type: Optional[Union[tuple[str], list[str]]] = None
def __init__(self, **data):
super().__init__(**data)
if self.order_field is None:
self.order_field = ["uu_id"]
if self.order_type is None:
self.order_type = ["asc"]
class Pagination:
"""
Handles pagination logic for query results.
Manages page size, current page, ordering, and calculates total pages
and items based on the data source.
Attributes:
DEFAULT_SIZE: Default number of items per page
MIN_SIZE: Minimum allowed page size
MAX_SIZE: Maximum allowed page size
"""
DEFAULT_SIZE: int = int(ApiConfigs.DEFAULT_SIZE or 10)
MIN_SIZE: int = int(ApiConfigs.MIN_SIZE or 5)
MAX_SIZE: int = int(ApiConfigs.MAX_SIZE or 50)
def __init__(self, data: PostgresResponse):
self._data = data
self.size: int = self.DEFAULT_SIZE
self.page: int = 1
self.orderField: Optional[Union[tuple[str], list[str]]] = ["uu_id"]
self.orderType: Optional[Union[tuple[str], list[str]]] = ["asc"]
self.page_count: int = 1
self.total_pages: int = 1
self.total_count: int = self._data.count
self.all_count: int = self._data.total_count
self._update_page_counts()
def change(self, config: PaginationConfig) -> None:
"""Update pagination settings from config."""
self.size = (
config.size
if self.MIN_SIZE <= config.size <= self.MAX_SIZE
else self.DEFAULT_SIZE
)
self.page = config.page
self.orderField = config.order_field
self.orderType = config.order_type
self._update_page_counts()
def feed(self, data: PostgresResponse) -> None:
"""Calculate pagination based on data source."""
self._data = data
self._update_page_counts()
def _update_page_counts(self) -> None:
"""Update page counts and validate current page."""
self.total_count = self._data.count
self.all_count = self._data.total_count
self.size = (
self.size
if self.MIN_SIZE <= self.size <= self.MAX_SIZE
else self.DEFAULT_SIZE
)
self.total_pages = max(1, (self.total_count + self.size - 1) // self.size)
self.page = max(1, min(self.page, self.total_pages))
self.page_count = (
self.total_count % self.size
if self.page == self.total_pages and self.total_count % self.size
else self.size
)
def refresh(self) -> None:
"""Reset pagination state to defaults."""
self._update_page_counts()
def reset(self) -> None:
"""Reset pagination state to defaults."""
self.size = self.DEFAULT_SIZE
self.page = 1
self.orderField = "uu_id"
self.orderType = "asc"
def as_dict(self) -> Dict[str, Any]:
"""Convert pagination state to dictionary format."""
self.refresh()
return {
"size": self.size,
"page": self.page,
"allCount": self.all_count,
"totalCount": self.total_count,
"totalPages": self.total_pages,
"pageCount": self.page_count,
"orderField": self.orderField,
"orderType": self.orderType,
}
class PaginationResult:
"""
Result of a paginated query.
Contains the query result and pagination state.
data: PostgresResponse of query results
pagination: Pagination state
Attributes:
_query: Original query object
pagination: Pagination state
"""
def __init__(
self, data: PostgresResponse, pagination: Pagination, response_model: Any = None
):
self._data = data
self._query = data.query
self._core_query = data.core_query
self.pagination = pagination
self.response_type = data.is_list
self.limit = self.pagination.size
self.offset = self.pagination.size * (self.pagination.page - 1)
self.order_by = self.pagination.orderField
self.order_type = self.pagination.orderType
self.response_model = response_model
def dynamic_order_by(self):
"""
Dynamically order a query by multiple fields.
Returns:
Ordered query object.
"""
if not len(self.order_by) == len(self.order_type):
raise ValueError(
"Order by fields and order types must have the same length."
)
order_criteria = zip(self.order_by, self.order_type)
print('order_criteria', order_criteria)
if not self._data.data:
return self._core_query
for field, direction in order_criteria:
print('field', field, direction)
columns = self._data.data[0].filterable_attributes
print('columns', columns)
if field in columns:
if direction.lower().startswith("d"):
self._core_query = self._core_query.order_by(
desc(
getattr(self._core_query.column_descriptions[0]["entity"], field)
)
)
else:
self._core_query = self._core_query.order_by(
asc(
getattr(self._core_query.column_descriptions[0]["entity"], field)
)
)
return self._core_query
@property
def data(self) -> Union[list | dict]:
"""Get query object."""
query_ordered = self.dynamic_order_by()
query_paginated = query_ordered.limit(self.limit).offset(self.offset)
queried_data = (
query_paginated.all() if self.response_type else query_paginated.first()
)
data = (
[result.get_dict() for result in queried_data]
if self.response_type
else queried_data.get_dict()
)
if self.response_model:
return [self.response_model(**item).model_dump() for item in data]
return data
class QueryOptions:
def __init__(
self,
table,
data: Union[dict, ListOptions] = None,
model_query: Optional[Any] = None,
):
self.table = table
self.data = data
self.model_query = model_query
if isinstance(data, dict):
self.data = ListOptions(**data)
self.validate_query()
if not self.data.order_type:
self.data.order_type = ["created_at"]
if not self.data.order_field:
self.data.order_field = ["uu_id"]
def validate_query(self):
if not self.data.query or not self.model_query:
return ()
cleaned_query, cleaned_query_by_model, last_dict = {}, {}, {}
for key, value in self.data.query.items():
cleaned_query[str(str(key).split("__")[0])] = value
cleaned_query_by_model[str(str(key).split("__")[0])] = (key, value)
cleaned_model = self.model_query(**cleaned_query)
for i in cleaned_query:
if hasattr(cleaned_model, i):
last_dict[str(cleaned_query_by_model[str(i)][0])] = str(
cleaned_query_by_model[str(i)][1]
)
self.data.query = last_dict
def convert(self) -> tuple:
"""
self.table.convert(query)
(<sqlalchemy.sql.elements.BinaryExpression object at 0x7caaeacf0080>, <sqlalchemy.sql.elements.BinaryExpression object at 0x7caaea729b80>)
"""
if not self.data or self.data.query:
return ()
return tuple(self.table.convert(self.data.query))

View File

@@ -0,0 +1,109 @@
"""
Response handler for PostgreSQL query results.
This module provides a wrapper class for SQLAlchemy query results,
adding convenience methods for accessing data and managing query state.
"""
from typing import Any, Dict, Optional, TypeVar, Generic, Union
from sqlalchemy.orm import Query
T = TypeVar("T")
class PostgresResponse(Generic[T]):
"""
Wrapper for PostgreSQL/SQLAlchemy query results.
Attributes:
metadata: Additional metadata for the query
Properties:
count: Total count of results
query: Get query object
as_dict: Convert response to dictionary format
"""
def __init__(
self,
pre_query: Query,
query: Query,
model,
is_array: bool = True,
metadata: Any = None,
):
self._core_class = model
self._is_list = is_array
self._query = query
self._pre_query = pre_query
self._count: Optional[int] = None
self.metadata = metadata
@property
def core_class(self):
"""Get query object."""
return self._core_class
@property
def data(self) -> Union[T, list[T]]:
"""Get query results."""
if not self.is_list:
first_item = self._query.first()
return first_item if first_item else None
return self._query.all() if self._query.all() else []
@property
def data_as_dict(self) -> Union[Dict[str, Any], list[Dict[str, Any]]] | None:
"""Get query results as dictionary."""
if not self.count:
return None
if self.is_list:
first_item = self._query.first()
return first_item.get_dict() if first_item.first() else None
all_items = self._query.all()
return [result.get_dict() for result in all_items] if all_items else []
@property
def total_count(self) -> int:
"""Lazy load and return total count of results."""
return self._pre_query.count() if self._pre_query else 0
@property
def count(self) -> int:
"""Lazy load and return total count of results."""
return self._query.count()
@property
def query(self) -> str:
"""Get query object."""
return str(self._query)
@property
def core_query(self) -> Query:
"""Get query object."""
return self._query
@property
def is_list(self) -> bool:
"""Check if response is a list."""
return self._is_list
@property
def as_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary format."""
if isinstance(self.data, list):
return {
"metadata": self.metadata,
"is_list": self._is_list,
"query": str(self.query),
"count": self.count,
"data": [result.get_dict() for result in self.data],
}
return {
"metadata": self.metadata,
"is_list": self._is_list,
"query": str(self.query),
"count": self.count,
"data": self.data.get_dict() if self.data else {},
}

View File

@@ -0,0 +1,62 @@
from contextlib import contextmanager
from functools import lru_cache
from typing import Generator
from Configs.postgres import Database
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Session
# Configure the database engine with proper pooling
engine = create_engine(
Database.DATABASE_URL,
pool_pre_ping=True, # Verify connection before using
pool_size=20, # Maximum number of permanent connections
max_overflow=10, # Maximum number of additional connections
pool_recycle=3600, # Recycle connections after 1 hour
pool_timeout=30, # Wait up to 30 seconds for a connection
echo=True, # Set to True for debugging SQL queries
)
Base = declarative_base()
# Create a cached session factory
@lru_cache()
def get_session_factory() -> scoped_session:
"""Create a thread-safe session factory."""
session_local = sessionmaker(
bind=engine,
autocommit=False,
autoflush=False,
expire_on_commit=False, # Prevent expired object issues
)
return scoped_session(session_local)
# Get database session with proper connection management
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""Get database session with proper connection management.
This context manager ensures:
- Proper connection pooling
- Session cleanup
- Connection return to pool
- Thread safety
Yields:
Session: SQLAlchemy session object
"""
session_factory = get_session_factory()
session = session_factory()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
session_factory.remove() # Clean up the session from the registry