import datetime from decimal import Decimal from sqlalchemy import ( TIMESTAMP, NUMERIC, func, text, UUID, String, Integer, Boolean, SmallInteger, ) from sqlalchemy.orm import ( Mapped, mapped_column, InstrumentedAttribute, ) from sqlalchemy_mixins.session import SessionMixin from sqlalchemy_mixins.serialize import SerializeMixin from sqlalchemy_mixins.repr import ReprMixin from sqlalchemy_mixins.smartquery import SmartQueryMixin from api_library.date_time_actions.date_functions import DateTimeLocal, client_arrow from databases.sql_models.sql_operations import FilterAttributes from databases.sql_models.postgres_database import Base class CrudMixin(Base, SmartQueryMixin, SessionMixin, FilterAttributes): __abstract__ = True # The model is abstract not a database table. __session__ = Base.session # The session to use in the model. __system__fields__create__ = ( "created_at", "updated_at", "cryp_uu_id", "created_by", "created_by_id", "updated_by", "updated_by_id", "replication_id", "confirmed_by", "confirmed_by_id", "is_confirmed", "deleted", "active", "is_notification_send", "is_email_send", ) # The system fields to use in the model. __system__fields__update__ = ( "cryp_uu_id", "created_at", "updated_at", "created_by", "created_by_id", "confirmed_by", "confirmed_by_id", "updated_by", "updated_by_id", "replication_id", ) __system_default_model__ = [ "cryp_uu_id", "is_confirmed", "deleted", "is_notification_send", "replication_id", "is_email_send", "confirmed_by_id", "confirmed_by", "updated_by_id", "created_by_id", ] creds = None # The credentials to use in the model. client_arrow: DateTimeLocal = None # The arrow to use in the model. valid_record_dict: dict = {"active": True, "deleted": False} valid_record_args = lambda class_: [class_.active == True, class_.deleted == False] expiry_starts: Mapped[TIMESTAMP] = mapped_column( TIMESTAMP(timezone=True), server_default=func.now(), nullable=False ) expiry_ends: Mapped[TIMESTAMP] = mapped_column( TIMESTAMP(timezone=True), default="2099-12-31", server_default="2099-12-31" ) @classmethod def set_user_define_properties(cls, token): cls.creds = token.credentials cls.client_arrow = DateTimeLocal(is_client=True, timezone=token.timezone) @classmethod def remove_non_related_inputs(cls, kwargs): """ Removes the non-related inputs from the given attributes. """ return { key: value for key, value in kwargs.items() if key in cls.columns + cls.hybrid_properties + cls.settable_relations } @classmethod def extract_system_fields(cls, filter_kwargs: dict, create: bool = True): """ Extracts the system fields from the given attributes. """ system_fields = filter_kwargs.copy() extract_fields = ( cls.__system__fields__create__ if create else cls.__system__fields__update__ ) for field in extract_fields: system_fields.pop(field, None) return system_fields @classmethod def iterate_over_variables(cls, val, key): key_ = cls.__annotations__.get(key, None) is_primary, value_type = key in cls.primary_keys, type(val) row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None)) if is_primary or row_attr: return False, None elif val is None: return True, None elif str(key[-5:]).lower() == "uu_id": return True, str(val) elif key_: if key_ == Mapped[int]: return True, int(val) elif key_ == Mapped[bool]: return True, bool(val) elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]: return True, round(float(val), 3) elif key_ == Mapped[int]: return True, int(val) elif key_ == Mapped[TIMESTAMP]: return True, str( client_arrow.get(str(val)).format("DD-MM-YYYY HH:mm:ss") ) elif key_ == Mapped[str]: return True, str(val) else: if isinstance(val, datetime.datetime): return True, str( client_arrow.get(str(val)).format("DD-MM-YYYY HH:mm:ss") ) elif isinstance(value_type, bool): return True, bool(val) elif isinstance(value_type, float) or isinstance(value_type, Decimal): return True, round(float(val), 3) elif isinstance(value_type, int): return True, int(val) elif isinstance(value_type, str): return True, str(val) elif isinstance(value_type, type(None)): return True, None return False, None @classmethod def find_or_create(cls, **kwargs): from api_library.date_time_actions.date_functions import system_arrow """ Finds a record with the given attributes or creates it if it doesn't exist. If found, sets is_found to True, otherwise False. is_found can be used to check if the record was found or created. """ check_kwargs = cls.extract_system_fields(kwargs) cls.pre_query = cls.query.filter( cls.expiry_ends > str(system_arrow.now()), cls.expiry_starts <= str(system_arrow.now()), ) already_record = cls.filter_by_one(system=True, **check_kwargs).data cls.pre_query = None if already_record: if already_record.deleted: cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="DeletedRecord", data=check_kwargs, message="Record exits but is deleted. Contact with authorized user", ) elif not already_record.is_confirmed: cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="IsNotConfirmed", data=check_kwargs, message="Record exits but is not confirmed. Contact with authorized user", ) cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="AlreadyExists", data=check_kwargs, message="Record already exits. Refresh data and try again", ) check_kwargs = cls.remove_non_related_inputs(check_kwargs) created_record = cls() for key, value in check_kwargs.items(): setattr(created_record, key, value) if getattr(cls.creds, "person_id", None) and getattr( cls.creds, "person_name", None ): cls.created_by_id = cls.creds.get("person_id", None) cls.created_by = cls.creds.get("person_name", None) created_record.flush() return created_record def update(self, **kwargs): check_kwargs = self.remove_non_related_inputs(kwargs) """Updates the record with the given attributes.""" is_confirmed_argument = kwargs.get("is_confirmed", None) if is_confirmed_argument and not len(kwargs) == 1: self.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="ConfirmError", data=kwargs, message="Confirm field can not be updated with other fields", ) check_kwargs = self.extract_system_fields(check_kwargs, create=False) for key, value in check_kwargs.items(): setattr(self, key, value) if is_confirmed_argument: if getattr(self.creds, "person_id", None) and getattr( self.creds, "person_name", None ): self.confirmed_by_id = self.creds.get("person_id", "Unknown") self.confirmed_by = self.creds.get("person_name", "Unknown") else: if getattr(self.creds, "person_id", None) and getattr( self.creds, "person_name", None ): self.updated_by_id = self.creds.get("person_id", "Unknown") self.updated_by = self.creds.get("person_id", "Unknown") self.flush() return self def get_dict( self, exclude: list = None, include: list = None, include_joins: list = None ): return_dict = {} if include: exclude_list = [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" and str(element)[-5:].lower() == "uu_id" ] columns_include_list = list(set(include).difference(set(exclude_list))) # columns_include_list.extend([column for column in self.columns if str(column)[-5:].lower() == 'uu_id']) columns_include_list.extend(["uu_id"]) for key in list(columns_include_list): val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database elif exclude: exclude.extend( list(set(self.__exclude__fields__ or []).difference(exclude)) ) exclude.extend( [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" ] ) columns_excluded_list = list(set(self.columns).difference(set(exclude))) # columns_excluded_list.extend([column for column in self.columns if str(column)[-5:].lower() == 'uu_id']) columns_excluded_list.extend(["uu_id", "active"]) for key in list(columns_excluded_list): val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database else: exclude_list = ( self.__exclude__fields__ or [] + self.__system_default_model__ ) columns_list = list(set(self.columns).difference(set(exclude_list))) columns_list = [ columns for columns in columns_list if str(columns)[-2:] != "id" ] columns_list.extend( [ column for column in self.columns if str(column)[-5:].lower() == "uu_id" ] ) for remove_field in self.__system_default_model__: if remove_field in columns_list: columns_list.remove(remove_field) for key in list(columns_list): val = getattr(self, key) correct, value_of_database = self.iterate_over_variables(val, key) if correct: return_dict[key] = value_of_database # all_arguments = [ # record # for record in self.__class__.__dict__ # if "_" not in record[0] and "id" not in record[-2:] # ] # # for all_argument in all_arguments: # column = getattr(self.__class__, all_argument) # is_populate = isinstance(column, InstrumentedAttribute) and not hasattr( # column, "foreign_keys" # ) # if is_populate and all_argument in include_joins or []: # populate_arg = getattr(self, all_argument, None) # if isinstance(populate_arg, list): # return_dict[all_argument] = [ # arg.get_dict() if arg else [] for arg in populate_arg # ] # elif getattr(populate_arg, "get_dict", None): # return_dict[all_argument] = ( # populate_arg.get_dict() if populate_arg else [] # ) # return dict(sorted(return_dict.items(), reverse=False)) return return_dict class BaseMixin(CrudMixin, ReprMixin, SerializeMixin, FilterAttributes): __abstract__ = True class BaseCollection(BaseMixin): __abstract__ = True __repr__ = ReprMixin.__repr__ id: Mapped[int] = mapped_column(primary_key=True) class CrudCollection(BaseMixin, SmartQueryMixin): __abstract__ = True __repr__ = ReprMixin.__repr__ id: Mapped[int] = mapped_column(primary_key=True) uu_id: Mapped[str] = mapped_column( UUID, server_default=text("gen_random_uuid()"), index=True, unique=True ) ref_id: Mapped[str] = mapped_column(String(100), nullable=True, index=True) created_at: Mapped[TIMESTAMP] = mapped_column( "created_at", TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True, ) updated_at: Mapped[TIMESTAMP] = mapped_column( "updated_at", TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, index=True, ) cryp_uu_id: Mapped[str] = mapped_column(String, nullable=True, index=True) created_by: Mapped[str] = mapped_column(String, nullable=True) created_by_id: Mapped[int] = mapped_column(Integer, nullable=True) updated_by: Mapped[str] = mapped_column(String, nullable=True) updated_by_id: Mapped[int] = mapped_column(Integer, nullable=True) confirmed_by: Mapped[str] = mapped_column(String, nullable=True) confirmed_by_id: Mapped[int] = mapped_column(Integer, nullable=True) is_confirmed: Mapped[bool] = mapped_column(Boolean, server_default="0") replication_id: Mapped[int] = mapped_column(SmallInteger, server_default="0") deleted: Mapped[bool] = mapped_column(Boolean, server_default="0") active: Mapped[bool] = mapped_column(Boolean, server_default="1") is_notification_send: Mapped[bool] = mapped_column(Boolean, server_default="0") is_email_send: Mapped[bool] = mapped_column(Boolean, server_default="0")