prod-wag-backend-automate-s.../Controllers/Postgres/filter.py

421 lines
13 KiB
Python

"""
Advanced filtering functionality for SQLAlchemy models.
This module provides a comprehensive set of filtering capabilities for SQLAlchemy models,
including pagination, ordering, and complex query building.
"""
from __future__ import annotations
import arrow
from typing import Any, TypeVar, Type, Union, Optional
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Query, Session
from sqlalchemy.sql.elements import BinaryExpression
from Controllers.Postgres.response import PostgresResponse
T = TypeVar("T", bound="QueryModel")
class QueryModel:
__abstract__ = True
pre_query: Optional[Query] = None
@classmethod
def _query(cls: Type[T], db: Session) -> Query:
"""Returns the query to use in the model."""
if cls.pre_query is not None:
# Return the pre_query directly as it's already a Query object
return cls.pre_query
return db.query(cls)
@classmethod
def add_new_arg_to_args(
cls: Type[T],
args_list: tuple[BinaryExpression, ...],
argument: str,
value: BinaryExpression,
) -> tuple[BinaryExpression, ...]:
"""
Add a new argument to the query arguments if it doesn't exist.
Args:
args_list: Existing query arguments
argument: Key of the argument to check for
value: New argument value to add
Returns:
Updated tuple of query arguments
"""
# Convert to set to remove duplicates while preserving order
new_args = list(
dict.fromkeys(arg for arg in args_list if isinstance(arg, BinaryExpression))
)
# Check if argument already exists
if not any(
getattr(getattr(arg, "left", None), "key", None) == argument
for arg in new_args
):
new_args.append(value)
return tuple(new_args)
@classmethod
def get_not_expired_query_arg(
cls: Type[T], args: tuple[BinaryExpression, ...]
) -> tuple[BinaryExpression, ...]:
"""
Add expiry date filtering to the query arguments.
Args:
args: Existing query arguments
Returns:
Updated tuple of query arguments with expiry filters
Raises:
AttributeError: If model does not have expiry_starts or expiry_ends columns
"""
try:
current_time = str(arrow.now())
# Only add expiry filters if they don't already exist
if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
for arg in args
):
ends = cls.expiry_ends > current_time
args = cls.add_new_arg_to_args(args, "expiry_ends", ends)
if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
for arg in args
):
starts = cls.expiry_starts <= current_time
args = cls.add_new_arg_to_args(args, "expiry_starts", starts)
return args
except AttributeError as e:
raise AttributeError(
f"Model {cls.__name__} must have expiry_starts and expiry_ends columns"
) from e
@classmethod
def produce_query_to_add(cls: Type[T], filter_list: dict, args: tuple) -> tuple:
"""
Adds query to main filter options
Args:
filter_list: Dictionary containing query parameters
args: Existing query arguments to add to
Returns:
Updated query arguments tuple
"""
try:
if not filter_list or not isinstance(filter_list, dict):
return args
query_params = filter_list.get("query")
if not query_params or not isinstance(query_params, dict):
return args
for key, value in query_params.items():
if hasattr(cls, key):
# Create a new filter expression
filter_expr = getattr(cls, key) == value
# Add it to args if it doesn't exist
args = cls.add_new_arg_to_args(args, key, filter_expr)
return args
except Exception as e:
print(f"Error in produce_query_to_add: {str(e)}")
return args
@classmethod
def convert(
cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None
) -> Optional[tuple[BinaryExpression, ...]]:
"""
Convert smart options to SQLAlchemy filter expressions.
Args:
smart_options: Dictionary of filter options
validate_model: Optional model to validate against
Returns:
Tuple of SQLAlchemy filter expressions or None if validation fails
"""
try:
# Let SQLAlchemy handle the validation by attempting to create the filter expressions
return tuple(cls.filter_expr(**smart_options))
except Exception as e:
# If there's an error, provide a helpful message with valid columns and relationships
valid_columns = set()
relationship_names = set()
# Get column names if available
if hasattr(cls, '__table__') and hasattr(cls.__table__, 'columns'):
valid_columns = set(column.key for column in cls.__table__.columns)
# Get relationship names if available
if hasattr(cls, '__mapper__') and hasattr(cls.__mapper__, 'relationships'):
relationship_names = set(rel.key for rel in cls.__mapper__.relationships)
# Create a helpful error message
error_msg = f"Error in filter expression: {str(e)}\n"
error_msg += f"Attempted to filter with: {smart_options}\n"
error_msg += f"Valid columns are: {', '.join(valid_columns)}\n"
error_msg += f"Valid relationships are: {', '.join(relationship_names)}"
raise ValueError(error_msg) from e
@classmethod
def filter_by_one(
cls: Type[T], db: Session, system: bool = False, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter single record by keyword arguments.
Args:
db: Database session
system: If True, skip status filtering
**kwargs: Filter criteria
Returns:
Query response with single record
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add keyword filters first
query = query.filter_by(**kwargs)
# Add status filters if not system query
if not system:
query = query.filter(
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False,
)
@classmethod
def filter_one(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter single record by expressions.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with single record
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add status filters
query = query.filter(
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False,
)
@classmethod
def filter_one_system(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter single record by expressions without status filtering.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with single record
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False,
)
@classmethod
def filter_all_system(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter multiple records by expressions without status filtering.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with matching records
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True,
)
@classmethod
def filter_all(
cls: Type[T],
*args: Union[BinaryExpression, ColumnExpressionArgument],
db: Session,
) -> PostgresResponse[T]:
"""
Filter multiple records by expressions.
Args:
db: Database session
*args: Filter expressions
Returns:
Query response with matching records
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add status filters
query = query.filter(
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True,
)
@classmethod
def filter_by_all_system(
cls: Type[T], db: Session, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter multiple records by keyword arguments without status filtering.
Args:
db: Database session
**kwargs: Filter criteria
Returns:
Query response with matching records
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add keyword filters first
query = query.filter_by(**kwargs)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True,
)
@classmethod
def filter_by_one_system(
cls: Type[T], db: Session, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter single record by keyword arguments without status filtering.
Args:
db: Database session
**kwargs: Filter criteria
Returns:
Query response with single record
"""
# Use filter_by_one with system=True to avoid code duplication
return cls.filter_by_one(db=db, system=True, **kwargs)