323 lines
10 KiB
Python
323 lines
10 KiB
Python
import arrow
|
|
import datetime
|
|
|
|
from typing import Optional, Any, Dict, List
|
|
from sqlalchemy.orm import Session, Mapped
|
|
from pydantic import BaseModel
|
|
from fastapi.exceptions import HTTPException
|
|
from decimal import Decimal
|
|
from sqlalchemy import TIMESTAMP, NUMERIC
|
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
|
|
|
|
|
class Credentials(BaseModel):
|
|
"""
|
|
Class to store user credentials.
|
|
"""
|
|
person_id: int
|
|
person_name: str
|
|
full_name: Optional[str] = None
|
|
|
|
|
|
class MetaData:
|
|
"""
|
|
Class to store metadata for a query.
|
|
"""
|
|
created: bool = False
|
|
updated: bool = False
|
|
|
|
|
|
class CRUDModel:
|
|
"""
|
|
Base class for CRUD operations on PostgreSQL models.
|
|
|
|
Features:
|
|
- User credential tracking
|
|
- Metadata tracking for operations
|
|
- Type-safe field handling
|
|
- Automatic timestamp management
|
|
- Soft delete support
|
|
"""
|
|
|
|
__abstract__ = True
|
|
|
|
creds: Credentials = None
|
|
meta_data: MetaData = MetaData()
|
|
|
|
# Define required columns for CRUD operations
|
|
required_columns = {
|
|
'expiry_starts': TIMESTAMP,
|
|
'expiry_ends': TIMESTAMP,
|
|
'created_by': str,
|
|
'created_by_id': int,
|
|
'updated_by': str,
|
|
'updated_by_id': int,
|
|
'deleted': bool
|
|
}
|
|
|
|
@classmethod
|
|
def create_credentials(cls, record_created) -> None:
|
|
"""
|
|
Save user credentials for tracking.
|
|
|
|
Args:
|
|
record_created: Record that created or updated
|
|
"""
|
|
if not cls.creds:
|
|
return
|
|
|
|
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
|
|
record_created.created_by_id = cls.creds.person_id
|
|
record_created.created_by = cls.creds.person_name
|
|
|
|
@classmethod
|
|
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400):
|
|
"""
|
|
Raise HTTP exception with custom message and status code.
|
|
|
|
Args:
|
|
message: Error message
|
|
status_code: HTTP status code
|
|
"""
|
|
raise HTTPException(
|
|
status_code=status_code,
|
|
detail={"message": message}
|
|
)
|
|
|
|
@classmethod
|
|
def create_or_abort(cls, db: Session, **kwargs):
|
|
"""
|
|
Create a new record or abort if it already exists.
|
|
|
|
Args:
|
|
db: Database session
|
|
**kwargs: Record fields
|
|
|
|
Returns:
|
|
New record if successfully created
|
|
|
|
Raises:
|
|
HTTPException: If record already exists or creation fails
|
|
"""
|
|
try:
|
|
# Search for existing record
|
|
query = db.query(cls).filter(
|
|
cls.expiry_ends > str(arrow.now()),
|
|
cls.expiry_starts <= str(arrow.now()),
|
|
)
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(cls, key):
|
|
query = query.filter(getattr(cls, key) == value)
|
|
|
|
already_record = query.first()
|
|
|
|
# Handle existing record
|
|
if already_record and already_record.deleted:
|
|
cls.raise_exception("Record already exists and is deleted")
|
|
elif already_record:
|
|
cls.raise_exception("Record already exists")
|
|
|
|
# Create new record
|
|
created_record = cls()
|
|
for key, value in kwargs.items():
|
|
setattr(created_record, key, value)
|
|
|
|
cls.create_credentials(created_record)
|
|
db.add(created_record)
|
|
db.flush()
|
|
return created_record
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
|
|
|
|
@classmethod
|
|
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
|
|
"""
|
|
Process a field value based on its type and convert it to the appropriate format.
|
|
|
|
Args:
|
|
val: Field value
|
|
key: Field name
|
|
|
|
Returns:
|
|
Tuple of (should_include, processed_value)
|
|
"""
|
|
try:
|
|
key_ = cls.__annotations__.get(key, None)
|
|
is_primary = key in getattr(cls, 'primary_keys', [])
|
|
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
|
|
|
|
# Skip primary keys and foreign keys
|
|
if is_primary or row_attr:
|
|
return False, None
|
|
|
|
if val is None: # Handle None values
|
|
return True, None
|
|
|
|
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
|
|
return True, str(val)
|
|
|
|
if key_: # Handle typed fields
|
|
if key_ == Mapped[int]:
|
|
return True, int(val)
|
|
elif key_ == Mapped[bool]:
|
|
return True, bool(val)
|
|
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
|
|
return True, round(float(val), 3)
|
|
elif key_ == Mapped[TIMESTAMP]:
|
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
|
|
elif key_ == Mapped[str]:
|
|
return True, str(val)
|
|
else: # Handle based on Python types
|
|
if isinstance(val, datetime.datetime):
|
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
|
|
elif isinstance(val, bool):
|
|
return True, bool(val)
|
|
elif isinstance(val, (float, Decimal)):
|
|
return True, round(float(val), 3)
|
|
elif isinstance(val, int):
|
|
return True, int(val)
|
|
elif isinstance(val, str):
|
|
return True, str(val)
|
|
elif val is None:
|
|
return True, None
|
|
|
|
return False, None
|
|
|
|
except Exception as e:
|
|
return False, None
|
|
|
|
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Convert model instance to dictionary with customizable fields.
|
|
|
|
Args:
|
|
exclude_list: List of fields to exclude from the dictionary
|
|
|
|
Returns:
|
|
Dictionary representation of the model
|
|
"""
|
|
try:
|
|
return_dict: Dict[str, Any] = {}
|
|
exclude_list = exclude_list or []
|
|
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
|
|
|
|
# Get all column names from the model
|
|
columns = [col.name for col in self.__table__.columns]
|
|
columns_set = set(columns)
|
|
|
|
# Filter columns
|
|
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
|
|
columns_extend = set(
|
|
col for col in columns_set if str(col)[-5:].lower() == "uu_id"
|
|
)
|
|
columns_list = set(columns_list) | set(columns_extend)
|
|
columns_list = list(set(columns_list) - set(exclude_list))
|
|
|
|
for key in columns_list:
|
|
val = getattr(self, key)
|
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
|
if correct:
|
|
return_dict[key] = value_of_database
|
|
|
|
return return_dict
|
|
|
|
except Exception as e:
|
|
return {}
|
|
|
|
@classmethod
|
|
def find_or_create(
|
|
cls,
|
|
db: Session,
|
|
exclude_args: Optional[list[InstrumentedAttribute]] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Find an existing record matching the criteria or create a new one.
|
|
|
|
Args:
|
|
db: Database session
|
|
exclude_args: Keys to exclude from search
|
|
**kwargs: Search/creation criteria
|
|
|
|
Returns:
|
|
Existing or newly created record
|
|
"""
|
|
try:
|
|
# Search for existing record
|
|
query = db.query(cls).filter(
|
|
cls.expiry_ends > str(arrow.now()),
|
|
cls.expiry_starts <= str(arrow.now()),
|
|
)
|
|
|
|
exclude_args = exclude_args or []
|
|
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(cls, key) and key not in exclude_args:
|
|
query = query.filter(getattr(cls, key) == value)
|
|
|
|
already_record = query.first()
|
|
if already_record: # Handle existing record
|
|
cls.meta_data.created = False
|
|
return already_record
|
|
|
|
# Create new record
|
|
created_record = cls()
|
|
for key, value in kwargs.items():
|
|
setattr(created_record, key, value)
|
|
|
|
cls.create_credentials(created_record)
|
|
db.add(created_record)
|
|
db.flush()
|
|
cls.meta_data.created = True
|
|
return created_record
|
|
|
|
except Exception as e:
|
|
db.rollback()
|
|
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500)
|
|
|
|
def update(self, db: Session, **kwargs):
|
|
"""
|
|
Update the record with new values.
|
|
|
|
Args:
|
|
db: Database session
|
|
**kwargs: Fields to update
|
|
|
|
Returns:
|
|
Updated record
|
|
|
|
Raises:
|
|
HTTPException: If update fails
|
|
"""
|
|
try:
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
self.update_credentials()
|
|
db.flush()
|
|
self.meta_data.updated = True
|
|
return self
|
|
|
|
except Exception as e:
|
|
self.meta_data.updated = False
|
|
db.rollback()
|
|
self.raise_exception(f"Failed to update record: {str(e)}", status_code=500)
|
|
|
|
def update_credentials(self) -> None:
|
|
"""
|
|
Save user credentials for tracking.
|
|
"""
|
|
if not self.creds:
|
|
return
|
|
|
|
person_id = getattr(self.creds, "person_id", None)
|
|
person_name = getattr(self.creds, "person_name", None)
|
|
|
|
if person_id and person_name:
|
|
self.updated_by_id = self.creds.person_id
|
|
self.updated_by = self.creds.person_name
|