import datetime from decimal import Decimal from typing import Any, Dict, List, Optional from sqlalchemy import TIMESTAMP, NUMERIC from sqlalchemy.orm import Session, Mapped from pydantic import BaseModel from ApiLibrary import system_arrow, get_line_number_for_error, client_arrow from ErrorHandlers.Exceptions.api_exc import HTTPExceptionApi from Services.PostgresDb.Models.core_alchemy import BaseAlchemyModel from Services.PostgresDb.Models.system_fields import SystemFields class MetaDataRow(BaseModel): created: Optional[bool] = False message: Optional[str] = None error_case: Optional[str] = None class Credentials(BaseModel): person_id: int person_name: str class CrudActions(SystemFields): @classmethod def extract_system_fields( cls, filter_kwargs: dict, create: bool = True ) -> Dict[str, Any]: """ Remove system-managed fields from input dictionary. Args: filter_kwargs: Input dictionary of fields create: If True, use creation field list, else use update field list Returns: Dictionary with system fields removed """ system_fields = filter_kwargs.copy() extract_fields = ( cls.__system__fields__create__ if create else cls.__system__fields__update__ ) for field in extract_fields: system_fields.pop(field, None) return system_fields @classmethod def remove_non_related_inputs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: """ Filter out inputs that don't correspond to model fields. Args: kwargs: Dictionary of field names and values Returns: Dictionary containing only valid model fields """ return { key: value for key, value in kwargs.items() if key in cls.columns + cls.hybrid_properties + cls.settable_relations } @classmethod def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]: """ Process a field value based on its type and convert it to the appropriate format. Args: val: Field value key: Field name Returns: Tuple of (should_include, processed_value) """ key_ = cls.__annotations__.get(key, None) is_primary = key in cls.primary_keys row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None)) # Skip primary keys and foreign keys if is_primary or row_attr: return False, None # Handle None values if val is None: return True, None # Special handling for UUID fields if str(key[-5:]).lower() == "uu_id": return True, str(val) # Handle typed fields if key_: if key_ == Mapped[int]: return True, int(val) elif key_ == Mapped[bool]: return True, bool(val) elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]: return True, round(float(val), 3) elif key_ == Mapped[TIMESTAMP]: return True, str( system_arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ") ) elif key_ == Mapped[str]: return True, str(val) # Handle based on Python types else: if isinstance(val, datetime.datetime): return True, str( system_arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ") ) elif isinstance(val, bool): return True, bool(val) elif isinstance(val, (float, Decimal)): return True, round(float(val), 3) elif isinstance(val, int): return True, int(val) elif isinstance(val, str): return True, str(val) elif val is None: return True, None return False, None def get_dict( self, exclude: Optional[List[str]] = None, include: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Convert model instance to dictionary with customizable fields. Args: exclude: List of fields to exclude include: List of fields to include (takes precedence over exclude) Returns: Dictionary representation of the model """ return_dict: Dict[str, Any] = {} if include: # Handle explicitly included fields exclude_list = [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" and str(element)[-5:].lower() == "uu_id" ] columns_include_list = list(set(include).difference(set(exclude_list))) columns_include_list.extend(["uu_id"]) for key in columns_include_list: val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database elif exclude: # Handle explicitly excluded fields exclude.extend( list( set(getattr(self, "__exclude__fields__", []) or []).difference( exclude ) ) ) exclude.extend( [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" ] ) columns_excluded_list = list(set(self.columns).difference(set(exclude))) columns_excluded_list.extend(["uu_id", "active"]) for key in columns_excluded_list: val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database else: # Handle default field selection exclude_list = (getattr(self, "__exclude__fields__", []) or []) + list( self.__system_default_model__ ) columns_list = list(set(self.columns).difference(set(exclude_list))) columns_list = [col for col in columns_list if str(col)[-2:] != "id"] columns_list.extend( [col for col in self.columns if str(col)[-5:].lower() == "uu_id"] ) for remove_field in self.__system_default_model__: if remove_field in columns_list: columns_list.remove(remove_field) for key in columns_list: val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database return return_dict class CRUDModel(BaseAlchemyModel, CrudActions): __abstract__ = True meta_data: MetaDataRow creds: Credentials = None @property def is_created(self): return self.meta_data.created @classmethod def create_credentials(cls, record_created) -> None: """ Save user credentials for tracking. Args: record_created: Record that created or updated """ if getattr(cls.creds, "person_id", None) and getattr( cls.creds, "person_name", None ): record_created.created_by_id = cls.creds.person_id record_created.created_by = cls.creds.person_name return @classmethod def update_metadata( cls, created: bool, error_case: str = None, message: str = None ) -> None: cls.meta_data = MetaDataRow( created=created, error_case=error_case, message=message ) @classmethod def raise_exception(cls): raise HTTPExceptionApi( error_code=cls.meta_data.error_case, lang=cls.lang, loc=get_line_number_for_error(), sys_msg=cls.meta_data.message, ) @classmethod def create_or_abort(cls, db: Session, **kwargs): """ Create a new record or abort if it already exists. Args: db: Database session **kwargs: Record fields Returns: New record if successfully created """ check_kwargs = cls.extract_system_fields(kwargs) # Search for existing record query = db.query(cls).filter( cls.expiry_ends > str(system_arrow.now()), cls.expiry_starts <= str(system_arrow.now()), ) for key, value in check_kwargs.items(): if hasattr(cls, key): query = query.filter(getattr(cls, key) == value) already_record = query.first() # Handle existing record if already_record: if already_record.deleted: cls.update_metadata(created=False, error_case="DeletedRecord") cls.raise_exception() elif not already_record.is_confirmed: cls.update_metadata(created=False, error_case="IsNotConfirmed") cls.raise_exception() cls.update_metadata(created=False, error_case="AlreadyExists") cls.raise_exception() # Create new record check_kwargs = cls.remove_non_related_inputs(check_kwargs) created_record = cls() for key, value in check_kwargs.items(): setattr(created_record, key, value) cls.create_credentials(created_record) db.add(created_record) db.flush() cls.update_metadata(created=True) return created_record @classmethod def find_or_create(cls, db: Session, **kwargs): """ Find an existing record matching the criteria or create a new one. Args: db: Database session **kwargs: Search/creation criteria Returns: Existing or newly created record """ check_kwargs = cls.extract_system_fields(kwargs) # Search for existing record query = db.query(cls).filter( cls.expiry_ends > str(system_arrow.now()), cls.expiry_starts <= str(system_arrow.now()), ) for key, value in check_kwargs.items(): if hasattr(cls, key): query = query.filter(getattr(cls, key) == value) already_record = query.first() # Handle existing record if already_record: if already_record.deleted: cls.update_metadata(created=False, error_case="DeletedRecord") return already_record elif not already_record.is_confirmed: cls.update_metadata(created=False, error_case="IsNotConfirmed") return already_record cls.update_metadata(created=False, error_case="AlreadyExists") return already_record # Create new record check_kwargs = cls.remove_non_related_inputs(check_kwargs) created_record = cls() for key, value in check_kwargs.items(): setattr(created_record, key, value) cls.create_credentials(created_record) db.add(created_record) db.flush() cls.update_metadata(created=True) return created_record def update(self, db: Session, **kwargs): """ Update the record with new values. Args: db: Database session **kwargs: Fields to update Returns: Updated record Raises: ValueError: If attempting to update is_confirmed with other fields """ check_kwargs = self.remove_non_related_inputs(kwargs) check_kwargs = self.extract_system_fields(check_kwargs, create=False) for key, value in check_kwargs.items(): setattr(self, key, value) self.update_credentials(kwargs=kwargs) db.flush() return self def update_credentials(self, **kwargs) -> None: """ Save user credentials for tracking. Args: record_updated: Record that created or updated """ # Update confirmation or modification tracking is_confirmed_argument = kwargs.get("is_confirmed", None) if is_confirmed_argument and not len(kwargs) == 1: raise ValueError("Confirm field cannot be updated with other fields") if is_confirmed_argument: if getattr(self.creds, "person_id", None) and getattr( self.creds, "person_name", None ): self.confirmed_by_id = self.creds.person_id self.confirmed_by = self.creds.person_name else: if getattr(self.creds, "person_id", None) and getattr( self.creds, "person_name", None ): self.updated_by_id = self.creds.person_id self.updated_by = self.creds.person_name return