prod-wag-backend-automate-s.../Controllers/Postgres/crud.py

323 lines
10 KiB
Python

import arrow
import datetime
from typing import Optional, Any, Dict, List
from sqlalchemy.orm import Session, Mapped
from pydantic import BaseModel
from fastapi.exceptions import HTTPException
from decimal import Decimal
from sqlalchemy import TIMESTAMP, NUMERIC
from sqlalchemy.orm.attributes import InstrumentedAttribute
class Credentials(BaseModel):
"""
Class to store user credentials.
"""
person_id: int
person_name: str
full_name: Optional[str] = None
class MetaData:
"""
Class to store metadata for a query.
"""
created: bool = False
updated: bool = False
class CRUDModel:
"""
Base class for CRUD operations on PostgreSQL models.
Features:
- User credential tracking
- Metadata tracking for operations
- Type-safe field handling
- Automatic timestamp management
- Soft delete support
"""
__abstract__ = True
creds: Credentials = None
meta_data: MetaData = MetaData()
# Define required columns for CRUD operations
required_columns = {
'expiry_starts': TIMESTAMP,
'expiry_ends': TIMESTAMP,
'created_by': str,
'created_by_id': int,
'updated_by': str,
'updated_by_id': int,
'deleted': bool
}
@classmethod
def create_credentials(cls, record_created) -> None:
"""
Save user credentials for tracking.
Args:
record_created: Record that created or updated
"""
if not cls.creds:
return
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
record_created.created_by_id = cls.creds.person_id
record_created.created_by = cls.creds.person_name
@classmethod
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400):
"""
Raise HTTP exception with custom message and status code.
Args:
message: Error message
status_code: HTTP status code
"""
raise HTTPException(
status_code=status_code,
detail={"message": message}
)
@classmethod
def create_or_abort(cls, db: Session, **kwargs):
"""
Create a new record or abort if it already exists.
Args:
db: Database session
**kwargs: Record fields
Returns:
New record if successfully created
Raises:
HTTPException: If record already exists or creation fails
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
for key, value in kwargs.items():
if hasattr(cls, key):
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
# Handle existing record
if already_record and already_record.deleted:
cls.raise_exception("Record already exists and is deleted")
elif already_record:
cls.raise_exception("Record already exists")
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
@classmethod
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
"""
Process a field value based on its type and convert it to the appropriate format.
Args:
val: Field value
key: Field name
Returns:
Tuple of (should_include, processed_value)
"""
try:
key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, 'primary_keys', [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
if is_primary or row_attr:
return False, None
if val is None: # Handle None values
return True, None
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
return True, str(val)
if key_: # Handle typed fields
if key_ == Mapped[int]:
return True, int(val)
elif key_ == Mapped[bool]:
return True, bool(val)
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
return True, round(float(val), 3)
elif isinstance(val, int):
return True, int(val)
elif isinstance(val, str):
return True, str(val)
elif val is None:
return True, None
return False, None
except Exception as e:
return False, None
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Args:
exclude_list: List of fields to exclude from the dictionary
Returns:
Dictionary representation of the model
"""
try:
return_dict: Dict[str, Any] = {}
exclude_list = exclude_list or []
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
# Get all column names from the model
columns = [col.name for col in self.__table__.columns]
columns_set = set(columns)
# Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set(
col for col in columns_set if str(col)[-5:].lower() == "uu_id"
)
columns_list = set(columns_list) | set(columns_extend)
columns_list = list(set(columns_list) - set(exclude_list))
for key in columns_list:
val = getattr(self, key)
correct, value_of_database = self.iterate_over_variables(val, key)
if correct:
return_dict[key] = value_of_database
return return_dict
except Exception as e:
return {}
@classmethod
def find_or_create(
cls,
db: Session,
exclude_args: Optional[list[InstrumentedAttribute]] = None,
**kwargs,
):
"""
Find an existing record matching the criteria or create a new one.
Args:
db: Database session
exclude_args: Keys to exclude from search
**kwargs: Search/creation criteria
Returns:
Existing or newly created record
"""
try:
# Search for existing record
query = db.query(cls).filter(
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
for key, value in kwargs.items():
if hasattr(cls, key) and key not in exclude_args:
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
if already_record: # Handle existing record
cls.meta_data.created = False
return already_record
# Create new record
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
cls.meta_data.created = True
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500)
def update(self, db: Session, **kwargs):
"""
Update the record with new values.
Args:
db: Database session
**kwargs: Fields to update
Returns:
Updated record
Raises:
HTTPException: If update fails
"""
try:
for key, value in kwargs.items():
setattr(self, key, value)
self.update_credentials()
db.flush()
self.meta_data.updated = True
return self
except Exception as e:
self.meta_data.updated = False
db.rollback()
self.raise_exception(f"Failed to update record: {str(e)}", status_code=500)
def update_credentials(self) -> None:
"""
Save user credentials for tracking.
"""
if not self.creds:
return
person_id = getattr(self.creds, "person_id", None)
person_name = getattr(self.creds, "person_name", None)
if person_id and person_name:
self.updated_by_id = self.creds.person_id
self.updated_by = self.creds.person_name