276 lines
9.5 KiB
Python
276 lines
9.5 KiB
Python
import arrow
|
|
import datetime
|
|
|
|
from decimal import Decimal
|
|
from typing import Any, TypeVar, Type, Union, Optional
|
|
|
|
from sqlalchemy import Column, Integer, String, Float, ForeignKey, UUID, TIMESTAMP, Boolean, SmallInteger, Numeric, func, text, NUMERIC, ColumnExpressionArgument
|
|
from sqlalchemy.orm import InstrumentedAttribute, Mapped, mapped_column, Query, Session
|
|
from sqlalchemy.sql.elements import BinaryExpression
|
|
|
|
from sqlalchemy_mixins.serialize import SerializeMixin
|
|
from sqlalchemy_mixins.repr import ReprMixin
|
|
from sqlalchemy_mixins.smartquery import SmartQueryMixin
|
|
from sqlalchemy_mixins.activerecord import ActiveRecordMixin
|
|
|
|
from api_controllers.postgres.engine import get_db, Base
|
|
|
|
|
|
T = TypeVar("CrudMixin", bound="CrudMixin")
|
|
|
|
|
|
class BasicMixin(Base, ActiveRecordMixin, SerializeMixin, ReprMixin, SmartQueryMixin):
|
|
|
|
__abstract__ = True
|
|
__repr__ = ReprMixin.__repr__
|
|
|
|
@classmethod
|
|
def new_session(cls):
|
|
"""Get database session."""
|
|
return get_db()
|
|
|
|
@classmethod
|
|
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
|
|
"""
|
|
Process a field value based on its type and convert it to the appropriate format.
|
|
|
|
Args:
|
|
val: Field value
|
|
key: Field name
|
|
|
|
Returns:
|
|
Tuple of (should_include, processed_value)
|
|
"""
|
|
try:
|
|
key_ = cls.__annotations__.get(key, None)
|
|
is_primary = key in getattr(cls, "primary_keys", [])
|
|
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
|
|
|
|
# Skip primary keys and foreign keys
|
|
if is_primary or row_attr:
|
|
return False, None
|
|
|
|
if val is None: # Handle None values
|
|
return True, None
|
|
|
|
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
|
|
return True, str(val)
|
|
|
|
if key_: # Handle typed fields
|
|
if key_ == Mapped[int]:
|
|
return True, int(val)
|
|
elif key_ == Mapped[bool]:
|
|
return True, bool(val)
|
|
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
|
|
return True, round(float(val), 3)
|
|
elif key_ == Mapped[TIMESTAMP]:
|
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
|
|
elif key_ == Mapped[str]:
|
|
return True, str(val)
|
|
else: # Handle based on Python types
|
|
if isinstance(val, datetime.datetime):
|
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
|
|
elif isinstance(val, bool):
|
|
return True, bool(val)
|
|
elif isinstance(val, (float, Decimal)):
|
|
return True, round(float(val), 3)
|
|
elif isinstance(val, int):
|
|
return True, int(val)
|
|
elif isinstance(val, str):
|
|
return True, str(val)
|
|
elif val is None:
|
|
return True, None
|
|
return False, None
|
|
|
|
except Exception as e:
|
|
err = e
|
|
return False, None
|
|
|
|
@classmethod
|
|
def convert(cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None) -> Optional[tuple[BinaryExpression, ...]]:
|
|
"""
|
|
Convert smart options to SQLAlchemy filter expressions.
|
|
|
|
Args:
|
|
smart_options: Dictionary of filter options
|
|
validate_model: Optional model to validate against
|
|
|
|
Returns:
|
|
Tuple of SQLAlchemy filter expressions or None if validation fails
|
|
"""
|
|
try:
|
|
# Let SQLAlchemy handle the validation by attempting to create the filter expressions
|
|
return tuple(cls.filter_expr(**smart_options))
|
|
except Exception as e:
|
|
# If there's an error, provide a helpful message with valid columns and relationships
|
|
valid_columns = set()
|
|
relationship_names = set()
|
|
|
|
# Get column names if available
|
|
if hasattr(cls, '__table__') and hasattr(cls.__table__, 'columns'):
|
|
valid_columns = set(column.key for column in cls.__table__.columns)
|
|
|
|
# Get relationship names if available
|
|
if hasattr(cls, '__mapper__') and hasattr(cls.__mapper__, 'relationships'):
|
|
relationship_names = set(rel.key for rel in cls.__mapper__.relationships)
|
|
|
|
# Create a helpful error message
|
|
error_msg = f"Error in filter expression: {str(e)}\n"
|
|
error_msg += f"Attempted to filter with: {smart_options}\n"
|
|
error_msg += f"Valid columns are: {', '.join(valid_columns)}\n"
|
|
error_msg += f"Valid relationships are: {', '.join(relationship_names)}"
|
|
|
|
raise ValueError(error_msg) from e
|
|
|
|
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> dict[str, Any]:
|
|
"""
|
|
Convert model instance to dictionary with customizable fields.
|
|
|
|
Args:
|
|
exclude_list: List of fields to exclude from the dictionary
|
|
|
|
Returns:
|
|
Dictionary representation of the model
|
|
"""
|
|
try:
|
|
return_dict: Dict[str, Any] = {}
|
|
exclude_list = exclude_list or []
|
|
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
|
|
|
|
# Get all column names from the model
|
|
columns = [col.name for col in self.__table__.columns]
|
|
columns_set = set(columns)
|
|
|
|
# Filter columns
|
|
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
|
|
columns_extend = set(col for col in columns_set if str(col)[-5:].lower() == "uu_id")
|
|
columns_list = set(columns_list) | set(columns_extend)
|
|
columns_list = list(set(columns_list) - set(exclude_list))
|
|
|
|
for key in columns_list:
|
|
val = getattr(self, key)
|
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
|
if correct:
|
|
return_dict[key] = value_of_database
|
|
|
|
return return_dict
|
|
|
|
except Exception as e:
|
|
err = e
|
|
return {}
|
|
|
|
|
|
|
|
class CrudMixin(BasicMixin):
|
|
"""
|
|
Base mixin providing CRUD operations and common fields for PostgreSQL models.
|
|
|
|
Features:
|
|
- Automatic timestamps (created_at, updated_at)
|
|
- Soft delete capability
|
|
- User tracking (created_by, updated_by)
|
|
- Data serialization
|
|
- Multi-language support
|
|
"""
|
|
|
|
__abstract__ = True
|
|
|
|
# Primary and reference fields
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
uu_id: Mapped[str] = mapped_column(
|
|
UUID,
|
|
server_default=text("gen_random_uuid()"),
|
|
index=True,
|
|
unique=True,
|
|
comment="Unique identifier UUID",
|
|
)
|
|
|
|
# Common timestamp fields for all models
|
|
expiry_starts: Mapped[TIMESTAMP] = mapped_column(
|
|
TIMESTAMP(timezone=True),
|
|
server_default=func.now(),
|
|
comment="Record validity start timestamp",
|
|
)
|
|
expiry_ends: Mapped[TIMESTAMP] = mapped_column(
|
|
TIMESTAMP(timezone=True),
|
|
default=str(arrow.get("2099-12-31")),
|
|
server_default=func.now(),
|
|
comment="Record validity end timestamp",
|
|
)
|
|
|
|
# Timestamps
|
|
created_at: Mapped[TIMESTAMP] = mapped_column(
|
|
TIMESTAMP(timezone=True),
|
|
server_default=func.now(),
|
|
nullable=False,
|
|
index=True,
|
|
comment="Record creation timestamp",
|
|
)
|
|
updated_at: Mapped[TIMESTAMP] = mapped_column(
|
|
TIMESTAMP(timezone=True),
|
|
server_default=func.now(),
|
|
onupdate=func.now(),
|
|
nullable=False,
|
|
index=True,
|
|
comment="Last update timestamp",
|
|
)
|
|
|
|
|
|
class CrudCollection(CrudMixin):
|
|
"""
|
|
Full-featured model class with all common fields.
|
|
|
|
Includes:
|
|
- UUID and reference ID
|
|
- Timestamps
|
|
- User tracking
|
|
- Confirmation status
|
|
- Soft delete
|
|
- Notification flags
|
|
"""
|
|
|
|
__abstract__ = True
|
|
__repr__ = ReprMixin.__repr__
|
|
|
|
# Outer reference fields
|
|
ref_id: Mapped[str] = mapped_column(
|
|
String(100), nullable=True, index=True, comment="External reference ID"
|
|
)
|
|
replication_id: Mapped[int] = mapped_column(
|
|
SmallInteger, server_default="0", comment="Replication identifier"
|
|
)
|
|
|
|
# Cryptographic and user tracking
|
|
cryp_uu_id: Mapped[str] = mapped_column(
|
|
String, nullable=True, index=True, comment="Cryptographic UUID"
|
|
)
|
|
|
|
# Token fields of modification
|
|
created_credentials_token: Mapped[str] = mapped_column(
|
|
String, nullable=True, comment="Created Credentials token"
|
|
)
|
|
updated_credentials_token: Mapped[str] = mapped_column(
|
|
String, nullable=True, comment="Updated Credentials token"
|
|
)
|
|
confirmed_credentials_token: Mapped[str] = mapped_column(
|
|
String, nullable=True, comment="Confirmed Credentials token"
|
|
)
|
|
|
|
# Status flags
|
|
is_confirmed: Mapped[bool] = mapped_column(
|
|
Boolean, server_default="0", comment="Record confirmation status"
|
|
)
|
|
deleted: Mapped[bool] = mapped_column(
|
|
Boolean, server_default="0", comment="Soft delete flag"
|
|
)
|
|
active: Mapped[bool] = mapped_column(
|
|
Boolean, server_default="1", comment="Record active status"
|
|
)
|
|
is_notification_send: Mapped[bool] = mapped_column(
|
|
Boolean, server_default="0", comment="Notification sent flag"
|
|
)
|
|
is_email_send: Mapped[bool] = mapped_column(
|
|
Boolean, server_default="0", comment="Email sent flag"
|
|
)
|
|
|