import arrow import datetime from typing import Optional, Any, Dict from decimal import Decimal from fastapi.exceptions import HTTPException from sqlalchemy import TIMESTAMP, NUMERIC from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm import Session, Mapped class MetaData: """ Class to store metadata for a query. """ created: bool = False updated: bool = False deleted: bool = False class CRUDModel: """ Base class for CRUD operations on PostgreSQL models. Features: - User credential tracking - Metadata tracking for operations - Type-safe field handling - Automatic timestamp management - Soft delete support """ __abstract__ = True # creds: Credentials = None meta_data: MetaData = MetaData() # Define required columns for CRUD operations required_columns = { "expiry_starts": TIMESTAMP, "expiry_ends": TIMESTAMP, "created_by": str, "created_by_id": int, "updated_by": str, "updated_by_id": int, "deleted": bool, } @classmethod def raise_exception( cls, message: str = "Exception raised.", status_code: int = 400 ): """ Raise HTTP exception with custom message and status code. Args: message: Error message status_code: HTTP status code """ raise HTTPException(status_code=status_code, detail={"message": message}) @classmethod def create_or_abort(cls, db: Session, **kwargs): """ Create a new record or abort if it already exists. Args: db: Database session **kwargs: Record fields Returns: New record if successfully created Raises: HTTPException: If record already exists or creation fails """ try: # Search for existing record query = db.query(cls).filter( cls.expiry_ends > str(arrow.now()), cls.expiry_starts <= str(arrow.now()), ) for key, value in kwargs.items(): if hasattr(cls, key): query = query.filter(getattr(cls, key) == value) already_record = query.first() # Handle existing record if already_record and already_record.deleted: cls.raise_exception("Record already exists and is deleted") elif already_record: cls.raise_exception("Record already exists") # Create new record created_record = cls() for key, value in kwargs.items(): setattr(created_record, key, value) db.add(created_record) db.flush() return created_record except Exception as e: db.rollback() cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500) @classmethod def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]: """ Process a field value based on its type and convert it to the appropriate format. Args: val: Field value key: Field name Returns: Tuple of (should_include, processed_value) """ try: key_ = cls.__annotations__.get(key, None) is_primary = key in getattr(cls, "primary_keys", []) row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None)) # Skip primary keys and foreign keys if is_primary or row_attr: return False, None if val is None: # Handle None values return True, None if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields return True, str(val) if key_: # Handle typed fields if key_ == Mapped[int]: return True, int(val) elif key_ == Mapped[bool]: return True, bool(val) elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]: return True, round(float(val), 3) elif key_ == Mapped[TIMESTAMP]: return True, str( arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ") ) elif key_ == Mapped[str]: return True, str(val) else: # Handle based on Python types if isinstance(val, datetime.datetime): return True, str( arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ") ) elif isinstance(val, bool): return True, bool(val) elif isinstance(val, (float, Decimal)): return True, round(float(val), 3) elif isinstance(val, int): return True, int(val) elif isinstance(val, str): return True, str(val) elif val is None: return True, None return False, None except Exception as e: err = e return False, None def get_dict( self, exclude_list: Optional[list[InstrumentedAttribute]] = None ) -> Dict[str, Any]: """ Convert model instance to dictionary with customizable fields. Args: exclude_list: List of fields to exclude from the dictionary Returns: Dictionary representation of the model """ try: return_dict: Dict[str, Any] = {} exclude_list = exclude_list or [] exclude_list = [exclude_arg.key for exclude_arg in exclude_list] # Get all column names from the model columns = [col.name for col in self.__table__.columns] columns_set = set(columns) # Filter columns columns_list = set([col for col in columns_set if str(col)[-2:] != "id"]) columns_extend = set( col for col in columns_set if str(col)[-5:].lower() == "uu_id" ) columns_list = set(columns_list) | set(columns_extend) columns_list = list(set(columns_list) - set(exclude_list)) for key in columns_list: val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database return return_dict except Exception as e: err = e return {} @classmethod def find_or_create( cls, db: Session, exclude_args: Optional[list[InstrumentedAttribute]] = None, include_args: Optional[list[InstrumentedAttribute]] = None, **kwargs, ): """ Find an existing record matching the criteria or create a new one. Args: db: Database session exclude_args: Keys to exclude from search include_args: Keys to specifically include in search (if provided, only these will be used) **kwargs: Search/creation criteria Returns: Existing or newly created record """ try: # Search for existing record query = db.query(cls).filter( cls.expiry_ends > str(arrow.now()), cls.expiry_starts <= str(arrow.now()), ) exclude_args = exclude_args or [] exclude_args = [exclude_arg.key for exclude_arg in exclude_args] include_args = include_args or [] include_args = [include_arg.key for include_arg in include_args] # If include_args is provided, only use those fields for matching # Otherwise, use all fields except those in exclude_args for key, value in kwargs.items(): if hasattr(cls, key): if include_args and key in include_args: query = query.filter(getattr(cls, key) == value) elif not include_args and key not in exclude_args: query = query.filter(getattr(cls, key) == value) already_record = query.first() if already_record: # Handle existing record cls.meta_data.created = False return already_record # Create new record created_record = cls() for key, value in kwargs.items(): setattr(created_record, key, value) db.add(created_record) db.flush() cls.meta_data.created = True return created_record except Exception as e: db.rollback() print('e', e) cls.raise_exception( f"Failed to find or create record: {str(e)}", status_code=500 ) def update(self, db: Session, **kwargs): """ Update the record with new values. Args: db: Database session **kwargs: Fields to update Returns: Updated record Raises: HTTPException: If update fails """ try: for key, value in kwargs.items(): setattr(self, key, value) db.flush() self.meta_data.updated = True return self except Exception as e: self.meta_data.updated = False db.rollback() self.raise_exception(f"Failed to update record: {str(e)}", status_code=500)