340 lines
12 KiB
Python
340 lines
12 KiB
Python
import arrow
|
|
from schema import EndpointRestriction
|
|
|
|
def create_sample_endpoint_restriction():
|
|
"""Create a sample endpoint restriction for testing."""
|
|
with EndpointRestriction.new_session() as db_session:
|
|
endpoint = EndpointRestriction.find_or_create(
|
|
endpoint_function="test_function",
|
|
endpoint_name="Test Endpoint",
|
|
endpoint_method="GET",
|
|
endpoint_desc="Test Description",
|
|
endpoint_code="TEST001",
|
|
is_confirmed=True,
|
|
expiry_starts=arrow.now().shift(days=-1),
|
|
expiry_ends=arrow.now().shift(days=1)
|
|
)
|
|
endpoint.save(db=db_session)
|
|
return endpoint
|
|
|
|
def test_filter_by_one():
|
|
"""Test filtering a single record by keyword arguments."""
|
|
print("\nTesting filter_by_one...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
sample_endpoint = create_sample_endpoint_restriction()
|
|
result = EndpointRestriction.filter_by_one(
|
|
db=db_session,
|
|
endpoint_code="TEST001"
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
result.data is not None and
|
|
result.data.endpoint_code == "TEST001" and
|
|
result.is_list is False and
|
|
isinstance(result.data_as_dict, dict) and
|
|
result.data_as_dict.get("endpoint_code") == "TEST001"
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_by_one_system():
|
|
"""Test filtering a single record by keyword arguments without status filtering."""
|
|
print("\nTesting filter_by_one_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
sample_endpoint = create_sample_endpoint_restriction()
|
|
result = EndpointRestriction.filter_by_one(
|
|
db=db_session,
|
|
endpoint_code="TEST001",
|
|
system=True
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
result.data is not None and
|
|
result.data.endpoint_code == "TEST001" and
|
|
result.is_list is False and
|
|
isinstance(result.data_as_dict, dict) and
|
|
result.data_as_dict.get("endpoint_code") == "TEST001"
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_one():
|
|
"""Test filtering a single record by expressions."""
|
|
print("\nTesting filter_one...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
sample_endpoint = create_sample_endpoint_restriction()
|
|
result = EndpointRestriction.filter_one(
|
|
EndpointRestriction.endpoint_code == "TEST001",
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
result.data is not None and
|
|
result.data.endpoint_code == "TEST001" and
|
|
result.is_list is False and
|
|
isinstance(result.data_as_dict, dict) and
|
|
result.data_as_dict.get("endpoint_code") == "TEST001"
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_one_system():
|
|
"""Test filtering a single record by expressions without status filtering."""
|
|
print("\nTesting filter_one_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
sample_endpoint = create_sample_endpoint_restriction()
|
|
result = EndpointRestriction.filter_one_system(
|
|
EndpointRestriction.endpoint_code == "TEST001",
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
result.data is not None and
|
|
result.data.endpoint_code == "TEST001" and
|
|
result.is_list is False and
|
|
isinstance(result.data_as_dict, dict) and
|
|
result.data_as_dict.get("endpoint_code") == "TEST001"
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_all():
|
|
"""Test filtering multiple records by expressions."""
|
|
print("\nTesting filter_all...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction()
|
|
endpoint2 = EndpointRestriction.find_or_create(
|
|
endpoint_function="test_function2",
|
|
endpoint_name="Test Endpoint 2",
|
|
endpoint_method="POST",
|
|
endpoint_desc="Test Description 2",
|
|
endpoint_code="TEST002",
|
|
is_confirmed=True,
|
|
expiry_starts=arrow.now().shift(days=-1),
|
|
expiry_ends=arrow.now().shift(days=1)
|
|
)
|
|
|
|
result = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_method.in_(["GET", "POST"]),
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 2 and
|
|
result.total_count == 2 and
|
|
len(result.data) == 2 and
|
|
{r.endpoint_code for r in result.data} == {"TEST001", "TEST002"} and
|
|
result.is_list is True and
|
|
isinstance(result.data_as_dict, list) and
|
|
len(result.data_as_dict) == 2
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_all_system():
|
|
"""Test filtering multiple records by expressions without status filtering."""
|
|
print("\nTesting filter_all_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction()
|
|
endpoint2 = EndpointRestriction.find_or_create(
|
|
endpoint_function="test_function2",
|
|
endpoint_name="Test Endpoint 2",
|
|
endpoint_method="POST",
|
|
endpoint_desc="Test Description 2",
|
|
endpoint_code="TEST002",
|
|
is_confirmed=True,
|
|
expiry_starts=arrow.now().shift(days=-1),
|
|
expiry_ends=arrow.now().shift(days=1)
|
|
)
|
|
|
|
result = EndpointRestriction.filter_all_system(
|
|
EndpointRestriction.endpoint_method.in_(["GET", "POST"]),
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 2 and
|
|
result.total_count == 2 and
|
|
len(result.data) == 2 and
|
|
{r.endpoint_code for r in result.data} == {"TEST001", "TEST002"} and
|
|
result.is_list is True and
|
|
isinstance(result.data_as_dict, list) and
|
|
len(result.data_as_dict) == 2
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_filter_by_all_system():
|
|
"""Test filtering multiple records by keyword arguments."""
|
|
print("\nTesting filter_by_all_system...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
# Create two endpoint restrictions
|
|
endpoint1 = create_sample_endpoint_restriction()
|
|
endpoint2 = EndpointRestriction.find_or_create(
|
|
endpoint_function="test_function2",
|
|
endpoint_name="Test Endpoint 2",
|
|
endpoint_method="POST",
|
|
endpoint_desc="Test Description 2",
|
|
endpoint_code="TEST002",
|
|
is_confirmed=True,
|
|
expiry_starts=arrow.now().shift(days=-1),
|
|
expiry_ends=arrow.now().shift(days=1)
|
|
)
|
|
|
|
result = EndpointRestriction.filter_by_all_system(
|
|
db=db_session,
|
|
endpoint_method="POST"
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
len(result.data) == 1 and
|
|
result.data[0].endpoint_code == "TEST002" and
|
|
result.is_list is True and
|
|
isinstance(result.data_as_dict, list) and
|
|
len(result.data_as_dict) == 1
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_get_not_expired_query_arg():
|
|
"""Test expiry date filtering in query arguments."""
|
|
print("\nTesting get_not_expired_query_arg...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
# Create active and expired endpoints
|
|
active_endpoint = create_sample_endpoint_restriction()
|
|
expired_endpoint = EndpointRestriction.find_or_create(
|
|
endpoint_function="expired_function",
|
|
endpoint_name="Expired Endpoint",
|
|
endpoint_method="GET",
|
|
endpoint_desc="Expired Description",
|
|
endpoint_code="EXP001",
|
|
is_confirmed=True,
|
|
expiry_starts=arrow.now().shift(days=-2),
|
|
expiry_ends=arrow.now().shift(days=-1)
|
|
)
|
|
|
|
result = EndpointRestriction.filter_all(
|
|
EndpointRestriction.endpoint_code.in_(["TEST001", "EXP001"]),
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
len(result.data) == 1 and
|
|
result.data[0].endpoint_code == "TEST001" and
|
|
result.is_list is True and
|
|
isinstance(result.data_as_dict, list) and
|
|
len(result.data_as_dict) == 1 and
|
|
result.data_as_dict[0].get("endpoint_code") == "TEST001"
|
|
)
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_add_new_arg_to_args():
|
|
"""Test adding new arguments to query arguments."""
|
|
print("\nTesting add_new_arg_to_args...")
|
|
args = (EndpointRestriction.endpoint_code == "TEST001",)
|
|
new_arg = EndpointRestriction.endpoint_method == "GET"
|
|
|
|
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg)
|
|
success = len(updated_args) == 2
|
|
|
|
# Test duplicate prevention
|
|
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
|
|
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg)
|
|
success = success and len(updated_args) == 2 # Should not add duplicate
|
|
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def test_produce_query_to_add():
|
|
"""Test adding query parameters to filter options."""
|
|
print("\nTesting produce_query_to_add...")
|
|
with EndpointRestriction.new_session() as db_session:
|
|
sample_endpoint = create_sample_endpoint_restriction()
|
|
filter_list = {
|
|
"query": {
|
|
"endpoint_method": "GET",
|
|
"endpoint_code": "TEST001"
|
|
}
|
|
}
|
|
args = ()
|
|
|
|
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
|
|
success = len(updated_args) == 2
|
|
|
|
result = EndpointRestriction.filter_all(
|
|
*updated_args,
|
|
db=db_session
|
|
)
|
|
|
|
# Test PostgresResponse properties
|
|
success = (
|
|
success and
|
|
result.count == 1 and
|
|
result.total_count == 1 and
|
|
len(result.data) == 1 and
|
|
result.data[0].endpoint_code == "TEST001" and
|
|
result.is_list is True and
|
|
isinstance(result.data_as_dict, list) and
|
|
len(result.data_as_dict) == 1 and
|
|
result.data_as_dict[0].get("endpoint_code") == "TEST001"
|
|
)
|
|
|
|
print(f"Test {'passed' if success else 'failed'}")
|
|
return success
|
|
|
|
def run_all_tests():
|
|
"""Run all tests and report results."""
|
|
print("Starting EndpointRestriction tests...")
|
|
tests = [
|
|
test_filter_by_one,
|
|
test_filter_by_one_system,
|
|
test_filter_one,
|
|
test_filter_one_system,
|
|
test_filter_all,
|
|
test_filter_all_system,
|
|
test_filter_by_all_system,
|
|
test_get_not_expired_query_arg,
|
|
test_add_new_arg_to_args,
|
|
test_produce_query_to_add
|
|
]
|
|
|
|
passed = 0
|
|
failed = 0
|
|
|
|
for test in tests:
|
|
if test():
|
|
passed += 1
|
|
else:
|
|
failed += 1
|
|
|
|
print(f"\nTest Summary:")
|
|
print(f"Total tests: {len(tests)}")
|
|
print(f"Passed: {passed}")
|
|
print(f"Failed: {failed}")
|
|
|
|
if __name__ == "__main__":
|
|
run_all_tests()
|