156 lines
5.6 KiB
Python
156 lines
5.6 KiB
Python
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from databases.sql_models.response_model import AlchemyResponse
|
|
from databases.sql_models.postgres_database import Base
|
|
|
|
|
|
class FilterAttributes:
|
|
"""
|
|
The class to use in the model for filtering.
|
|
Usage:
|
|
alchemy_objects = Model.filter_by_all(name="Something").data [<AlchemyObject>, <AlchemyObject>]
|
|
alchemy_object = Model.filter_by_one(name="Something").data <AlchemyObject>
|
|
alchemy_objects = Model.filter_all(Model.name == "Something").data [<AlchemyObject>, <AlchemyObject>]
|
|
alchemy_object = Model.filter_one(Model.name == "Something").data <AlchemyObject>
|
|
"""
|
|
|
|
__abstract__ = True
|
|
__session__ = Base.session # The session to use in the model.
|
|
|
|
pre_query = None # The query to use before the filtering such as: query = cls.query.filter_by(active=True)
|
|
filter_attr = None # The filter attributes to use in the model.
|
|
|
|
def flush(self):
|
|
"""Flush the current session."""
|
|
try:
|
|
self.__session__.add(self)
|
|
self.__session__.flush()
|
|
except SQLAlchemyError as e:
|
|
self.raise_http_exception(
|
|
status_code="HTTP_304_NOT_MODIFIED",
|
|
error_case=e.__class__.__name__,
|
|
data={},
|
|
message=str(e.__context__).split('\n')[0],
|
|
)
|
|
|
|
def destroy(self):
|
|
"""Delete the record from the database."""
|
|
self.__session__.delete(self)
|
|
self.__session__.commit()
|
|
|
|
@classmethod
|
|
def save(cls):
|
|
"""Saves the updated model to the current entity db."""
|
|
try:
|
|
cls.__session__.commit()
|
|
except SQLAlchemyError as e:
|
|
cls.raise_http_exception(
|
|
status_code="HTTP_304_NOT_MODIFIED",
|
|
error_case=e.__class__.__name__,
|
|
data={},
|
|
message=str(e.__context__).split('\n')[0],
|
|
)
|
|
|
|
@classmethod
|
|
def _query(cls):
|
|
"""Returns the query to use in the model."""
|
|
return cls.pre_query if cls.pre_query else cls.query
|
|
|
|
@classmethod
|
|
def add_query_to_filter(cls, filter_query, filter_list):
|
|
return (
|
|
filter_query.order_by(
|
|
getattr(cls, filter_list.get("order_field")).desc()
|
|
if str(filter_list.get("order_type"))[0] == "d"
|
|
else getattr(cls, filter_list.get("order_field")).asc()
|
|
)
|
|
.limit(filter_list.get("size"))
|
|
.offset(int((filter_list.get("page")) - 1) * int(filter_list.get("size")))
|
|
.populate_existing()
|
|
)
|
|
|
|
@classmethod
|
|
def get_filter_attributes(cls):
|
|
"""
|
|
Returns the filter to use pagination and ordering.
|
|
page is the current page number.
|
|
size is the number of records per page.
|
|
order_field is the field to order by.
|
|
order_type is the order type (asc or desc).
|
|
include_joins returns the joined tables when related field names are given as a list.
|
|
"""
|
|
return {
|
|
"page": getattr(cls.filter_attr, "page", 1),
|
|
"size": getattr(cls.filter_attr, "size", 10),
|
|
"order_field": getattr(cls.filter_attr, "order_field", "id"),
|
|
"order_type": getattr(cls.filter_attr, "order_type", "asc"),
|
|
"include_joins": getattr(cls.filter_attr, "include_joins", []),
|
|
}
|
|
|
|
@classmethod
|
|
def get_not_expired_query_arg(cls, *arg, expired=True):
|
|
"""Add expiry_starts and expiry_ends to the query."""
|
|
from api_library.date_time_actions.date_functions import system_arrow
|
|
if expired:
|
|
arg_add = (
|
|
*arg[0],
|
|
system_arrow.get(cls.expiry_ends) >= system_arrow.now(),
|
|
system_arrow.now() >= system_arrow.get(cls.expiry_starts),
|
|
)
|
|
return arg_add
|
|
return arg[0]
|
|
|
|
@classmethod
|
|
def filter_by_all(cls, **kwargs):
|
|
"""
|
|
Filters all the records regardless of is_deleted, is_confirmed.
|
|
"""
|
|
filter_list = cls.get_filter_attributes()
|
|
query = cls._query().filter_by(**kwargs)
|
|
data = cls.add_query_to_filter(query, filter_list)
|
|
return AlchemyResponse(query=data, first=False)
|
|
|
|
@classmethod
|
|
def filter_by_one(cls, **kwargs):
|
|
"""
|
|
Filters one record regardless of is_deleted, is_confirmed.
|
|
"""
|
|
query = cls._query().filter_by(**kwargs)
|
|
return AlchemyResponse(query=query, first=True)
|
|
|
|
@classmethod
|
|
def filter_all(cls, *args, expired: bool = False):
|
|
"""
|
|
Filters all the records regardless of is_deleted, is_confirmed.
|
|
"""
|
|
filter_list = cls.get_filter_attributes()
|
|
query = cls._query()
|
|
data = cls.add_query_to_filter(query, filter_list)
|
|
return AlchemyResponse(query=data, first=False)
|
|
|
|
@classmethod
|
|
def filter_one(cls, *args, expired: bool = False):
|
|
"""
|
|
Filters one record regardless of is_deleted, is_confirmed.
|
|
"""
|
|
|
|
arg = cls.get_not_expired_query_arg(args, expired=expired)
|
|
query = cls._query().filter(*arg)
|
|
return AlchemyResponse(query=query, first=True)
|
|
|
|
@classmethod
|
|
def raise_http_exception(cls, status_code, error_case, data, message):
|
|
from fastapi.exceptions import HTTPException
|
|
from fastapi import status
|
|
from json import dumps
|
|
|
|
cls.__session__.rollback()
|
|
raise HTTPException(
|
|
status_code=getattr(status, status_code, 'HTTP_404_NOT_FOUND'),
|
|
detail=dumps({
|
|
"data": data,
|
|
"error": error_case,
|
|
"message": message,
|
|
})
|
|
)
|