478 lines
17 KiB
Python
478 lines
17 KiB
Python
import arrow
|
|
from Controllers.Postgres.schema import EndpointRestriction
|
|
from Controllers.Postgres.database import Base, engine
|
|
|
|
|
|
def generate_table_in_postgres():
|
|
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
|
|
|
|
# Create all tables defined in the Base metadata
|
|
Base.metadata.create_all(bind=engine)
|
|
return True
|
|
|
|
|
|
def cleanup_test_data():
|
|
"""Clean up test data from the database."""
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Get all test records
|
|
test_records = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_code.like("TEST%"), db=db_session
|
|
).data
|
|
|
|
# Delete each record using the same session
|
|
for record in test_records:
|
|
# Merge the record into the current session if it's not already attached
|
|
if record not in db_session:
|
|
record = db_session.merge(record)
|
|
db_session.delete(record)
|
|
|
|
db_session.commit()
|
|
except Exception as e:
|
|
print(f"Error cleaning up test data: {str(e)}")
|
|
db_session.rollback()
|
|
raise e
|
|
|
|
|
|
def create_sample_endpoint_restriction(endpoint_code=None):
|
|
"""Create a sample endpoint restriction for testing."""
|
|
if endpoint_code is None:
|
|
# Generate a unique endpoint code using timestamp and random number
|
|
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
|
|
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# First check if record exists
|
|
existing = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
|
|
)
|
|
|
|
if existing and existing.data:
|
|
return existing.data
|
|
|
|
# If not found, create new record
|
|
endpoint = EndpointRestriction.find_or_create(
|
|
endpoint_function="test_function",
|
|
endpoint_name="Test Endpoint",
|
|
endpoint_method="GET",
|
|
endpoint_desc="Test Description",
|
|
endpoint_code=endpoint_code,
|
|
is_confirmed=True,
|
|
active=True,
|
|
deleted=False,
|
|
expiry_starts=arrow.now().shift(days=-1).__str__(),
|
|
expiry_ends=arrow.now().shift(days=1).__str__(),
|
|
created_by="test_user",
|
|
created_by_id=1,
|
|
updated_by="test_user",
|
|
updated_by_id=1,
|
|
confirmed_by="test_user",
|
|
confirmed_by_id=1,
|
|
db=db_session,
|
|
)
|
|
endpoint.save(db=db_session)
|
|
return endpoint
|
|
except Exception as e:
|
|
print(f"Error creating sample endpoint: {str(e)}")
|
|
db_session.rollback()
|
|
raise e
|
|
|
|
|
|
def test_filter_by_one():
|
|
"""Test filtering a single record by keyword arguments."""
|
|
print("\nTesting filter_by_one...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
sample_endpoint = create_sample_endpoint_restriction("TEST001")
|
|
result = EndpointRestriction.filter_by_one(
|
|
db=db_session, endpoint_code="TEST001"
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 1
|
|
and result.total_count == 1
|
|
and result.is_list is False
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_by_one_system():
|
|
"""Test filtering a single record by keyword arguments without status filtering."""
|
|
print("\nTesting filter_by_one_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
sample_endpoint = create_sample_endpoint_restriction("TEST002")
|
|
result = EndpointRestriction.filter_by_one(
|
|
db=db_session, endpoint_code="TEST002", system=True
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 1
|
|
and result.total_count == 1
|
|
and result.is_list is False
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_one():
|
|
"""Test filtering a single record by expressions."""
|
|
print("\nTesting filter_one...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
sample_endpoint = create_sample_endpoint_restriction("TEST003")
|
|
result = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_code == "TEST003", db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 1
|
|
and result.total_count == 1
|
|
and result.is_list is False
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_one_system():
|
|
"""Test filtering a single record by expressions without status filtering."""
|
|
print("\nTesting filter_one_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
sample_endpoint = create_sample_endpoint_restriction("TEST004")
|
|
result = EndpointRestriction.filter_one_system(
|
|
EndpointRestriction.endpoint_code == "TEST004", db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 1
|
|
and result.total_count == 1
|
|
and result.is_list is False
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_all():
|
|
"""Test filtering multiple records by expressions."""
|
|
print("\nTesting filter_all...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction("TEST005")
|
|
endpoint2 = create_sample_endpoint_restriction("TEST006")
|
|
|
|
result = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 2
|
|
and result.total_count == 2
|
|
and result.is_list is True
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_all_system():
|
|
"""Test filtering multiple records by expressions without status filtering."""
|
|
print("\nTesting filter_all_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction("TEST007")
|
|
endpoint2 = create_sample_endpoint_restriction("TEST008")
|
|
|
|
result = EndpointRestriction.filter_all_system(
|
|
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 2
|
|
and result.total_count == 2
|
|
and result.is_list is True
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_filter_by_all_system():
|
|
"""Test filtering multiple records by keyword arguments without status filtering."""
|
|
print("\nTesting filter_by_all_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction("TEST009")
|
|
endpoint2 = create_sample_endpoint_restriction("TEST010")
|
|
|
|
result = EndpointRestriction.filter_by_all_system(
|
|
db=db_session, endpoint_method="GET"
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result is not None
|
|
and result.count == 2
|
|
and result.total_count == 2
|
|
and result.is_list is True
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_get_not_expired_query_arg():
|
|
"""Test adding expiry date filtering to query arguments."""
|
|
print("\nTesting get_not_expired_query_arg...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Create a sample endpoint with a unique code
|
|
endpoint_code = (
|
|
f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
|
|
)
|
|
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
|
|
|
|
# Test the query argument generation
|
|
args = EndpointRestriction.get_not_expired_query_arg(())
|
|
|
|
# Verify the arguments
|
|
success = (
|
|
len(args) == 2
|
|
and any(
|
|
str(arg).startswith("endpoint_restriction.expiry_starts")
|
|
for arg in args
|
|
)
|
|
and any(
|
|
str(arg).startswith("endpoint_restriction.expiry_ends")
|
|
for arg in args
|
|
)
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_add_new_arg_to_args():
|
|
"""Test adding new arguments to query arguments."""
|
|
print("\nTesting add_new_arg_to_args...")
|
|
try:
|
|
args = (EndpointRestriction.endpoint_code == "TEST001",)
|
|
new_arg = EndpointRestriction.endpoint_method == "GET"
|
|
|
|
updated_args = EndpointRestriction.add_new_arg_to_args(
|
|
args, "endpoint_method", new_arg
|
|
)
|
|
success = len(updated_args) == 2
|
|
|
|
# Test duplicate prevention
|
|
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
|
|
updated_args = EndpointRestriction.add_new_arg_to_args(
|
|
updated_args, "endpoint_method", duplicate_arg
|
|
)
|
|
success = success and len(updated_args) == 2 # Should not add duplicate
|
|
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_produce_query_to_add():
|
|
"""Test adding query parameters to filter options."""
|
|
print("\nTesting produce_query_to_add...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
sample_endpoint = create_sample_endpoint_restriction("TEST001")
|
|
filter_list = {
|
|
"query": {"endpoint_method": "GET", "endpoint_code": "TEST001"}
|
|
}
|
|
args = ()
|
|
|
|
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
|
|
success = len(updated_args) == 2
|
|
|
|
result = EndpointRestriction.filter_all(*updated_args, db=db_session)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
success
|
|
and result is not None
|
|
and result.count == 1
|
|
and result.total_count == 1
|
|
and result.is_list is True
|
|
)
|
|
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def test_get_dict():
|
|
"""Test the get_dict() function for single-record filters."""
|
|
print("\nTesting get_dict...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
try:
|
|
# Set up pre_query first
|
|
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method == "GET", db=db_session
|
|
).query
|
|
|
|
# Create a sample endpoint
|
|
endpoint_code = "TEST_DICT_001"
|
|
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
|
|
|
|
# Get the endpoint using filter_one
|
|
result = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
|
|
)
|
|
|
|
# Get the data and convert to dict
|
|
data = result.data
|
|
data_dict = data.get_dict()
|
|
|
|
# Test dictionary properties
|
|
success = (
|
|
data_dict is not None
|
|
and isinstance(data_dict, dict)
|
|
and data_dict.get("endpoint_code") == endpoint_code
|
|
and data_dict.get("endpoint_method") == "GET"
|
|
and data_dict.get("endpoint_function") == "test_function"
|
|
and data_dict.get("endpoint_name") == "Test Endpoint"
|
|
and data_dict.get("endpoint_desc") == "Test Description"
|
|
and data_dict.get("is_confirmed") is True
|
|
and data_dict.get("active") is True
|
|
and data_dict.get("deleted") is False
|
|
)
|
|
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
except Exception as e:
|
|
print(f"Test failed with exception: {e}")
|
|
return False
|
|
|
|
|
|
def run_all_tests():
|
|
"""Run all tests and report results."""
|
|
print("Starting EndpointRestriction tests...")
|
|
|
|
# Clean up any existing test data before starting
|
|
cleanup_test_data()
|
|
|
|
tests = [
|
|
test_filter_by_one,
|
|
test_filter_by_one_system,
|
|
test_filter_one,
|
|
test_filter_one_system,
|
|
test_filter_all,
|
|
test_filter_all_system,
|
|
test_filter_by_all_system,
|
|
test_get_not_expired_query_arg,
|
|
test_add_new_arg_to_args,
|
|
test_produce_query_to_add,
|
|
test_get_dict, # Added new test
|
|
]
|
|
passed_list, not_passed_list = [], []
|
|
passed, failed = 0, 0
|
|
|
|
for test in tests:
|
|
# Clean up test data before each test
|
|
cleanup_test_data()
|
|
try:
|
|
if test():
|
|
passed += 1
|
|
passed_list.append(f"Test {test.__name__} passed")
|
|
else:
|
|
failed += 1
|
|
not_passed_list.append(f"Test {test.__name__} failed")
|
|
except Exception as e:
|
|
print(f"Test {test.__name__} failed with exception: {e}")
|
|
failed += 1
|
|
not_passed_list.append(f"Test {test.__name__} failed")
|
|
|
|
print(f"\nTest Results: {passed} passed, {failed} failed")
|
|
print("Passed Tests:")
|
|
print("\n".join(passed_list))
|
|
print("Failed Tests:")
|
|
print("\n".join(not_passed_list))
|
|
|
|
return passed, failed
|
|
|
|
|
|
if __name__ == "__main__":
|
|
generate_table_in_postgres()
|
|
run_all_tests()
|