213 lines
7.8 KiB
Python
213 lines
7.8 KiB
Python
from sqlalchemy import BinaryExpression
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from api_validations.validations_request import ListOptions
|
|
from databases.sql_models.response_model import AlchemyResponse
|
|
from databases.sql_models.postgres_database import Base
|
|
|
|
|
|
class FilterAttributes:
|
|
"""
|
|
The class to use in the model for filtering.
|
|
Usage:
|
|
alchemy_objects = Model.filter_by_all(name="Something").data [<AlchemyObject>, <AlchemyObject>]
|
|
alchemy_object = Model.filter_by_one(name="Something").data <AlchemyObject>
|
|
alchemy_objects = Model.filter_all(Model.name == "Something").data [<AlchemyObject>, <AlchemyObject>]
|
|
alchemy_object = Model.filter_one(Model.name == "Something").data <AlchemyObject>
|
|
"""
|
|
|
|
__abstract__ = True
|
|
__session__ = Base.session # The session to use in the model.
|
|
|
|
pre_query = None # The query to use before the filtering such as: query = cls.query.filter_by(active=True)
|
|
filter_attr = None # The filter attributes to use in the model.
|
|
FilterModel = ListOptions
|
|
|
|
def flush(self):
|
|
"""Flush the current session."""
|
|
try:
|
|
self.__session__.add(self)
|
|
self.__session__.flush()
|
|
return self
|
|
except SQLAlchemyError as e:
|
|
self.raise_http_exception(
|
|
status_code="HTTP_400_BAD_REQUEST",
|
|
error_case=e.__class__.__name__,
|
|
data={},
|
|
message=str(e.__context__).split("\n")[0],
|
|
)
|
|
|
|
def destroy(self):
|
|
"""Delete the record from the database."""
|
|
self.__session__.delete(self)
|
|
self.__session__.commit()
|
|
|
|
@classmethod
|
|
def save(cls):
|
|
"""Saves the updated model to the current entity db."""
|
|
try:
|
|
cls.__session__.commit()
|
|
return cls
|
|
except SQLAlchemyError as e:
|
|
cls.raise_http_exception(
|
|
status_code="HTTP_400_BAD_REQUEST",
|
|
error_case=e.__class__.__name__,
|
|
data={},
|
|
message=str(e.__context__).split("\n")[0],
|
|
)
|
|
|
|
@classmethod
|
|
def _query(cls):
|
|
"""Returns the query to use in the model."""
|
|
return cls.pre_query if cls.pre_query else cls.query
|
|
|
|
@classmethod
|
|
def add_query_to_filter(cls, filter_query, filter_list):
|
|
return (
|
|
filter_query.order_by(
|
|
getattr(cls, filter_list.get("order_field")).desc()
|
|
if str(filter_list.get("order_type"))[0] == "d"
|
|
else getattr(cls, filter_list.get("order_field")).asc()
|
|
)
|
|
.limit(filter_list.get("size"))
|
|
.offset(int((filter_list.get("page")) - 1) * int(filter_list.get("size")))
|
|
.populate_existing()
|
|
)
|
|
|
|
@classmethod
|
|
def get_filter_attributes(cls):
|
|
"""
|
|
Returns the filter to use pagination and ordering.
|
|
page is the current page number.
|
|
size is the number of records per page.
|
|
order_field is the field to order by.
|
|
order_type is the order type (asc or desc).
|
|
include_joins returns the joined tables when related field names are given as a list.
|
|
"""
|
|
return {
|
|
"page": getattr(cls.filter_attr, "page", 1),
|
|
"size": getattr(cls.filter_attr, "size", 10),
|
|
"order_field": getattr(cls.filter_attr, "order_field", "id"),
|
|
"order_type": getattr(cls.filter_attr, "order_type", "asc"),
|
|
"include_joins": getattr(cls.filter_attr, "include_joins", []),
|
|
}
|
|
|
|
@classmethod
|
|
def add_new_arg_to_args(cls, args_list, argument, value):
|
|
new_arg_list = list(
|
|
set(
|
|
args_
|
|
for args_ in list(args_list)
|
|
if isinstance(args_, BinaryExpression)
|
|
)
|
|
)
|
|
arg_left = lambda arg_obj: getattr(getattr(arg_obj, "left", None), "key", None)
|
|
# arg_right = lambda arg_obj: getattr(getattr(arg_obj, "right", None), "value", None)
|
|
if not any(True for arg in new_arg_list if arg_left(arg_obj=arg) == argument):
|
|
new_arg_list.append(value)
|
|
return tuple(new_arg_list)
|
|
|
|
@classmethod
|
|
def get_not_expired_query_arg(cls, arg):
|
|
"""Add expiry_starts and expiry_ends to the query."""
|
|
from api_library.date_time_actions.date_functions import system_arrow
|
|
|
|
arg = cls.add_new_arg_to_args(
|
|
arg, "expiry_ends", cls.expiry_ends > str(system_arrow.now())
|
|
)
|
|
arg = cls.add_new_arg_to_args(
|
|
arg, "expiry_starts", cls.expiry_starts <= str(system_arrow.now())
|
|
)
|
|
return arg
|
|
|
|
@classmethod
|
|
def select_only(
|
|
cls, *args, select_args: list, order_by=None, limit=None, system=False
|
|
):
|
|
if not system:
|
|
args = cls.add_new_arg_to_args(
|
|
args, "is_confirmed", cls.is_confirmed == True
|
|
)
|
|
args = cls.get_not_expired_query_arg(args)
|
|
query = cls._query().filter(*args).with_entities(*select_args)
|
|
if order_by is not None:
|
|
query = query.order_by(order_by)
|
|
if limit:
|
|
query = query.limit(limit)
|
|
return AlchemyResponse(query=query, first=False)
|
|
|
|
@classmethod
|
|
def filter_by_all(cls, system=False, **kwargs):
|
|
"""
|
|
Filters all the records regardless of is_deleted, is_confirmed.
|
|
"""
|
|
if "is_confirmed" not in kwargs and not system:
|
|
kwargs["is_confirmed"] = True
|
|
kwargs.pop("system", None)
|
|
query = cls._query().filter_by(**kwargs)
|
|
if cls.filter_attr:
|
|
filter_list = cls.get_filter_attributes()
|
|
data_query = cls.add_query_to_filter(query, filter_list)
|
|
return AlchemyResponse(query=data_query, first=False)
|
|
return AlchemyResponse(query=query, first=False)
|
|
|
|
@classmethod
|
|
def filter_by_one(cls, system=False, **kwargs):
|
|
"""
|
|
Filters one record regardless of is_deleted, is_confirmed.
|
|
"""
|
|
if "is_confirmed" not in kwargs and not system:
|
|
kwargs["is_confirmed"] = True
|
|
kwargs.pop("system", None)
|
|
query = cls._query().filter_by(**kwargs)
|
|
return AlchemyResponse(query=query, first=True)
|
|
|
|
@classmethod
|
|
def filter_all(cls, *args, system=False):
|
|
"""
|
|
Filters all the records regardless of is_deleted, is_confirmed.
|
|
"""
|
|
if not system:
|
|
args = cls.add_new_arg_to_args(
|
|
args, "is_confirmed", cls.is_confirmed == True
|
|
)
|
|
args = cls.get_not_expired_query_arg(args)
|
|
|
|
query = cls._query().filter(*args)
|
|
if cls.filter_attr:
|
|
filter_list = cls.get_filter_attributes()
|
|
data_query = cls.add_query_to_filter(query, filter_list)
|
|
return AlchemyResponse(query=data_query, first=False)
|
|
return AlchemyResponse(query=query, first=False)
|
|
|
|
@classmethod
|
|
def filter_one(cls, *args, system=False, expired: bool = False):
|
|
"""
|
|
Filters one record regardless of is_deleted, is_confirmed.
|
|
"""
|
|
if not system:
|
|
args = cls.add_new_arg_to_args(
|
|
args, "is_confirmed", cls.is_confirmed == True
|
|
)
|
|
args = cls.get_not_expired_query_arg(args)
|
|
query = cls._query().filter(*args)
|
|
return AlchemyResponse(query=query, first=True)
|
|
|
|
@classmethod
|
|
def raise_http_exception(cls, status_code, error_case, data, message):
|
|
from fastapi.exceptions import HTTPException
|
|
from fastapi import status
|
|
from json import dumps
|
|
|
|
cls.__session__.rollback()
|
|
raise HTTPException(
|
|
status_code=getattr(status, status_code, "HTTP_404_NOT_FOUND"),
|
|
detail=dumps(
|
|
{
|
|
"data": data,
|
|
"error": error_case,
|
|
"message": message,
|
|
}
|
|
),
|
|
)
|