400 lines
13 KiB
Python
400 lines
13 KiB
Python
import datetime
|
|
|
|
from decimal import Decimal
|
|
from typing import Any, Dict, List, Optional
|
|
from sqlalchemy import TIMESTAMP, NUMERIC
|
|
from sqlalchemy.orm import Session, Mapped
|
|
from pydantic import BaseModel
|
|
|
|
from ApiLayers.ApiLibrary import system_arrow, get_line_number_for_error
|
|
from ApiLayers.ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi
|
|
|
|
from Services.PostgresDb.Models.core_alchemy import BaseAlchemyModel
|
|
from Services.PostgresDb.Models.system_fields import SystemFields
|
|
|
|
|
|
class MetaDataRow(BaseModel):
|
|
created: Optional[bool] = False
|
|
message: Optional[str] = None
|
|
error_case: Optional[str] = None
|
|
|
|
|
|
class Credentials(BaseModel):
|
|
person_id: int
|
|
person_name: str
|
|
|
|
|
|
class CrudActions(SystemFields):
|
|
|
|
@classmethod
|
|
def extract_system_fields(
|
|
cls, filter_kwargs: dict, create: bool = True
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Remove system-managed fields from input dictionary.
|
|
|
|
Args:
|
|
filter_kwargs: Input dictionary of fields
|
|
create: If True, use creation field list, else use update field list
|
|
|
|
Returns:
|
|
Dictionary with system fields removed
|
|
"""
|
|
system_fields = filter_kwargs.copy()
|
|
extract_fields = (
|
|
cls.__system__fields__create__ if create else cls.__system__fields__update__
|
|
)
|
|
for field in extract_fields:
|
|
system_fields.pop(field, None)
|
|
return system_fields
|
|
|
|
@classmethod
|
|
def remove_non_related_inputs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Filter out inputs that don't correspond to model fields.
|
|
|
|
Args:
|
|
kwargs: Dictionary of field names and values
|
|
|
|
Returns:
|
|
Dictionary containing only valid model fields
|
|
"""
|
|
return {
|
|
key: value
|
|
for key, value in kwargs.items()
|
|
if key in cls.columns + cls.hybrid_properties + cls.settable_relations
|
|
}
|
|
|
|
@classmethod
|
|
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
|
|
"""
|
|
Process a field value based on its type and convert it to the appropriate format.
|
|
|
|
Args:
|
|
val: Field value
|
|
key: Field name
|
|
|
|
Returns:
|
|
Tuple of (should_include, processed_value)
|
|
"""
|
|
key_ = cls.__annotations__.get(key, None)
|
|
is_primary = key in cls.primary_keys
|
|
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
|
|
|
|
# Skip primary keys and foreign keys
|
|
if is_primary or row_attr:
|
|
return False, None
|
|
|
|
# Handle None values
|
|
if val is None:
|
|
return True, None
|
|
|
|
# Special handling for UUID fields
|
|
if str(key[-5:]).lower() == "uu_id":
|
|
return True, str(val)
|
|
|
|
# Handle typed fields
|
|
if key_:
|
|
if key_ == Mapped[int]:
|
|
return True, int(val)
|
|
elif key_ == Mapped[bool]:
|
|
return True, bool(val)
|
|
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
|
|
return True, round(float(val), 3)
|
|
elif key_ == Mapped[TIMESTAMP]:
|
|
return True, str(
|
|
system_arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
|
|
)
|
|
elif key_ == Mapped[str]:
|
|
return True, str(val)
|
|
|
|
# Handle based on Python types
|
|
else:
|
|
if isinstance(val, datetime.datetime):
|
|
return True, str(
|
|
system_arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
|
|
)
|
|
elif isinstance(val, bool):
|
|
return True, bool(val)
|
|
elif isinstance(val, (float, Decimal)):
|
|
return True, round(float(val), 3)
|
|
elif isinstance(val, int):
|
|
return True, int(val)
|
|
elif isinstance(val, str):
|
|
return True, str(val)
|
|
elif val is None:
|
|
return True, None
|
|
|
|
return False, None
|
|
|
|
def get_dict(
|
|
self,
|
|
exclude: Optional[List[str]] = None,
|
|
include: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Convert model instance to dictionary with customizable fields.
|
|
|
|
Args:
|
|
exclude: List of fields to exclude
|
|
include: List of fields to include (takes precedence over exclude)
|
|
|
|
Returns:
|
|
Dictionary representation of the model
|
|
"""
|
|
return_dict: Dict[str, Any] = {}
|
|
|
|
if include:
|
|
# Handle explicitly included fields
|
|
exclude_list = [
|
|
element
|
|
for element in self.__system_default_model__
|
|
if str(element)[-2:] == "id" and str(element)[-5:].lower() == "uu_id"
|
|
]
|
|
columns_include_list = list(set(include).difference(set(exclude_list)))
|
|
columns_include_list.extend(["uu_id"])
|
|
|
|
for key in columns_include_list:
|
|
val = getattr(self, key)
|
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
|
if correct:
|
|
return_dict[key] = value_of_database
|
|
|
|
elif exclude:
|
|
# Handle explicitly excluded fields
|
|
exclude.extend(
|
|
list(
|
|
set(getattr(self, "__exclude__fields__", []) or []).difference(
|
|
exclude
|
|
)
|
|
)
|
|
)
|
|
exclude.extend(
|
|
[
|
|
element
|
|
for element in self.__system_default_model__
|
|
if str(element)[-2:] == "id"
|
|
]
|
|
)
|
|
|
|
columns_excluded_list = list(set(self.columns).difference(set(exclude)))
|
|
columns_excluded_list.extend(["uu_id", "active"])
|
|
|
|
for key in columns_excluded_list:
|
|
val = getattr(self, key)
|
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
|
if correct:
|
|
return_dict[key] = value_of_database
|
|
else:
|
|
# Handle default field selection
|
|
exclude_list = (getattr(self, "__exclude__fields__", []) or []) + list(
|
|
self.__system_default_model__
|
|
)
|
|
columns_list = list(set(self.columns).difference(set(exclude_list)))
|
|
columns_list = [col for col in columns_list if str(col)[-2:] != "id"]
|
|
columns_list.extend(
|
|
[col for col in self.columns if str(col)[-5:].lower() == "uu_id"]
|
|
)
|
|
|
|
for remove_field in self.__system_default_model__:
|
|
if remove_field in columns_list:
|
|
columns_list.remove(remove_field)
|
|
|
|
for key in columns_list:
|
|
val = getattr(self, key)
|
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
|
if correct:
|
|
return_dict[key] = value_of_database
|
|
|
|
return return_dict
|
|
|
|
|
|
class CRUDModel(BaseAlchemyModel, CrudActions):
|
|
|
|
__abstract__ = True
|
|
|
|
meta_data: MetaDataRow
|
|
creds: Credentials = None
|
|
|
|
@property
|
|
def is_created(self):
|
|
return self.meta_data.created
|
|
|
|
@classmethod
|
|
def create_credentials(cls, record_created) -> None:
|
|
"""
|
|
Save user credentials for tracking.
|
|
|
|
Args:
|
|
record_created: Record that created or updated
|
|
"""
|
|
|
|
if getattr(cls.creds, "person_id", None) and getattr(
|
|
cls.creds, "person_name", None
|
|
):
|
|
record_created.created_by_id = cls.creds.person_id
|
|
record_created.created_by = cls.creds.person_name
|
|
return
|
|
|
|
@classmethod
|
|
def update_metadata(
|
|
cls, created: bool, error_case: str = None, message: str = None
|
|
) -> None:
|
|
cls.meta_data = MetaDataRow(
|
|
created=created, error_case=error_case, message=message
|
|
)
|
|
|
|
@classmethod
|
|
def raise_exception(cls):
|
|
raise HTTPExceptionApi(
|
|
error_code=cls.meta_data.error_case,
|
|
lang=cls.lang,
|
|
loc=get_line_number_for_error(),
|
|
sys_msg=cls.meta_data.message,
|
|
)
|
|
|
|
@classmethod
|
|
def create_or_abort(cls, db: Session, **kwargs):
|
|
"""
|
|
Create a new record or abort if it already exists.
|
|
|
|
Args:
|
|
db: Database session
|
|
**kwargs: Record fields
|
|
|
|
Returns:
|
|
New record if successfully created
|
|
"""
|
|
check_kwargs = cls.extract_system_fields(kwargs)
|
|
|
|
# Search for existing record
|
|
query = db.query(cls).filter(
|
|
cls.expiry_ends > str(system_arrow.now()),
|
|
cls.expiry_starts <= str(system_arrow.now()),
|
|
)
|
|
|
|
for key, value in check_kwargs.items():
|
|
if hasattr(cls, key):
|
|
query = query.filter(getattr(cls, key) == value)
|
|
|
|
already_record = query.first()
|
|
# Handle existing record
|
|
if already_record:
|
|
if already_record.deleted:
|
|
cls.update_metadata(created=False, error_case="DeletedRecord")
|
|
cls.raise_exception()
|
|
elif not already_record.is_confirmed:
|
|
cls.update_metadata(created=False, error_case="IsNotConfirmed")
|
|
cls.raise_exception()
|
|
cls.update_metadata(created=False, error_case="AlreadyExists")
|
|
cls.raise_exception()
|
|
|
|
# Create new record
|
|
check_kwargs = cls.remove_non_related_inputs(check_kwargs)
|
|
created_record = cls()
|
|
for key, value in check_kwargs.items():
|
|
setattr(created_record, key, value)
|
|
cls.create_credentials(created_record)
|
|
db.add(created_record)
|
|
db.flush()
|
|
cls.update_metadata(created=True)
|
|
return created_record
|
|
|
|
@classmethod
|
|
def find_or_create(cls, db: Session, **kwargs):
|
|
"""
|
|
Find an existing record matching the criteria or create a new one.
|
|
|
|
Args:
|
|
db: Database session
|
|
**kwargs: Search/creation criteria
|
|
|
|
Returns:
|
|
Existing or newly created record
|
|
"""
|
|
check_kwargs = cls.extract_system_fields(kwargs)
|
|
|
|
# Search for existing record
|
|
query = db.query(cls).filter(
|
|
cls.expiry_ends > str(system_arrow.now()),
|
|
cls.expiry_starts <= str(system_arrow.now()),
|
|
)
|
|
|
|
for key, value in check_kwargs.items():
|
|
if hasattr(cls, key):
|
|
query = query.filter(getattr(cls, key) == value)
|
|
|
|
already_record = query.first()
|
|
# Handle existing record
|
|
if already_record:
|
|
if already_record.deleted:
|
|
cls.update_metadata(created=False, error_case="DeletedRecord")
|
|
return already_record
|
|
elif not already_record.is_confirmed:
|
|
cls.update_metadata(created=False, error_case="IsNotConfirmed")
|
|
return already_record
|
|
cls.update_metadata(created=False, error_case="AlreadyExists")
|
|
return already_record
|
|
|
|
# Create new record
|
|
check_kwargs = cls.remove_non_related_inputs(check_kwargs)
|
|
created_record = cls()
|
|
for key, value in check_kwargs.items():
|
|
setattr(created_record, key, value)
|
|
cls.create_credentials(created_record)
|
|
db.add(created_record)
|
|
db.flush()
|
|
cls.update_metadata(created=True)
|
|
return created_record
|
|
|
|
def update(self, db: Session, **kwargs):
|
|
"""
|
|
Update the record with new values.
|
|
|
|
Args:
|
|
db: Database session
|
|
**kwargs: Fields to update
|
|
|
|
Returns:
|
|
Updated record
|
|
|
|
Raises:
|
|
ValueError: If attempting to update is_confirmed with other fields
|
|
"""
|
|
check_kwargs = self.remove_non_related_inputs(kwargs)
|
|
check_kwargs = self.extract_system_fields(check_kwargs, create=False)
|
|
|
|
for key, value in check_kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
self.update_credentials(kwargs=kwargs)
|
|
db.flush()
|
|
return self
|
|
|
|
def update_credentials(self, **kwargs) -> None:
|
|
"""
|
|
Save user credentials for tracking.
|
|
|
|
Args:
|
|
record_updated: Record that created or updated
|
|
"""
|
|
# Update confirmation or modification tracking
|
|
is_confirmed_argument = kwargs.get("is_confirmed", None)
|
|
|
|
if is_confirmed_argument and not len(kwargs) == 1:
|
|
raise ValueError("Confirm field cannot be updated with other fields")
|
|
|
|
if is_confirmed_argument:
|
|
if getattr(self.creds, "person_id", None) and getattr(
|
|
self.creds, "person_name", None
|
|
):
|
|
self.confirmed_by_id = self.creds.person_id
|
|
self.confirmed_by = self.creds.person_name
|
|
else:
|
|
if getattr(self.creds, "person_id", None) and getattr(
|
|
self.creds, "person_name", None
|
|
):
|
|
self.updated_by_id = self.creds.person_id
|
|
self.updated_by = self.creds.person_name
|
|
return
|