auth endpoints added
This commit is contained in:
@@ -14,6 +14,7 @@ class Credentials(BaseModel):
|
||||
"""
|
||||
Class to store user credentials.
|
||||
"""
|
||||
|
||||
person_id: int
|
||||
person_name: str
|
||||
full_name: Optional[str] = None
|
||||
@@ -23,6 +24,7 @@ class MetaData:
|
||||
"""
|
||||
Class to store metadata for a query.
|
||||
"""
|
||||
|
||||
created: bool = False
|
||||
updated: bool = False
|
||||
|
||||
@@ -30,7 +32,7 @@ class MetaData:
|
||||
class CRUDModel:
|
||||
"""
|
||||
Base class for CRUD operations on PostgreSQL models.
|
||||
|
||||
|
||||
Features:
|
||||
- User credential tracking
|
||||
- Metadata tracking for operations
|
||||
@@ -38,21 +40,21 @@ class CRUDModel:
|
||||
- Automatic timestamp management
|
||||
- Soft delete support
|
||||
"""
|
||||
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
creds: Credentials = None
|
||||
meta_data: MetaData = MetaData()
|
||||
|
||||
|
||||
# Define required columns for CRUD operations
|
||||
required_columns = {
|
||||
'expiry_starts': TIMESTAMP,
|
||||
'expiry_ends': TIMESTAMP,
|
||||
'created_by': str,
|
||||
'created_by_id': int,
|
||||
'updated_by': str,
|
||||
'updated_by_id': int,
|
||||
'deleted': bool
|
||||
"expiry_starts": TIMESTAMP,
|
||||
"expiry_ends": TIMESTAMP,
|
||||
"created_by": str,
|
||||
"created_by_id": int,
|
||||
"updated_by": str,
|
||||
"updated_by_id": int,
|
||||
"deleted": bool,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -65,24 +67,25 @@ class CRUDModel:
|
||||
"""
|
||||
if not cls.creds:
|
||||
return
|
||||
|
||||
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
|
||||
|
||||
if getattr(cls.creds, "person_id", None) and getattr(
|
||||
cls.creds, "person_name", None
|
||||
):
|
||||
record_created.created_by_id = cls.creds.person_id
|
||||
record_created.created_by = cls.creds.person_name
|
||||
|
||||
|
||||
@classmethod
|
||||
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400):
|
||||
def raise_exception(
|
||||
cls, message: str = "Exception raised.", status_code: int = 400
|
||||
):
|
||||
"""
|
||||
Raise HTTP exception with custom message and status code.
|
||||
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
status_code: HTTP status code
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail={"message": message}
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail={"message": message})
|
||||
|
||||
@classmethod
|
||||
def create_or_abort(cls, db: Session, **kwargs):
|
||||
@@ -111,7 +114,7 @@ class CRUDModel:
|
||||
query = query.filter(getattr(cls, key) == value)
|
||||
|
||||
already_record = query.first()
|
||||
|
||||
|
||||
# Handle existing record
|
||||
if already_record and already_record.deleted:
|
||||
cls.raise_exception("Record already exists and is deleted")
|
||||
@@ -122,12 +125,12 @@ class CRUDModel:
|
||||
created_record = cls()
|
||||
for key, value in kwargs.items():
|
||||
setattr(created_record, key, value)
|
||||
|
||||
|
||||
cls.create_credentials(created_record)
|
||||
db.add(created_record)
|
||||
db.flush()
|
||||
return created_record
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
|
||||
@@ -146,7 +149,7 @@ class CRUDModel:
|
||||
"""
|
||||
try:
|
||||
key_ = cls.__annotations__.get(key, None)
|
||||
is_primary = key in getattr(cls, 'primary_keys', [])
|
||||
is_primary = key in getattr(cls, "primary_keys", [])
|
||||
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
|
||||
|
||||
# Skip primary keys and foreign keys
|
||||
@@ -167,12 +170,16 @@ class CRUDModel:
|
||||
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
|
||||
return True, round(float(val), 3)
|
||||
elif key_ == Mapped[TIMESTAMP]:
|
||||
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
|
||||
return True, str(
|
||||
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
|
||||
)
|
||||
elif key_ == Mapped[str]:
|
||||
return True, str(val)
|
||||
else: # Handle based on Python types
|
||||
if isinstance(val, datetime.datetime):
|
||||
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
|
||||
return True, str(
|
||||
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
|
||||
)
|
||||
elif isinstance(val, bool):
|
||||
return True, bool(val)
|
||||
elif isinstance(val, (float, Decimal)):
|
||||
@@ -185,17 +192,19 @@ class CRUDModel:
|
||||
return True, None
|
||||
|
||||
return False, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
|
||||
def get_dict(
|
||||
self, exclude_list: Optional[list[InstrumentedAttribute]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert model instance to dictionary with customizable fields.
|
||||
|
||||
|
||||
Args:
|
||||
exclude_list: List of fields to exclude from the dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the model
|
||||
"""
|
||||
@@ -207,7 +216,7 @@ class CRUDModel:
|
||||
# Get all column names from the model
|
||||
columns = [col.name for col in self.__table__.columns]
|
||||
columns_set = set(columns)
|
||||
|
||||
|
||||
# Filter columns
|
||||
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
|
||||
columns_extend = set(
|
||||
@@ -223,7 +232,7 @@ class CRUDModel:
|
||||
return_dict[key] = value_of_database
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return {}
|
||||
|
||||
@@ -251,10 +260,10 @@ class CRUDModel:
|
||||
cls.expiry_ends > str(arrow.now()),
|
||||
cls.expiry_starts <= str(arrow.now()),
|
||||
)
|
||||
|
||||
|
||||
exclude_args = exclude_args or []
|
||||
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
|
||||
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(cls, key) and key not in exclude_args:
|
||||
query = query.filter(getattr(cls, key) == value)
|
||||
@@ -268,16 +277,18 @@ class CRUDModel:
|
||||
created_record = cls()
|
||||
for key, value in kwargs.items():
|
||||
setattr(created_record, key, value)
|
||||
|
||||
|
||||
cls.create_credentials(created_record)
|
||||
db.add(created_record)
|
||||
db.flush()
|
||||
cls.meta_data.created = True
|
||||
return created_record
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500)
|
||||
cls.raise_exception(
|
||||
f"Failed to find or create record: {str(e)}", status_code=500
|
||||
)
|
||||
|
||||
def update(self, db: Session, **kwargs):
|
||||
"""
|
||||
@@ -301,7 +312,7 @@ class CRUDModel:
|
||||
db.flush()
|
||||
self.meta_data.updated = True
|
||||
return self
|
||||
|
||||
|
||||
except Exception as e:
|
||||
self.meta_data.updated = False
|
||||
db.rollback()
|
||||
@@ -313,10 +324,10 @@ class CRUDModel:
|
||||
"""
|
||||
if not self.creds:
|
||||
return
|
||||
|
||||
|
||||
person_id = getattr(self.creds, "person_id", None)
|
||||
person_name = getattr(self.creds, "person_name", None)
|
||||
|
||||
|
||||
if person_id and person_name:
|
||||
self.updated_by_id = self.creds.person_id
|
||||
self.updated_by = self.creds.person_name
|
||||
|
||||
@@ -10,11 +10,11 @@ from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Sessi
|
||||
engine = create_engine(
|
||||
postgres_configs.url,
|
||||
pool_pre_ping=True,
|
||||
pool_size=10, # Reduced from 20 to better match your CPU cores
|
||||
max_overflow=5, # Reduced from 10 to prevent too many connections
|
||||
pool_recycle=600, # Keep as is
|
||||
pool_timeout=30, # Keep as is
|
||||
echo=False, # Consider setting to False in production
|
||||
pool_size=10, # Reduced from 20 to better match your CPU cores
|
||||
max_overflow=5, # Reduced from 10 to prevent too many connections
|
||||
pool_recycle=600, # Keep as is
|
||||
pool_timeout=30, # Keep as is
|
||||
echo=False, # Consider setting to False in production
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -35,51 +35,49 @@ class QueryModel:
|
||||
|
||||
@classmethod
|
||||
def add_new_arg_to_args(
|
||||
cls: Type[T],
|
||||
args_list: tuple[BinaryExpression, ...],
|
||||
argument: str,
|
||||
value: BinaryExpression
|
||||
cls: Type[T],
|
||||
args_list: tuple[BinaryExpression, ...],
|
||||
argument: str,
|
||||
value: BinaryExpression,
|
||||
) -> tuple[BinaryExpression, ...]:
|
||||
"""
|
||||
Add a new argument to the query arguments if it doesn't exist.
|
||||
|
||||
|
||||
Args:
|
||||
args_list: Existing query arguments
|
||||
argument: Key of the argument to check for
|
||||
value: New argument value to add
|
||||
|
||||
|
||||
Returns:
|
||||
Updated tuple of query arguments
|
||||
"""
|
||||
# Convert to set to remove duplicates while preserving order
|
||||
new_args = list(dict.fromkeys(
|
||||
arg for arg in args_list
|
||||
if isinstance(arg, BinaryExpression)
|
||||
))
|
||||
|
||||
new_args = list(
|
||||
dict.fromkeys(arg for arg in args_list if isinstance(arg, BinaryExpression))
|
||||
)
|
||||
|
||||
# Check if argument already exists
|
||||
if not any(
|
||||
getattr(getattr(arg, "left", None), "key", None) == argument
|
||||
getattr(getattr(arg, "left", None), "key", None) == argument
|
||||
for arg in new_args
|
||||
):
|
||||
new_args.append(value)
|
||||
|
||||
|
||||
return tuple(new_args)
|
||||
|
||||
@classmethod
|
||||
def get_not_expired_query_arg(
|
||||
cls: Type[T],
|
||||
args: tuple[BinaryExpression, ...]
|
||||
cls: Type[T], args: tuple[BinaryExpression, ...]
|
||||
) -> tuple[BinaryExpression, ...]:
|
||||
"""
|
||||
Add expiry date filtering to the query arguments.
|
||||
|
||||
|
||||
Args:
|
||||
args: Existing query arguments
|
||||
|
||||
|
||||
Returns:
|
||||
Updated tuple of query arguments with expiry filters
|
||||
|
||||
|
||||
Raises:
|
||||
AttributeError: If model does not have expiry_starts or expiry_ends columns
|
||||
"""
|
||||
@@ -87,21 +85,21 @@ class QueryModel:
|
||||
current_time = str(arrow.now())
|
||||
# Only add expiry filters if they don't already exist
|
||||
if not any(
|
||||
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
|
||||
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
|
||||
for arg in args
|
||||
):
|
||||
ends = cls.expiry_ends > current_time
|
||||
args = cls.add_new_arg_to_args(args, "expiry_ends", ends)
|
||||
|
||||
|
||||
if not any(
|
||||
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
|
||||
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
|
||||
for arg in args
|
||||
):
|
||||
starts = cls.expiry_starts <= current_time
|
||||
args = cls.add_new_arg_to_args(args, "expiry_starts", starts)
|
||||
|
||||
|
||||
return args
|
||||
|
||||
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Model {cls.__name__} must have expiry_starts and expiry_ends columns"
|
||||
@@ -111,7 +109,7 @@ class QueryModel:
|
||||
def produce_query_to_add(cls: Type[T], filter_list: dict, args: tuple) -> tuple:
|
||||
"""
|
||||
Adds query to main filter options
|
||||
|
||||
|
||||
Args:
|
||||
filter_list: Dictionary containing query parameters
|
||||
args: Existing query arguments to add to
|
||||
@@ -122,11 +120,11 @@ class QueryModel:
|
||||
try:
|
||||
if not filter_list or not isinstance(filter_list, dict):
|
||||
return args
|
||||
|
||||
|
||||
query_params = filter_list.get("query")
|
||||
if not query_params or not isinstance(query_params, dict):
|
||||
return args
|
||||
|
||||
|
||||
for key, value in query_params.items():
|
||||
if hasattr(cls, key):
|
||||
# Create a new filter expression
|
||||
@@ -134,39 +132,34 @@ class QueryModel:
|
||||
# Add it to args if it doesn't exist
|
||||
args = cls.add_new_arg_to_args(args, key, filter_expr)
|
||||
return args
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in produce_query_to_add: {str(e)}")
|
||||
return args
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls: Type[T],
|
||||
smart_options: dict[str, Any],
|
||||
validate_model: Any = None
|
||||
cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None
|
||||
) -> Optional[tuple[BinaryExpression, ...]]:
|
||||
"""
|
||||
Convert smart options to SQLAlchemy filter expressions.
|
||||
|
||||
|
||||
Args:
|
||||
smart_options: Dictionary of filter options
|
||||
validate_model: Optional model to validate against
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of SQLAlchemy filter expressions or None if validation fails
|
||||
"""
|
||||
if validate_model is not None:
|
||||
# Add validation logic here if needed
|
||||
pass
|
||||
|
||||
|
||||
return tuple(cls.filter_expr(**smart_options))
|
||||
|
||||
@classmethod
|
||||
def filter_by_one(
|
||||
cls: Type[T],
|
||||
db: Session,
|
||||
system: bool = False,
|
||||
**kwargs: Any
|
||||
cls: Type[T], db: Session, system: bool = False, **kwargs: Any
|
||||
) -> PostgresResponse[T]:
|
||||
"""
|
||||
Filter single record by keyword arguments.
|
||||
@@ -181,30 +174,28 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add keyword filters first
|
||||
query = query.filter_by(**kwargs)
|
||||
|
||||
|
||||
# Add status filters if not system query
|
||||
if not system:
|
||||
query = query.filter(
|
||||
cls.is_confirmed == True,
|
||||
cls.deleted == False,
|
||||
cls.active == True
|
||||
cls.is_confirmed == True, cls.deleted == False, cls.active == True
|
||||
)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=False
|
||||
query=query,
|
||||
is_array=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -225,29 +216,27 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add expression filters first
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
# Add status filters
|
||||
query = query.filter(
|
||||
cls.is_confirmed == True,
|
||||
cls.deleted == False,
|
||||
cls.active == True
|
||||
cls.is_confirmed == True, cls.deleted == False, cls.active == True
|
||||
)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=False
|
||||
query=query,
|
||||
is_array=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -268,22 +257,22 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add expression filters first
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=False
|
||||
query=query,
|
||||
is_array=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -304,22 +293,22 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add expression filters first
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=True
|
||||
query=query,
|
||||
is_array=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -340,36 +329,32 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add expression filters first
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
# Add status filters
|
||||
query = query.filter(
|
||||
cls.is_confirmed == True,
|
||||
cls.deleted == False,
|
||||
cls.active == True
|
||||
cls.is_confirmed == True, cls.deleted == False, cls.active == True
|
||||
)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=True
|
||||
query=query,
|
||||
is_array=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def filter_by_all_system(
|
||||
cls: Type[T],
|
||||
db: Session,
|
||||
**kwargs: Any
|
||||
cls: Type[T], db: Session, **kwargs: Any
|
||||
) -> PostgresResponse[T]:
|
||||
"""
|
||||
Filter multiple records by keyword arguments without status filtering.
|
||||
@@ -383,29 +368,27 @@ class QueryModel:
|
||||
"""
|
||||
# Get base query (either pre_query or new query)
|
||||
base_query = cls._query(db)
|
||||
|
||||
|
||||
# Create the final query by applying filters
|
||||
query = base_query
|
||||
|
||||
|
||||
# Add keyword filters first
|
||||
query = query.filter_by(**kwargs)
|
||||
|
||||
|
||||
# Add expiry filters last
|
||||
args = cls.get_not_expired_query_arg(())
|
||||
query = query.filter(*args)
|
||||
|
||||
|
||||
return PostgresResponse(
|
||||
model=cls,
|
||||
model=cls,
|
||||
pre_query=base_query, # Use the base query for pre_query
|
||||
query=query,
|
||||
is_array=True
|
||||
query=query,
|
||||
is_array=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def filter_by_one_system(
|
||||
cls: Type[T],
|
||||
db: Session,
|
||||
**kwargs: Any
|
||||
cls: Type[T], db: Session, **kwargs: Any
|
||||
) -> PostgresResponse[T]:
|
||||
"""
|
||||
Filter single record by keyword arguments without status filtering.
|
||||
|
||||
@@ -5,34 +5,35 @@ from Controllers.Postgres.database import Base, engine
|
||||
|
||||
def generate_table_in_postgres():
|
||||
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
|
||||
|
||||
|
||||
# Create all tables defined in the Base metadata
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return True
|
||||
|
||||
|
||||
def cleanup_test_data():
|
||||
"""Clean up test data from the database."""
|
||||
with EndpointRestriction.new_session() as db_session:
|
||||
try:
|
||||
# Get all test records
|
||||
test_records = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_code.like("TEST%"),
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_code.like("TEST%"), db=db_session
|
||||
).data
|
||||
|
||||
|
||||
# Delete each record using the same session
|
||||
for record in test_records:
|
||||
# Merge the record into the current session if it's not already attached
|
||||
if record not in db_session:
|
||||
record = db_session.merge(record)
|
||||
db_session.delete(record)
|
||||
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up test data: {str(e)}")
|
||||
db_session.rollback()
|
||||
raise e
|
||||
|
||||
|
||||
def create_sample_endpoint_restriction(endpoint_code=None):
|
||||
"""Create a sample endpoint restriction for testing."""
|
||||
if endpoint_code is None:
|
||||
@@ -43,13 +44,12 @@ def create_sample_endpoint_restriction(endpoint_code=None):
|
||||
try:
|
||||
# First check if record exists
|
||||
existing = EndpointRestriction.filter_one(
|
||||
EndpointRestriction.endpoint_code == endpoint_code,
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
|
||||
)
|
||||
|
||||
|
||||
if existing and existing.data:
|
||||
return existing.data
|
||||
|
||||
|
||||
# If not found, create new record
|
||||
endpoint = EndpointRestriction.find_or_create(
|
||||
endpoint_function="test_function",
|
||||
@@ -77,6 +77,7 @@ def create_sample_endpoint_restriction(endpoint_code=None):
|
||||
db_session.rollback()
|
||||
raise e
|
||||
|
||||
|
||||
def test_filter_by_one():
|
||||
"""Test filtering a single record by keyword arguments."""
|
||||
print("\nTesting filter_by_one...")
|
||||
@@ -84,22 +85,20 @@ def test_filter_by_one():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
sample_endpoint = create_sample_endpoint_restriction("TEST001")
|
||||
result = EndpointRestriction.filter_by_one(
|
||||
db=db_session,
|
||||
endpoint_code="TEST001"
|
||||
db=db_session, endpoint_code="TEST001"
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 1 and
|
||||
result.total_count == 1 and
|
||||
result.is_list is False
|
||||
result is not None
|
||||
and result.count == 1
|
||||
and result.total_count == 1
|
||||
and result.is_list is False
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -107,6 +106,7 @@ def test_filter_by_one():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_by_one_system():
|
||||
"""Test filtering a single record by keyword arguments without status filtering."""
|
||||
print("\nTesting filter_by_one_system...")
|
||||
@@ -114,23 +114,20 @@ def test_filter_by_one_system():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
sample_endpoint = create_sample_endpoint_restriction("TEST002")
|
||||
result = EndpointRestriction.filter_by_one(
|
||||
db=db_session,
|
||||
endpoint_code="TEST002",
|
||||
system=True
|
||||
db=db_session, endpoint_code="TEST002", system=True
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 1 and
|
||||
result.total_count == 1 and
|
||||
result.is_list is False
|
||||
result is not None
|
||||
and result.count == 1
|
||||
and result.total_count == 1
|
||||
and result.is_list is False
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -138,6 +135,7 @@ def test_filter_by_one_system():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_one():
|
||||
"""Test filtering a single record by expressions."""
|
||||
print("\nTesting filter_one...")
|
||||
@@ -145,22 +143,20 @@ def test_filter_one():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
sample_endpoint = create_sample_endpoint_restriction("TEST003")
|
||||
result = EndpointRestriction.filter_one(
|
||||
EndpointRestriction.endpoint_code == "TEST003",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_code == "TEST003", db=db_session
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 1 and
|
||||
result.total_count == 1 and
|
||||
result.is_list is False
|
||||
result is not None
|
||||
and result.count == 1
|
||||
and result.total_count == 1
|
||||
and result.is_list is False
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -168,6 +164,7 @@ def test_filter_one():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_one_system():
|
||||
"""Test filtering a single record by expressions without status filtering."""
|
||||
print("\nTesting filter_one_system...")
|
||||
@@ -175,22 +172,20 @@ def test_filter_one_system():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
sample_endpoint = create_sample_endpoint_restriction("TEST004")
|
||||
result = EndpointRestriction.filter_one_system(
|
||||
EndpointRestriction.endpoint_code == "TEST004",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_code == "TEST004", db=db_session
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 1 and
|
||||
result.total_count == 1 and
|
||||
result.is_list is False
|
||||
result is not None
|
||||
and result.count == 1
|
||||
and result.total_count == 1
|
||||
and result.is_list is False
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -198,6 +193,7 @@ def test_filter_one_system():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_all():
|
||||
"""Test filtering multiple records by expressions."""
|
||||
print("\nTesting filter_all...")
|
||||
@@ -205,25 +201,23 @@ def test_filter_all():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
# Create two endpoint restrictions
|
||||
endpoint1 = create_sample_endpoint_restriction("TEST005")
|
||||
endpoint2 = create_sample_endpoint_restriction("TEST006")
|
||||
|
||||
result = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 2 and
|
||||
result.total_count == 2 and
|
||||
result.is_list is True
|
||||
result is not None
|
||||
and result.count == 2
|
||||
and result.total_count == 2
|
||||
and result.is_list is True
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -231,6 +225,7 @@ def test_filter_all():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_all_system():
|
||||
"""Test filtering multiple records by expressions without status filtering."""
|
||||
print("\nTesting filter_all_system...")
|
||||
@@ -238,25 +233,23 @@ def test_filter_all_system():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
# Create two endpoint restrictions
|
||||
endpoint1 = create_sample_endpoint_restriction("TEST007")
|
||||
endpoint2 = create_sample_endpoint_restriction("TEST008")
|
||||
|
||||
result = EndpointRestriction.filter_all_system(
|
||||
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 2 and
|
||||
result.total_count == 2 and
|
||||
result.is_list is True
|
||||
result is not None
|
||||
and result.count == 2
|
||||
and result.total_count == 2
|
||||
and result.is_list is True
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -264,6 +257,7 @@ def test_filter_all_system():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_filter_by_all_system():
|
||||
"""Test filtering multiple records by keyword arguments without status filtering."""
|
||||
print("\nTesting filter_by_all_system...")
|
||||
@@ -271,25 +265,23 @@ def test_filter_by_all_system():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
# Create two endpoint restrictions
|
||||
endpoint1 = create_sample_endpoint_restriction("TEST009")
|
||||
endpoint2 = create_sample_endpoint_restriction("TEST010")
|
||||
|
||||
result = EndpointRestriction.filter_by_all_system(
|
||||
db=db_session,
|
||||
endpoint_method="GET"
|
||||
db=db_session, endpoint_method="GET"
|
||||
)
|
||||
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
result is not None and
|
||||
result.count == 2 and
|
||||
result.total_count == 2 and
|
||||
result.is_list is True
|
||||
result is not None
|
||||
and result.count == 2
|
||||
and result.total_count == 2
|
||||
and result.is_list is True
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -297,23 +289,32 @@ def test_filter_by_all_system():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_get_not_expired_query_arg():
|
||||
"""Test adding expiry date filtering to query arguments."""
|
||||
print("\nTesting get_not_expired_query_arg...")
|
||||
with EndpointRestriction.new_session() as db_session:
|
||||
try:
|
||||
# Create a sample endpoint with a unique code
|
||||
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
|
||||
endpoint_code = (
|
||||
f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
|
||||
)
|
||||
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
|
||||
|
||||
|
||||
# Test the query argument generation
|
||||
args = EndpointRestriction.get_not_expired_query_arg(())
|
||||
|
||||
|
||||
# Verify the arguments
|
||||
success = (
|
||||
len(args) == 2 and
|
||||
any(str(arg).startswith("endpoint_restriction.expiry_starts") for arg in args) and
|
||||
any(str(arg).startswith("endpoint_restriction.expiry_ends") for arg in args)
|
||||
len(args) == 2
|
||||
and any(
|
||||
str(arg).startswith("endpoint_restriction.expiry_starts")
|
||||
for arg in args
|
||||
)
|
||||
and any(
|
||||
str(arg).startswith("endpoint_restriction.expiry_ends")
|
||||
for arg in args
|
||||
)
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -321,27 +322,33 @@ def test_get_not_expired_query_arg():
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_add_new_arg_to_args():
|
||||
"""Test adding new arguments to query arguments."""
|
||||
print("\nTesting add_new_arg_to_args...")
|
||||
try:
|
||||
args = (EndpointRestriction.endpoint_code == "TEST001",)
|
||||
new_arg = EndpointRestriction.endpoint_method == "GET"
|
||||
|
||||
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg)
|
||||
|
||||
updated_args = EndpointRestriction.add_new_arg_to_args(
|
||||
args, "endpoint_method", new_arg
|
||||
)
|
||||
success = len(updated_args) == 2
|
||||
|
||||
|
||||
# Test duplicate prevention
|
||||
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
|
||||
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg)
|
||||
updated_args = EndpointRestriction.add_new_arg_to_args(
|
||||
updated_args, "endpoint_method", duplicate_arg
|
||||
)
|
||||
success = success and len(updated_args) == 2 # Should not add duplicate
|
||||
|
||||
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_produce_query_to_add():
|
||||
"""Test adding query parameters to filter options."""
|
||||
print("\nTesting produce_query_to_add...")
|
||||
@@ -349,36 +356,31 @@ def test_produce_query_to_add():
|
||||
try:
|
||||
sample_endpoint = create_sample_endpoint_restriction("TEST001")
|
||||
filter_list = {
|
||||
"query": {
|
||||
"endpoint_method": "GET",
|
||||
"endpoint_code": "TEST001"
|
||||
}
|
||||
"query": {"endpoint_method": "GET", "endpoint_code": "TEST001"}
|
||||
}
|
||||
args = ()
|
||||
|
||||
|
||||
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
|
||||
success = len(updated_args) == 2
|
||||
|
||||
result = EndpointRestriction.filter_all(
|
||||
*updated_args,
|
||||
db=db_session
|
||||
)
|
||||
|
||||
|
||||
result = EndpointRestriction.filter_all(*updated_args, db=db_session)
|
||||
|
||||
# Test PostgresResponse properties
|
||||
success = (
|
||||
success and
|
||||
result is not None and
|
||||
result.count == 1 and
|
||||
result.total_count == 1 and
|
||||
result.is_list is True
|
||||
success
|
||||
and result is not None
|
||||
and result.count == 1
|
||||
and result.total_count == 1
|
||||
and result.is_list is True
|
||||
)
|
||||
|
||||
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_get_dict():
|
||||
"""Test the get_dict() function for single-record filters."""
|
||||
print("\nTesting get_dict...")
|
||||
@@ -386,51 +388,50 @@ def test_get_dict():
|
||||
try:
|
||||
# Set up pre_query first
|
||||
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
||||
EndpointRestriction.endpoint_method == "GET",
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_method == "GET", db=db_session
|
||||
).query
|
||||
|
||||
|
||||
# Create a sample endpoint
|
||||
endpoint_code = "TEST_DICT_001"
|
||||
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
|
||||
|
||||
|
||||
# Get the endpoint using filter_one
|
||||
result = EndpointRestriction.filter_one(
|
||||
EndpointRestriction.endpoint_code == endpoint_code,
|
||||
db=db_session
|
||||
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
|
||||
)
|
||||
|
||||
|
||||
# Get the data and convert to dict
|
||||
data = result.data
|
||||
data_dict = data.get_dict()
|
||||
|
||||
|
||||
# Test dictionary properties
|
||||
success = (
|
||||
data_dict is not None and
|
||||
isinstance(data_dict, dict) and
|
||||
data_dict.get("endpoint_code") == endpoint_code and
|
||||
data_dict.get("endpoint_method") == "GET" and
|
||||
data_dict.get("endpoint_function") == "test_function" and
|
||||
data_dict.get("endpoint_name") == "Test Endpoint" and
|
||||
data_dict.get("endpoint_desc") == "Test Description" and
|
||||
data_dict.get("is_confirmed") is True and
|
||||
data_dict.get("active") is True and
|
||||
data_dict.get("deleted") is False
|
||||
data_dict is not None
|
||||
and isinstance(data_dict, dict)
|
||||
and data_dict.get("endpoint_code") == endpoint_code
|
||||
and data_dict.get("endpoint_method") == "GET"
|
||||
and data_dict.get("endpoint_function") == "test_function"
|
||||
and data_dict.get("endpoint_name") == "Test Endpoint"
|
||||
and data_dict.get("endpoint_desc") == "Test Description"
|
||||
and data_dict.get("is_confirmed") is True
|
||||
and data_dict.get("active") is True
|
||||
and data_dict.get("deleted") is False
|
||||
)
|
||||
|
||||
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Test failed with exception: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests and report results."""
|
||||
print("Starting EndpointRestriction tests...")
|
||||
|
||||
|
||||
# Clean up any existing test data before starting
|
||||
cleanup_test_data()
|
||||
|
||||
|
||||
tests = [
|
||||
test_filter_by_one,
|
||||
test_filter_by_one_system,
|
||||
@@ -442,7 +443,7 @@ def run_all_tests():
|
||||
test_get_not_expired_query_arg,
|
||||
test_add_new_arg_to_args,
|
||||
test_produce_query_to_add,
|
||||
test_get_dict # Added new test
|
||||
test_get_dict, # Added new test
|
||||
]
|
||||
passed_list, not_passed_list = [], []
|
||||
passed, failed = 0, 0
|
||||
@@ -453,33 +454,24 @@ def run_all_tests():
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
passed_list.append(
|
||||
f"Test {test.__name__} passed"
|
||||
)
|
||||
passed_list.append(f"Test {test.__name__} passed")
|
||||
else:
|
||||
failed += 1
|
||||
not_passed_list.append(
|
||||
f"Test {test.__name__} failed"
|
||||
)
|
||||
not_passed_list.append(f"Test {test.__name__} failed")
|
||||
except Exception as e:
|
||||
print(f"Test {test.__name__} failed with exception: {e}")
|
||||
failed += 1
|
||||
not_passed_list.append(
|
||||
f"Test {test.__name__} failed"
|
||||
)
|
||||
not_passed_list.append(f"Test {test.__name__} failed")
|
||||
|
||||
print(f"\nTest Results: {passed} passed, {failed} failed")
|
||||
print('Passed Tests:')
|
||||
print(
|
||||
"\n".join(passed_list)
|
||||
)
|
||||
print('Failed Tests:')
|
||||
print(
|
||||
"\n".join(not_passed_list)
|
||||
)
|
||||
print("Passed Tests:")
|
||||
print("\n".join(passed_list))
|
||||
print("Failed Tests:")
|
||||
print("\n".join(not_passed_list))
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_table_in_postgres()
|
||||
run_all_tests()
|
||||
|
||||
@@ -27,4 +27,3 @@ class EndpointRestriction(CrudCollection):
|
||||
endpoint_code: Mapped[str] = mapped_column(
|
||||
String, server_default="", unique=True, comment="Unique code for the endpoint"
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user