auth endpoints added

This commit is contained in:
2025-04-03 14:19:34 +03:00
parent 3583d178e9
commit ee405133be
37 changed files with 976 additions and 570 deletions

View File

@@ -17,29 +17,30 @@ def test_basic_crud_operations():
try:
with mongo_handler.collection("users") as users_collection:
# Insert multiple documents
users_collection.insert_many([
{"username": "john", "email": "john@example.com", "role": "user"},
{"username": "jane", "email": "jane@example.com", "role": "admin"},
{"username": "bob", "email": "bob@example.com", "role": "user"}
])
users_collection.insert_many(
[
{"username": "john", "email": "john@example.com", "role": "user"},
{"username": "jane", "email": "jane@example.com", "role": "admin"},
{"username": "bob", "email": "bob@example.com", "role": "user"},
]
)
# Find with multiple conditions
admin_users = list(users_collection.find({"role": "admin"}))
# Update multiple documents
update_result = users_collection.update_many(
{"role": "user"},
{"$set": {"last_login": datetime.now().isoformat()}}
{"role": "user"}, {"$set": {"last_login": datetime.now().isoformat()}}
)
# Delete documents
delete_result = users_collection.delete_many({"username": "bob"})
success = (
len(admin_users) == 1 and
admin_users[0]["username"] == "jane" and
update_result.modified_count == 2 and
delete_result.deleted_count == 1
len(admin_users) == 1
and admin_users[0]["username"] == "jane"
and update_result.modified_count == 2
and delete_result.deleted_count == 1
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -54,35 +55,32 @@ def test_nested_documents():
try:
with mongo_handler.collection("products") as products_collection:
# Insert a product with nested data
products_collection.insert_one({
"name": "Laptop",
"price": 999.99,
"specs": {
"cpu": "Intel i7",
"ram": "16GB",
"storage": "512GB SSD"
},
"in_stock": True,
"tags": ["electronics", "computers", "laptops"]
})
products_collection.insert_one(
{
"name": "Laptop",
"price": 999.99,
"specs": {"cpu": "Intel i7", "ram": "16GB", "storage": "512GB SSD"},
"in_stock": True,
"tags": ["electronics", "computers", "laptops"],
}
)
# Find with nested field query
laptop = products_collection.find_one({"specs.cpu": "Intel i7"})
# Update nested field
update_result = products_collection.update_one(
{"name": "Laptop"},
{"$set": {"specs.ram": "32GB"}}
{"name": "Laptop"}, {"$set": {"specs.ram": "32GB"}}
)
# Verify the update
updated_laptop = products_collection.find_one({"name": "Laptop"})
success = (
laptop is not None and
laptop["specs"]["ram"] == "16GB" and
update_result.modified_count == 1 and
updated_laptop["specs"]["ram"] == "32GB"
laptop is not None
and laptop["specs"]["ram"] == "16GB"
and update_result.modified_count == 1
and updated_laptop["specs"]["ram"] == "32GB"
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -97,16 +95,18 @@ def test_array_operations():
try:
with mongo_handler.collection("orders") as orders_collection:
# Insert an order with array of items
orders_collection.insert_one({
"order_id": "ORD001",
"customer": "john",
"items": [
{"product": "Laptop", "quantity": 1},
{"product": "Mouse", "quantity": 2}
],
"total": 1099.99,
"status": "pending"
})
orders_collection.insert_one(
{
"order_id": "ORD001",
"customer": "john",
"items": [
{"product": "Laptop", "quantity": 1},
{"product": "Mouse", "quantity": 2},
],
"total": 1099.99,
"status": "pending",
}
)
# Find orders containing specific items
laptop_orders = list(orders_collection.find({"items.product": "Laptop"}))
@@ -114,17 +114,17 @@ def test_array_operations():
# Update array elements
update_result = orders_collection.update_one(
{"order_id": "ORD001"},
{"$push": {"items": {"product": "Keyboard", "quantity": 1}}}
{"$push": {"items": {"product": "Keyboard", "quantity": 1}}},
)
# Verify the update
updated_order = orders_collection.find_one({"order_id": "ORD001"})
success = (
len(laptop_orders) == 1 and
update_result.modified_count == 1 and
len(updated_order["items"]) == 3 and
updated_order["items"][-1]["product"] == "Keyboard"
len(laptop_orders) == 1
and update_result.modified_count == 1
and len(updated_order["items"]) == 3
and updated_order["items"][-1]["product"] == "Keyboard"
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -139,23 +139,32 @@ def test_aggregation():
try:
with mongo_handler.collection("sales") as sales_collection:
# Insert sample sales data
sales_collection.insert_many([
{"product": "Laptop", "amount": 999.99, "date": datetime.now()},
{"product": "Mouse", "amount": 29.99, "date": datetime.now()},
{"product": "Keyboard", "amount": 59.99, "date": datetime.now()}
])
sales_collection.insert_many(
[
{"product": "Laptop", "amount": 999.99, "date": datetime.now()},
{"product": "Mouse", "amount": 29.99, "date": datetime.now()},
{"product": "Keyboard", "amount": 59.99, "date": datetime.now()},
]
)
# Calculate total sales by product
pipeline = [
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}}
]
pipeline = [{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}}]
sales_summary = list(sales_collection.aggregate(pipeline))
success = (
len(sales_summary) == 3 and
any(item["_id"] == "Laptop" and item["total"] == 999.99 for item in sales_summary) and
any(item["_id"] == "Mouse" and item["total"] == 29.99 for item in sales_summary) and
any(item["_id"] == "Keyboard" and item["total"] == 59.99 for item in sales_summary)
len(sales_summary) == 3
and any(
item["_id"] == "Laptop" and item["total"] == 999.99
for item in sales_summary
)
and any(
item["_id"] == "Mouse" and item["total"] == 29.99
for item in sales_summary
)
and any(
item["_id"] == "Keyboard" and item["total"] == 59.99
for item in sales_summary
)
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -174,11 +183,15 @@ def test_index_operations():
users_collection.create_index([("username", 1), ("role", 1)])
# Insert initial document
users_collection.insert_one({"username": "test_user", "email": "test@example.com"})
users_collection.insert_one(
{"username": "test_user", "email": "test@example.com"}
)
# Try to insert duplicate email (should fail)
try:
users_collection.insert_one({"username": "test_user2", "email": "test@example.com"})
users_collection.insert_one(
{"username": "test_user2", "email": "test@example.com"}
)
success = False # Should not reach here
except Exception:
success = True
@@ -196,49 +209,49 @@ def test_complex_queries():
try:
with mongo_handler.collection("products") as products_collection:
# Insert test data
products_collection.insert_many([
{
"name": "Expensive Laptop",
"price": 999.99,
"tags": ["electronics", "computers"],
"in_stock": True
},
{
"name": "Cheap Mouse",
"price": 29.99,
"tags": ["electronics", "peripherals"],
"in_stock": True
}
])
products_collection.insert_many(
[
{
"name": "Expensive Laptop",
"price": 999.99,
"tags": ["electronics", "computers"],
"in_stock": True,
},
{
"name": "Cheap Mouse",
"price": 29.99,
"tags": ["electronics", "peripherals"],
"in_stock": True,
},
]
)
# Find products with price range and specific tags
expensive_electronics = list(products_collection.find({
"price": {"$gt": 500},
"tags": {"$in": ["electronics"]},
"in_stock": True
}))
expensive_electronics = list(
products_collection.find(
{
"price": {"$gt": 500},
"tags": {"$in": ["electronics"]},
"in_stock": True,
}
)
)
# Update with multiple conditions
update_result = products_collection.update_many(
{
"price": {"$lt": 100},
"in_stock": True
},
{
"$set": {"discount": 0.1},
"$inc": {"price": -10}
}
{"price": {"$lt": 100}, "in_stock": True},
{"$set": {"discount": 0.1}, "$inc": {"price": -10}},
)
# Verify the update
updated_product = products_collection.find_one({"name": "Cheap Mouse"})
success = (
len(expensive_electronics) == 1 and
expensive_electronics[0]["name"] == "Expensive Laptop" and
update_result.modified_count == 1 and
updated_product["price"] == 19.99 and
updated_product["discount"] == 0.1
len(expensive_electronics) == 1
and expensive_electronics[0]["name"] == "Expensive Laptop"
and update_result.modified_count == 1
and updated_product["price"] == 19.99
and updated_product["discount"] == 0.1
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -250,19 +263,19 @@ def test_complex_queries():
def run_all_tests():
"""Run all MongoDB tests and report results."""
print("Starting MongoDB tests...")
# Clean up any existing test data before starting
cleanup_test_data()
tests = [
test_basic_crud_operations,
test_nested_documents,
test_array_operations,
test_aggregation,
test_index_operations,
test_complex_queries
test_complex_queries,
]
passed_list, not_passed_list = [], []
passed, failed = 0, 0
@@ -282,9 +295,9 @@ def run_all_tests():
not_passed_list.append(f"Test {test.__name__} failed")
print(f"\nTest Results: {passed} passed, {failed} failed")
print('Passed Tests:')
print("Passed Tests:")
print("\n".join(passed_list))
print('Failed Tests:')
print("Failed Tests:")
print("\n".join(not_passed_list))
return passed, failed

View File

@@ -14,6 +14,7 @@ class Credentials(BaseModel):
"""
Class to store user credentials.
"""
person_id: int
person_name: str
full_name: Optional[str] = None
@@ -23,6 +24,7 @@ class MetaData:
"""
Class to store metadata for a query.
"""
created: bool = False
updated: bool = False
@@ -30,7 +32,7 @@ class MetaData:
class CRUDModel:
"""
Base class for CRUD operations on PostgreSQL models.
Features:
- User credential tracking
- Metadata tracking for operations
@@ -38,21 +40,21 @@ class CRUDModel:
- Automatic timestamp management
- Soft delete support
"""
__abstract__ = True
creds: Credentials = None
meta_data: MetaData = MetaData()
# Define required columns for CRUD operations
required_columns = {
'expiry_starts': TIMESTAMP,
'expiry_ends': TIMESTAMP,
'created_by': str,
'created_by_id': int,
'updated_by': str,
'updated_by_id': int,
'deleted': bool
"expiry_starts": TIMESTAMP,
"expiry_ends": TIMESTAMP,
"created_by": str,
"created_by_id": int,
"updated_by": str,
"updated_by_id": int,
"deleted": bool,
}
@classmethod
@@ -65,24 +67,25 @@ class CRUDModel:
"""
if not cls.creds:
return
if getattr(cls.creds, "person_id", None) and getattr(cls.creds, "person_name", None):
if getattr(cls.creds, "person_id", None) and getattr(
cls.creds, "person_name", None
):
record_created.created_by_id = cls.creds.person_id
record_created.created_by = cls.creds.person_name
@classmethod
def raise_exception(cls, message: str = "Exception raised.", status_code: int = 400):
def raise_exception(
cls, message: str = "Exception raised.", status_code: int = 400
):
"""
Raise HTTP exception with custom message and status code.
Args:
message: Error message
status_code: HTTP status code
"""
raise HTTPException(
status_code=status_code,
detail={"message": message}
)
raise HTTPException(status_code=status_code, detail={"message": message})
@classmethod
def create_or_abort(cls, db: Session, **kwargs):
@@ -111,7 +114,7 @@ class CRUDModel:
query = query.filter(getattr(cls, key) == value)
already_record = query.first()
# Handle existing record
if already_record and already_record.deleted:
cls.raise_exception("Record already exists and is deleted")
@@ -122,12 +125,12 @@ class CRUDModel:
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to create record: {str(e)}", status_code=500)
@@ -146,7 +149,7 @@ class CRUDModel:
"""
try:
key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, 'primary_keys', [])
is_primary = key in getattr(cls, "primary_keys", [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
@@ -167,12 +170,16 @@ class CRUDModel:
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ"))
return True, str(
arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss ZZ")
)
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
@@ -185,17 +192,19 @@ class CRUDModel:
return True, None
return False, None
except Exception as e:
return False, None
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> Dict[str, Any]:
def get_dict(
self, exclude_list: Optional[list[InstrumentedAttribute]] = None
) -> Dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Args:
exclude_list: List of fields to exclude from the dictionary
Returns:
Dictionary representation of the model
"""
@@ -207,7 +216,7 @@ class CRUDModel:
# Get all column names from the model
columns = [col.name for col in self.__table__.columns]
columns_set = set(columns)
# Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set(
@@ -223,7 +232,7 @@ class CRUDModel:
return_dict[key] = value_of_database
return return_dict
except Exception as e:
return {}
@@ -251,10 +260,10 @@ class CRUDModel:
cls.expiry_ends > str(arrow.now()),
cls.expiry_starts <= str(arrow.now()),
)
exclude_args = exclude_args or []
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
for key, value in kwargs.items():
if hasattr(cls, key) and key not in exclude_args:
query = query.filter(getattr(cls, key) == value)
@@ -268,16 +277,18 @@ class CRUDModel:
created_record = cls()
for key, value in kwargs.items():
setattr(created_record, key, value)
cls.create_credentials(created_record)
db.add(created_record)
db.flush()
cls.meta_data.created = True
return created_record
except Exception as e:
db.rollback()
cls.raise_exception(f"Failed to find or create record: {str(e)}", status_code=500)
cls.raise_exception(
f"Failed to find or create record: {str(e)}", status_code=500
)
def update(self, db: Session, **kwargs):
"""
@@ -301,7 +312,7 @@ class CRUDModel:
db.flush()
self.meta_data.updated = True
return self
except Exception as e:
self.meta_data.updated = False
db.rollback()
@@ -313,10 +324,10 @@ class CRUDModel:
"""
if not self.creds:
return
person_id = getattr(self.creds, "person_id", None)
person_name = getattr(self.creds, "person_name", None)
if person_id and person_name:
self.updated_by_id = self.creds.person_id
self.updated_by = self.creds.person_name

View File

@@ -10,11 +10,11 @@ from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Sessi
engine = create_engine(
postgres_configs.url,
pool_pre_ping=True,
pool_size=10, # Reduced from 20 to better match your CPU cores
max_overflow=5, # Reduced from 10 to prevent too many connections
pool_recycle=600, # Keep as is
pool_timeout=30, # Keep as is
echo=False, # Consider setting to False in production
pool_size=10, # Reduced from 20 to better match your CPU cores
max_overflow=5, # Reduced from 10 to prevent too many connections
pool_recycle=600, # Keep as is
pool_timeout=30, # Keep as is
echo=False, # Consider setting to False in production
)

View File

@@ -35,51 +35,49 @@ class QueryModel:
@classmethod
def add_new_arg_to_args(
cls: Type[T],
args_list: tuple[BinaryExpression, ...],
argument: str,
value: BinaryExpression
cls: Type[T],
args_list: tuple[BinaryExpression, ...],
argument: str,
value: BinaryExpression,
) -> tuple[BinaryExpression, ...]:
"""
Add a new argument to the query arguments if it doesn't exist.
Args:
args_list: Existing query arguments
argument: Key of the argument to check for
value: New argument value to add
Returns:
Updated tuple of query arguments
"""
# Convert to set to remove duplicates while preserving order
new_args = list(dict.fromkeys(
arg for arg in args_list
if isinstance(arg, BinaryExpression)
))
new_args = list(
dict.fromkeys(arg for arg in args_list if isinstance(arg, BinaryExpression))
)
# Check if argument already exists
if not any(
getattr(getattr(arg, "left", None), "key", None) == argument
getattr(getattr(arg, "left", None), "key", None) == argument
for arg in new_args
):
new_args.append(value)
return tuple(new_args)
@classmethod
def get_not_expired_query_arg(
cls: Type[T],
args: tuple[BinaryExpression, ...]
cls: Type[T], args: tuple[BinaryExpression, ...]
) -> tuple[BinaryExpression, ...]:
"""
Add expiry date filtering to the query arguments.
Args:
args: Existing query arguments
Returns:
Updated tuple of query arguments with expiry filters
Raises:
AttributeError: If model does not have expiry_starts or expiry_ends columns
"""
@@ -87,21 +85,21 @@ class QueryModel:
current_time = str(arrow.now())
# Only add expiry filters if they don't already exist
if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
getattr(getattr(arg, "left", None), "key", None) == "expiry_ends"
for arg in args
):
ends = cls.expiry_ends > current_time
args = cls.add_new_arg_to_args(args, "expiry_ends", ends)
if not any(
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
getattr(getattr(arg, "left", None), "key", None) == "expiry_starts"
for arg in args
):
starts = cls.expiry_starts <= current_time
args = cls.add_new_arg_to_args(args, "expiry_starts", starts)
return args
except AttributeError as e:
raise AttributeError(
f"Model {cls.__name__} must have expiry_starts and expiry_ends columns"
@@ -111,7 +109,7 @@ class QueryModel:
def produce_query_to_add(cls: Type[T], filter_list: dict, args: tuple) -> tuple:
"""
Adds query to main filter options
Args:
filter_list: Dictionary containing query parameters
args: Existing query arguments to add to
@@ -122,11 +120,11 @@ class QueryModel:
try:
if not filter_list or not isinstance(filter_list, dict):
return args
query_params = filter_list.get("query")
if not query_params or not isinstance(query_params, dict):
return args
for key, value in query_params.items():
if hasattr(cls, key):
# Create a new filter expression
@@ -134,39 +132,34 @@ class QueryModel:
# Add it to args if it doesn't exist
args = cls.add_new_arg_to_args(args, key, filter_expr)
return args
except Exception as e:
print(f"Error in produce_query_to_add: {str(e)}")
return args
@classmethod
def convert(
cls: Type[T],
smart_options: dict[str, Any],
validate_model: Any = None
cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None
) -> Optional[tuple[BinaryExpression, ...]]:
"""
Convert smart options to SQLAlchemy filter expressions.
Args:
smart_options: Dictionary of filter options
validate_model: Optional model to validate against
Returns:
Tuple of SQLAlchemy filter expressions or None if validation fails
"""
if validate_model is not None:
# Add validation logic here if needed
pass
return tuple(cls.filter_expr(**smart_options))
@classmethod
def filter_by_one(
cls: Type[T],
db: Session,
system: bool = False,
**kwargs: Any
cls: Type[T], db: Session, system: bool = False, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter single record by keyword arguments.
@@ -181,30 +174,28 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add keyword filters first
query = query.filter_by(**kwargs)
# Add status filters if not system query
if not system:
query = query.filter(
cls.is_confirmed == True,
cls.deleted == False,
cls.active == True
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False
query=query,
is_array=False,
)
@classmethod
@@ -225,29 +216,27 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add status filters
query = query.filter(
cls.is_confirmed == True,
cls.deleted == False,
cls.active == True
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False
query=query,
is_array=False,
)
@classmethod
@@ -268,22 +257,22 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=False
query=query,
is_array=False,
)
@classmethod
@@ -304,22 +293,22 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True
query=query,
is_array=True,
)
@classmethod
@@ -340,36 +329,32 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add expression filters first
query = query.filter(*args)
# Add status filters
query = query.filter(
cls.is_confirmed == True,
cls.deleted == False,
cls.active == True
cls.is_confirmed == True, cls.deleted == False, cls.active == True
)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True
query=query,
is_array=True,
)
@classmethod
def filter_by_all_system(
cls: Type[T],
db: Session,
**kwargs: Any
cls: Type[T], db: Session, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter multiple records by keyword arguments without status filtering.
@@ -383,29 +368,27 @@ class QueryModel:
"""
# Get base query (either pre_query or new query)
base_query = cls._query(db)
# Create the final query by applying filters
query = base_query
# Add keyword filters first
query = query.filter_by(**kwargs)
# Add expiry filters last
args = cls.get_not_expired_query_arg(())
query = query.filter(*args)
return PostgresResponse(
model=cls,
model=cls,
pre_query=base_query, # Use the base query for pre_query
query=query,
is_array=True
query=query,
is_array=True,
)
@classmethod
def filter_by_one_system(
cls: Type[T],
db: Session,
**kwargs: Any
cls: Type[T], db: Session, **kwargs: Any
) -> PostgresResponse[T]:
"""
Filter single record by keyword arguments without status filtering.

View File

@@ -5,34 +5,35 @@ from Controllers.Postgres.database import Base, engine
def generate_table_in_postgres():
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
# Create all tables defined in the Base metadata
Base.metadata.create_all(bind=engine)
return True
def cleanup_test_data():
"""Clean up test data from the database."""
with EndpointRestriction.new_session() as db_session:
try:
# Get all test records
test_records = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_code.like("TEST%"),
db=db_session
EndpointRestriction.endpoint_code.like("TEST%"), db=db_session
).data
# Delete each record using the same session
for record in test_records:
# Merge the record into the current session if it's not already attached
if record not in db_session:
record = db_session.merge(record)
db_session.delete(record)
db_session.commit()
except Exception as e:
print(f"Error cleaning up test data: {str(e)}")
db_session.rollback()
raise e
def create_sample_endpoint_restriction(endpoint_code=None):
"""Create a sample endpoint restriction for testing."""
if endpoint_code is None:
@@ -43,13 +44,12 @@ def create_sample_endpoint_restriction(endpoint_code=None):
try:
# First check if record exists
existing = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code,
db=db_session
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
)
if existing and existing.data:
return existing.data
# If not found, create new record
endpoint = EndpointRestriction.find_or_create(
endpoint_function="test_function",
@@ -77,6 +77,7 @@ def create_sample_endpoint_restriction(endpoint_code=None):
db_session.rollback()
raise e
def test_filter_by_one():
"""Test filtering a single record by keyword arguments."""
print("\nTesting filter_by_one...")
@@ -84,22 +85,20 @@ def test_filter_by_one():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST001")
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST001"
db=db_session, endpoint_code="TEST001"
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -107,6 +106,7 @@ def test_filter_by_one():
print(f"Test failed with exception: {e}")
return False
def test_filter_by_one_system():
"""Test filtering a single record by keyword arguments without status filtering."""
print("\nTesting filter_by_one_system...")
@@ -114,23 +114,20 @@ def test_filter_by_one_system():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST002")
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST002",
system=True
db=db_session, endpoint_code="TEST002", system=True
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -138,6 +135,7 @@ def test_filter_by_one_system():
print(f"Test failed with exception: {e}")
return False
def test_filter_one():
"""Test filtering a single record by expressions."""
print("\nTesting filter_one...")
@@ -145,22 +143,20 @@ def test_filter_one():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST003")
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == "TEST003",
db=db_session
EndpointRestriction.endpoint_code == "TEST003", db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -168,6 +164,7 @@ def test_filter_one():
print(f"Test failed with exception: {e}")
return False
def test_filter_one_system():
"""Test filtering a single record by expressions without status filtering."""
print("\nTesting filter_one_system...")
@@ -175,22 +172,20 @@ def test_filter_one_system():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST004")
result = EndpointRestriction.filter_one_system(
EndpointRestriction.endpoint_code == "TEST004",
db=db_session
EndpointRestriction.endpoint_code == "TEST004", db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -198,6 +193,7 @@ def test_filter_one_system():
print(f"Test failed with exception: {e}")
return False
def test_filter_all():
"""Test filtering multiple records by expressions."""
print("\nTesting filter_all...")
@@ -205,25 +201,23 @@ def test_filter_all():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST005")
endpoint2 = create_sample_endpoint_restriction("TEST006")
result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
db=db_session
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -231,6 +225,7 @@ def test_filter_all():
print(f"Test failed with exception: {e}")
return False
def test_filter_all_system():
"""Test filtering multiple records by expressions without status filtering."""
print("\nTesting filter_all_system...")
@@ -238,25 +233,23 @@ def test_filter_all_system():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST007")
endpoint2 = create_sample_endpoint_restriction("TEST008")
result = EndpointRestriction.filter_all_system(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
db=db_session
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -264,6 +257,7 @@ def test_filter_all_system():
print(f"Test failed with exception: {e}")
return False
def test_filter_by_all_system():
"""Test filtering multiple records by keyword arguments without status filtering."""
print("\nTesting filter_by_all_system...")
@@ -271,25 +265,23 @@ def test_filter_by_all_system():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST009")
endpoint2 = create_sample_endpoint_restriction("TEST010")
result = EndpointRestriction.filter_by_all_system(
db=db_session,
endpoint_method="GET"
db=db_session, endpoint_method="GET"
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -297,23 +289,32 @@ def test_filter_by_all_system():
print(f"Test failed with exception: {e}")
return False
def test_get_not_expired_query_arg():
"""Test adding expiry date filtering to query arguments."""
print("\nTesting get_not_expired_query_arg...")
with EndpointRestriction.new_session() as db_session:
try:
# Create a sample endpoint with a unique code
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
endpoint_code = (
f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
)
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Test the query argument generation
args = EndpointRestriction.get_not_expired_query_arg(())
# Verify the arguments
success = (
len(args) == 2 and
any(str(arg).startswith("endpoint_restriction.expiry_starts") for arg in args) and
any(str(arg).startswith("endpoint_restriction.expiry_ends") for arg in args)
len(args) == 2
and any(
str(arg).startswith("endpoint_restriction.expiry_starts")
for arg in args
)
and any(
str(arg).startswith("endpoint_restriction.expiry_ends")
for arg in args
)
)
print(f"Test {'passed' if success else 'failed'}")
return success
@@ -321,27 +322,33 @@ def test_get_not_expired_query_arg():
print(f"Test failed with exception: {e}")
return False
def test_add_new_arg_to_args():
"""Test adding new arguments to query arguments."""
print("\nTesting add_new_arg_to_args...")
try:
args = (EndpointRestriction.endpoint_code == "TEST001",)
new_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg)
updated_args = EndpointRestriction.add_new_arg_to_args(
args, "endpoint_method", new_arg
)
success = len(updated_args) == 2
# Test duplicate prevention
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg)
updated_args = EndpointRestriction.add_new_arg_to_args(
updated_args, "endpoint_method", duplicate_arg
)
success = success and len(updated_args) == 2 # Should not add duplicate
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_produce_query_to_add():
"""Test adding query parameters to filter options."""
print("\nTesting produce_query_to_add...")
@@ -349,36 +356,31 @@ def test_produce_query_to_add():
try:
sample_endpoint = create_sample_endpoint_restriction("TEST001")
filter_list = {
"query": {
"endpoint_method": "GET",
"endpoint_code": "TEST001"
}
"query": {"endpoint_method": "GET", "endpoint_code": "TEST001"}
}
args = ()
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
success = len(updated_args) == 2
result = EndpointRestriction.filter_all(
*updated_args,
db=db_session
)
result = EndpointRestriction.filter_all(*updated_args, db=db_session)
# Test PostgresResponse properties
success = (
success and
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is True
success
and result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_get_dict():
"""Test the get_dict() function for single-record filters."""
print("\nTesting get_dict...")
@@ -386,51 +388,50 @@ def test_get_dict():
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create a sample endpoint
endpoint_code = "TEST_DICT_001"
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Get the endpoint using filter_one
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code,
db=db_session
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
)
# Get the data and convert to dict
data = result.data
data_dict = data.get_dict()
# Test dictionary properties
success = (
data_dict is not None and
isinstance(data_dict, dict) and
data_dict.get("endpoint_code") == endpoint_code and
data_dict.get("endpoint_method") == "GET" and
data_dict.get("endpoint_function") == "test_function" and
data_dict.get("endpoint_name") == "Test Endpoint" and
data_dict.get("endpoint_desc") == "Test Description" and
data_dict.get("is_confirmed") is True and
data_dict.get("active") is True and
data_dict.get("deleted") is False
data_dict is not None
and isinstance(data_dict, dict)
and data_dict.get("endpoint_code") == endpoint_code
and data_dict.get("endpoint_method") == "GET"
and data_dict.get("endpoint_function") == "test_function"
and data_dict.get("endpoint_name") == "Test Endpoint"
and data_dict.get("endpoint_desc") == "Test Description"
and data_dict.get("is_confirmed") is True
and data_dict.get("active") is True
and data_dict.get("deleted") is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def run_all_tests():
"""Run all tests and report results."""
print("Starting EndpointRestriction tests...")
# Clean up any existing test data before starting
cleanup_test_data()
tests = [
test_filter_by_one,
test_filter_by_one_system,
@@ -442,7 +443,7 @@ def run_all_tests():
test_get_not_expired_query_arg,
test_add_new_arg_to_args,
test_produce_query_to_add,
test_get_dict # Added new test
test_get_dict, # Added new test
]
passed_list, not_passed_list = [], []
passed, failed = 0, 0
@@ -453,33 +454,24 @@ def run_all_tests():
try:
if test():
passed += 1
passed_list.append(
f"Test {test.__name__} passed"
)
passed_list.append(f"Test {test.__name__} passed")
else:
failed += 1
not_passed_list.append(
f"Test {test.__name__} failed"
)
not_passed_list.append(f"Test {test.__name__} failed")
except Exception as e:
print(f"Test {test.__name__} failed with exception: {e}")
failed += 1
not_passed_list.append(
f"Test {test.__name__} failed"
)
not_passed_list.append(f"Test {test.__name__} failed")
print(f"\nTest Results: {passed} passed, {failed} failed")
print('Passed Tests:')
print(
"\n".join(passed_list)
)
print('Failed Tests:')
print(
"\n".join(not_passed_list)
)
print("Passed Tests:")
print("\n".join(passed_list))
print("Failed Tests:")
print("\n".join(not_passed_list))
return passed, failed
if __name__ == "__main__":
generate_table_in_postgres()
run_all_tests()

View File

@@ -27,4 +27,3 @@ class EndpointRestriction(CrudCollection):
endpoint_code: Mapped[str] = mapped_column(
String, server_default="", unique=True, comment="Unique code for the endpoint"
)

View File

@@ -15,7 +15,7 @@ from typing import Union, Dict, List, Optional, Any, TypeVar
from Controllers.Redis.connection import redis_cli
T = TypeVar('T', Dict[str, Any], List[Any])
T = TypeVar("T", Dict[str, Any], List[Any])
class RedisKeyError(Exception):
@@ -277,18 +277,18 @@ class RedisRow:
"""
if not key:
raise RedisKeyError("Cannot set empty key")
# Convert to string for validation
key_str = key.decode() if isinstance(key, bytes) else str(key)
# Validate key length (Redis has a 512MB limit for keys)
if len(key_str) > 512 * 1024 * 1024:
raise RedisKeyError("Key exceeds maximum length of 512MB")
# Validate key format (basic check for invalid characters)
if any(c in key_str for c in ['\n', '\r', '\t', '\0']):
if any(c in key_str for c in ["\n", "\r", "\t", "\0"]):
raise RedisKeyError("Key contains invalid characters")
self.key = key if isinstance(key, bytes) else str(key).encode()
@property

View File

@@ -5,11 +5,12 @@ class Configs(BaseSettings):
"""
MongoDB configuration settings.
"""
HOST: str = ""
PASSWORD: str = ""
PORT: int = 0
DB: int = 0
def as_dict(self):
return dict(
host=self.HOST,

View File

@@ -98,9 +98,7 @@ class RedisConn:
err = e
return False
def set_connection(
self, **kwargs
) -> Redis:
def set_connection(self, **kwargs) -> Redis:
"""
Recreate Redis connection with new parameters.

View File

@@ -14,6 +14,7 @@ def example_set_json() -> None:
result = RedisActions.set_json(list_keys=keys, value=data, expires=expiry)
print("Set JSON with expiry:", result.as_dict())
def example_get_json() -> None:
"""Example of retrieving JSON data from Redis."""
# Example 1: Get all matching keys
@@ -25,11 +26,16 @@ def example_get_json() -> None:
result = RedisActions.get_json(list_keys=keys, limit=5)
print("Get JSON with limit:", result.as_dict())
def example_get_json_iterator() -> None:
"""Example of using the JSON iterator for large datasets."""
keys = ["user", "profile", "*"]
for row in RedisActions.get_json_iterator(list_keys=keys):
print("Iterating over JSON row:", row.as_dict if isinstance(row.as_dict, dict) else row.as_dict)
print(
"Iterating over JSON row:",
row.as_dict if isinstance(row.as_dict, dict) else row.as_dict,
)
def example_delete_key() -> None:
"""Example of deleting a specific key."""
@@ -37,12 +43,14 @@ def example_delete_key() -> None:
result = RedisActions.delete_key(key)
print("Delete specific key:", result)
def example_delete() -> None:
"""Example of deleting multiple keys matching a pattern."""
keys = ["user", "profile", "*"]
result = RedisActions.delete(list_keys=keys)
print("Delete multiple keys:", result)
def example_refresh_ttl() -> None:
"""Example of refreshing TTL for a key."""
key = "user:profile:123"
@@ -50,48 +58,53 @@ def example_refresh_ttl() -> None:
result = RedisActions.refresh_ttl(key=key, expires=new_expiry)
print("Refresh TTL:", result.as_dict())
def example_key_exists() -> None:
"""Example of checking if a key exists."""
key = "user:profile:123"
exists = RedisActions.key_exists(key)
print(f"Key {key} exists:", exists)
def example_resolve_expires_at() -> None:
"""Example of resolving expiry time for a key."""
from Controllers.Redis.base import RedisRow
redis_row = RedisRow()
redis_row.set_key("user:profile:123")
print(redis_row.keys)
expires_at = RedisActions.resolve_expires_at(redis_row)
print("Resolve expires at:", expires_at)
def run_all_examples() -> None:
"""Run all example functions to demonstrate RedisActions functionality."""
print("\n=== Redis Actions Examples ===\n")
print("1. Setting JSON data:")
example_set_json()
print("\n2. Getting JSON data:")
example_get_json()
print("\n3. Using JSON iterator:")
example_get_json_iterator()
# print("\n4. Deleting specific key:")
# example_delete_key()
#
# print("\n5. Deleting multiple keys:")
# example_delete()
print("\n6. Refreshing TTL:")
example_refresh_ttl()
print("\n7. Checking key existence:")
example_key_exists()
print("\n8. Resolving expiry time:")
example_resolve_expires_at()
if __name__ == "__main__":
run_all_examples()

View File

@@ -67,7 +67,7 @@ class RedisResponse:
# Process single RedisRow
if isinstance(data, RedisRow):
result = {**main_dict}
if hasattr(data, 'keys') and hasattr(data, 'row'):
if hasattr(data, "keys") and hasattr(data, "row"):
if not isinstance(data.keys, str):
raise ValueError("RedisRow keys must be string type")
result[data.keys] = data.row
@@ -80,7 +80,11 @@ class RedisResponse:
# Handle list of RedisRow objects
rows_dict = {}
for row in data:
if isinstance(row, RedisRow) and hasattr(row, 'keys') and hasattr(row, 'row'):
if (
isinstance(row, RedisRow)
and hasattr(row, "keys")
and hasattr(row, "row")
):
if not isinstance(row.keys, str):
raise ValueError("RedisRow keys must be string type")
rows_dict[row.keys] = row.row
@@ -137,10 +141,10 @@ class RedisResponse:
if isinstance(self.data, list) and self.data:
item = self.data[0]
if isinstance(item, RedisRow) and hasattr(item, 'row'):
if isinstance(item, RedisRow) and hasattr(item, "row"):
return item.row
return item
elif isinstance(self.data, RedisRow) and hasattr(self.data, 'row'):
elif isinstance(self.data, RedisRow) and hasattr(self.data, "row"):
return self.data.row
elif isinstance(self.data, dict):
return self.data
@@ -168,16 +172,16 @@ class RedisResponse:
"success": self.status,
"message": self.message,
}
if self.error:
response["error"] = self.error
if self.data is not None:
if self.data_type == "row" and hasattr(self.data, 'to_dict'):
if self.data_type == "row" and hasattr(self.data, "to_dict"):
response["data"] = self.data.to_dict()
elif self.data_type == "list":
try:
if all(hasattr(item, 'to_dict') for item in self.data):
if all(hasattr(item, "to_dict") for item in self.data):
response["data"] = [item.to_dict() for item in self.data]
else:
response["data"] = self.data
@@ -192,5 +196,5 @@ class RedisResponse:
return {
"success": False,
"message": "Error formatting response",
"error": str(e)
}
"error": str(e),
}