import datetime from decimal import Decimal from typing import Union from sqlalchemy import ( TIMESTAMP, NUMERIC, func, text, UUID, String, Integer, Boolean, SmallInteger, ) from sqlalchemy.orm import ( Mapped, mapped_column, InstrumentedAttribute, ) from sqlalchemy_mixins.session import SessionMixin from sqlalchemy_mixins.serialize import SerializeMixin from sqlalchemy_mixins.repr import ReprMixin from sqlalchemy_mixins.smartquery import SmartQueryMixin from api_library.date_time_actions.date_functions import DateTimeLocal, client_arrow from api_objects import EmployeeTokenObject, OccupantTokenObject from api_objects.auth.token_objects import Credentials from databases.sql_models.sql_operations import FilterAttributes from databases.sql_models.postgres_database import Base class CrudMixin(Base, SmartQueryMixin, SessionMixin, FilterAttributes): __abstract__ = True # The model is abstract not a database table. __session__ = Base.session # The session to use in the model. __system__fields__create__ = ( "ref_id", "created_at", "updated_at", "cryp_uu_id", "created_by", "created_by_id", "updated_by", "updated_by_id", "replication_id", "confirmed_by", "confirmed_by_id", "is_confirmed", "deleted", "active", "is_notification_send", "is_email_send", "expiry_starts", "expiry_ends", ) # The system fields to use in the model. __system__fields__update__ = ( "cryp_uu_id", "created_at", "updated_at", "created_by", "created_by_id", "confirmed_by", "confirmed_by_id", "updated_by", "updated_by_id", "replication_id", ) __system_default_model__ = [ "cryp_uu_id", "is_confirmed", "deleted", "is_notification_send", "replication_id", "is_email_send", "confirmed_by_id", "confirmed_by", "updated_by_id", "created_by_id", ] creds: Credentials = None # The credentials to use in the model. client_arrow: DateTimeLocal = None # The arrow to use in the model. valid_record_dict: dict = {"active": True, "deleted": False} valid_record_args = lambda class_: [class_.active == True, class_.deleted == False] expiry_starts: Mapped[TIMESTAMP] = mapped_column( TIMESTAMP, server_default=func.now(), nullable=False ) expiry_ends: Mapped[TIMESTAMP] = mapped_column( TIMESTAMP, default="2099-12-31", server_default="2099-12-31" ) @classmethod def set_user_define_properties( cls, token: Union[EmployeeTokenObject, OccupantTokenObject] ): cls.creds = token.credentials cls.client_arrow = DateTimeLocal(is_client=True, timezone=token.timezone) @classmethod def remove_non_related_inputs(cls, kwargs): """ Removes the non-related inputs from the given attributes. """ return { key: value for key, value in kwargs.items() if key in cls.columns + cls.hybrid_properties + cls.settable_relations } @classmethod def extract_system_fields(cls, filter_kwargs: dict, create: bool = True): """ Extracts the system fields from the given attributes. """ system_fields = filter_kwargs.copy() extract_fields = ( cls.__system__fields__create__ if create else cls.__system__fields__update__ ) for field in extract_fields: system_fields.pop(field, None) return system_fields @classmethod def iterate_over_variables(cls, val, key): key_ = cls.__annotations__.get(key, None) is_primary, value_type = key in cls.primary_keys, type(val) row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None)) if is_primary or row_attr and key_ == Mapped[int]: return None if key_: if key_ == Mapped[int]: return int(val) if val else None elif key_ == Mapped[bool]: return bool(val) if val else None elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]: return round(float(val), 3) if val else None elif key_ == Mapped[int]: return int(val) if val else None elif key_ == Mapped[TIMESTAMP]: formatted_date = client_arrow.get(str(val)).format( "DD-MM-YYYY HH:mm:ss" ) return str(formatted_date) if val else None else: if isinstance(val, datetime.datetime): formatted_date = client_arrow.get(str(val)).format( "DD-MM-YYYY HH:mm:ss" ) print(key, "isinstance(value_type, datetime) | ", formatted_date) return str(formatted_date) if val else None elif isinstance(value_type, bool): return bool(val) if val else None elif isinstance(value_type, float) or isinstance(value_type, Decimal): return round(float(val), 3) if val else None elif isinstance(value_type, int): return int(val) if val else None elif isinstance(value_type, str): return str(val) if val else None elif isinstance(value_type, type(None)): return None return str(val) if val else None @classmethod def find_or_create(cls, **kwargs): from api_library.date_time_actions.date_functions import system_arrow """ Finds a record with the given attributes or creates it if it doesn't exist. If found, sets is_found to True, otherwise False. is_found can be used to check if the record was found or created. """ check_kwargs = cls.extract_system_fields(kwargs) cls.pre_query = cls.query.filter(cls.expiry_ends < system_arrow.now().date()) already_record = cls.filter_by_one(**check_kwargs, system=True) cls.pre_query = None if already_record := already_record.data: if already_record.deleted: cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="DeletedRecord", data=check_kwargs, message="Record exits but is deleted. Contact with authorized user", ) elif not already_record.is_confirmed: cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="IsNotConfirmed", data=check_kwargs, message="Record exits but is not confirmed. Contact with authorized user", ) cls.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="AlreadyExists", data=check_kwargs, message="Record already exits. Refresh data and try again", ) check_kwargs = cls.remove_non_related_inputs(check_kwargs) created_record = cls() for key, value in check_kwargs.items(): setattr(created_record, key, value) created_record.flush() cls.created_by_id = cls.creds.person_id cls.created_by = cls.creds.person_name return created_record def update(self, **kwargs): check_kwargs = self.remove_non_related_inputs(kwargs) """Updates the record with the given attributes.""" is_confirmed_argument = kwargs.get("is_confirmed", None) if is_confirmed_argument and not len(kwargs) == 1: self.raise_http_exception( status_code="HTTP_406_NOT_ACCEPTABLE", error_case="ConfirmError", data=kwargs, message="Confirm field can not be updated with other fields", ) check_kwargs = self.extract_system_fields(check_kwargs, create=False) for key, value in check_kwargs.items(): setattr(self, key, value) if is_confirmed_argument: self.confirmed_by_id = self.creds.person_id self.confirmed_by = self.creds.person_name else: self.updated_by_id = self.creds.person_id self.updated_by = self.creds.person_name self.flush() return self def get_dict( self, exclude: list = None, include: list = None, include_joins: list = None ): return_dict = {} if include: exclude_list = [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" ] columns_include_list = list(set(include).difference(set(exclude_list))) columns_include_list.extend(["uu_id", "active"]) for key in list(columns_include_list): val = getattr(self, key) value_of_database = self.iterate_over_variables(val, key) if value_of_database is not None: return_dict[key] = value_of_database elif exclude: exclude.extend( list(set(self.__exclude__fields__ or []).difference(exclude)) ) for i in self.__system_default_model__: print("i", str(i)[-2:]) exclude.extend( [ element for element in self.__system_default_model__ if str(element)[-2:] == "id" ] ) columns_excluded_list = set(self.columns).difference(set(exclude)) for key in list(columns_excluded_list): val = getattr(self, key) value_of_database = self.iterate_over_variables(val, key) if value_of_database is not None: return_dict[key] = value_of_database else: exclude_list = ( self.__exclude__fields__ or [] + self.__system_default_model__ ) columns_list = list(set(self.columns).difference(set(exclude_list))) for key in list(columns_list): val = getattr(self, key) value_of_database = self.iterate_over_variables(val, key) if value_of_database is not None: return_dict[key] = value_of_database all_arguments = [ record for record in self.__class__.__dict__ if "_" not in record[0] and "id" not in record[-2:] ] for all_argument in all_arguments: column = getattr(self.__class__, all_argument) is_populate = isinstance(column, InstrumentedAttribute) and not hasattr( column, "foreign_keys" ) if is_populate and all_argument in include_joins or []: populate_arg = getattr(self, all_argument, None) if isinstance(populate_arg, list): return_dict[all_argument] = [ arg.get_dict() if arg else [] for arg in populate_arg ] elif getattr(populate_arg, "get_dict", None): return_dict[all_argument] = ( populate_arg.get_dict() if populate_arg else [] ) return dict(sorted(return_dict.items(), reverse=False)) class BaseMixin(CrudMixin, ReprMixin, SerializeMixin, FilterAttributes): __abstract__ = True class BaseCollection(BaseMixin): __abstract__ = True __repr__ = ReprMixin.__repr__ id: Mapped[int] = mapped_column(primary_key=True) class CrudCollection(BaseMixin, SmartQueryMixin): __abstract__ = True __repr__ = ReprMixin.__repr__ id: Mapped[int] = mapped_column(primary_key=True) uu_id: Mapped[UUID] = mapped_column( UUID, server_default=text("gen_random_uuid()"), index=True, unique=True ) ref_id: Mapped[UUID] = mapped_column(String(100), nullable=True, index=True) created_at: Mapped[TIMESTAMP] = mapped_column( "created_at", TIMESTAMP(timezone=True), server_default=func.now(), nullable=False, index=True, ) updated_at: Mapped[TIMESTAMP] = mapped_column( "updated_at", TIMESTAMP(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, index=True, ) cryp_uu_id: Mapped[UUID] = mapped_column(String, nullable=True, index=True) created_by: Mapped[str] = mapped_column(String, nullable=True) created_by_id: Mapped[int] = mapped_column(Integer, nullable=True) updated_by: Mapped[str] = mapped_column(String, nullable=True) updated_by_id: Mapped[int] = mapped_column(Integer, nullable=True) confirmed_by: Mapped[str] = mapped_column(String, nullable=True) confirmed_by_id: Mapped[int] = mapped_column(Integer, nullable=True) is_confirmed: Mapped[bool] = mapped_column(Boolean, server_default="0") replication_id: Mapped[int] = mapped_column(SmallInteger, server_default="0") deleted: Mapped[bool] = mapped_column(Boolean, server_default="0") active: Mapped[bool] = mapped_column(Boolean, server_default="1") is_notification_send: Mapped[bool] = mapped_column(Boolean, server_default="0") is_email_send: Mapped[bool] = mapped_column(Boolean, server_default="0")