prod-wag-backend-automate-s.../Controllers/Postgres/crud.py

303 lines
9.5 KiB
Python

import arrow
import datetime
from typing import Optional, Any, Dict
from decimal import Decimal
from fastapi.exceptions import HTTPException
from sqlalchemy import TIMESTAMP, NUMERIC
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm import Session, Mapped
class MetaData:
"""
Class to store metadata for a query.
"""
created: bool = False
updated: bool = False
deleted: bool = False
class CRUDModel:
"""
Base class for CRUD operations on PostgreSQL models.
Features:
- User credential tracking
- Metadata tracking for operations
- Type-safe field handling
- Automatic timestamp management
- Soft delete support
"""
__abstract__ = True
# creds: Credentials = None
meta_data: MetaData = MetaData()
# Define required columns for CRUD operations
required_columns = {
"expiry_starts": TIMESTAMP,
"expiry_ends": TIMESTAMP,
"created_by": str,
"created_by_id": int,
"updated_by": str,
"updated_by_id": int,
"deleted": bool,
}
@classmethod
def raise_exception(
cls, message: str = "Exception raised.", status_code: int = 400
):
"""
Raise HTTP exception with custom message and status code.
Args:
message: Error message
status_code: HTTP status code
"""
raise HTTPException(status_code=status_code, detail={"message": message})
@classmethod
def create_or_abort(cls, db: Session, **kwargs):
"""
Create a new record or abort if it already exists.
Args:
db: Database session
**kwargs: Record fields
Returns:
New record if successfully created
Raises:
HTTPException: If record already exists or creation fails
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
for key, value in kwargs.items():
if hasattr(cls, key):
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
# Handle existing record
if already_record and already_record.deleted:
cls.raise_exception("Record already exists and is deleted")
elif already_record:
cls.raise_exception("Record already exists")
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
db.add(created_record)
db.flush()
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
@classmethod
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
"""
Process a field value based on its type and convert it to the appropriate format.
Args:
val: Field value
key: Field name
Returns:
Tuple of (should_include, processed_value)
"""
try:
key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, "primary_keys", [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
if is_primary or row_attr:
return False, None
if val is None: # Handle None values
return True, None
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
return True, str(val)
if key_: # Handle typed fields
if key_ == Mapped[int]:
return True, int(val)
elif key_ == Mapped[bool]:
return True, bool(val)
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
return True, round(float(val), 3)
elif isinstance(val, int):
return True, int(val)
elif isinstance(val, str):
return True, str(val)
elif val is None:
return True, None
return False, None
except Exception as e:
err = e
return False, None
def get_dict(
self, exclude_list: Optional[list[InstrumentedAttribute]] = None
) -> Dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Args:
exclude_list: List of fields to exclude from the dictionary
Returns:
Dictionary representation of the model
"""
try:
return_dict: Dict[str, Any] = {}
exclude_list = exclude_list or []
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
# Get all column names from the model
columns = [col.name for col in self.__table__.columns]
columns_set = set(columns)
# Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set(
col for col in columns_set if str(col)[-5:].lower() == "uu_id"
)
columns_list = set(columns_list) | set(columns_extend)
columns_list = list(set(columns_list) - set(exclude_list))
for key in columns_list:
val = getattr(self, key)
correct, value_of_database = self.iterate_over_variables(val, key)
if correct:
return_dict[key] = value_of_database
return return_dict
except Exception as e:
err = e
return {}
@classmethod
def find_or_create(
cls,
db: Session,
exclude_args: Optional[list[InstrumentedAttribute]] = None,
include_args: Optional[list[InstrumentedAttribute]] = None,
**kwargs,
):
"""
Find an existing record matching the criteria or create a new one.
Args:
db: Database session
exclude_args: Keys to exclude from search
include_args: Keys to specifically include in search (if provided, only these will be used)
**kwargs: Search/creation criteria
Returns:
Existing or newly created record
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
include_args = include_args or []
include_args = [include_arg.key for include_arg in include_args]
# If include_args is provided, only use those fields for matching
# Otherwise, use all fields except those in exclude_args
for key, value in kwargs.items():
if hasattr(cls, key):
if include_args and key in include_args:
query = query.filter(getattr(cls, key) == value)
elif not include_args and key not in exclude_args:
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
if already_record: # Handle existing record
cls.meta_data.created = False
return already_record
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
db.add(created_record)
db.flush()
cls.meta_data.created = True
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(
f"Failed to find or create record: {str(e)}", status_code=500
)
def update(self, db: Session, **kwargs):
"""
Update the record with new values.
Args:
db: Database session
**kwargs: Fields to update
Returns:
Updated record
Raises:
HTTPException: If update fails
"""
try:
for key, value in kwargs.items():
setattr(self, key, value)
db.flush()
self.meta_data.updated = True
return self
except Exception as e:
self.meta_data.updated = False
db.rollback()
self.raise_exception(f"Failed to update record: {str(e)}", status_code=500)