prod-wag-backend-automate-s.../Controllers/Postgres/implementations.py

486 lines
18 KiB
Python

import arrow
from Controllers.Postgres.schema import EndpointRestriction
from Controllers.Postgres.database import Base, engine
def generate_table_in_postgres():
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
# Create all tables defined in the Base metadata
Base.metadata.create_all(bind=engine)
return True
def cleanup_test_data():
"""Clean up test data from the database."""
with EndpointRestriction.new_session() as db_session:
try:
# Get all test records
test_records = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_code.like("TEST%"),
db=db_session
).data
# Delete each record using the same session
for record in test_records:
# Merge the record into the current session if it's not already attached
if record not in db_session:
record = db_session.merge(record)
db_session.delete(record)
db_session.commit()
except Exception as e:
print(f"Error cleaning up test data: {str(e)}")
db_session.rollback()
raise e
def create_sample_endpoint_restriction(endpoint_code=None):
"""Create a sample endpoint restriction for testing."""
if endpoint_code is None:
# Generate a unique endpoint code using timestamp and random number
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
with EndpointRestriction.new_session() as db_session:
try:
# First check if record exists
existing = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code,
db=db_session
)
if existing and existing.data:
return existing.data
# If not found, create new record
endpoint = EndpointRestriction.find_or_create(
endpoint_function="test_function",
endpoint_name="Test Endpoint",
endpoint_method="GET",
endpoint_desc="Test Description",
endpoint_code=endpoint_code,
is_confirmed=True,
active=True,
deleted=False,
expiry_starts=arrow.now().shift(days=-1).__str__(),
expiry_ends=arrow.now().shift(days=1).__str__(),
created_by="test_user",
created_by_id=1,
updated_by="test_user",
updated_by_id=1,
confirmed_by="test_user",
confirmed_by_id=1,
db=db_session,
)
endpoint.save(db=db_session)
return endpoint
except Exception as e:
print(f"Error creating sample endpoint: {str(e)}")
db_session.rollback()
raise e
def test_filter_by_one():
"""Test filtering a single record by keyword arguments."""
print("\nTesting filter_by_one...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST001")
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST001"
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_by_one_system():
"""Test filtering a single record by keyword arguments without status filtering."""
print("\nTesting filter_by_one_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST002")
result = EndpointRestriction.filter_by_one(
db=db_session,
endpoint_code="TEST002",
system=True
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_one():
"""Test filtering a single record by expressions."""
print("\nTesting filter_one...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST003")
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == "TEST003",
db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_one_system():
"""Test filtering a single record by expressions without status filtering."""
print("\nTesting filter_one_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST004")
result = EndpointRestriction.filter_one_system(
EndpointRestriction.endpoint_code == "TEST004",
db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_all():
"""Test filtering multiple records by expressions."""
print("\nTesting filter_all...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST005")
endpoint2 = create_sample_endpoint_restriction("TEST006")
result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_all_system():
"""Test filtering multiple records by expressions without status filtering."""
print("\nTesting filter_all_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST007")
endpoint2 = create_sample_endpoint_restriction("TEST008")
result = EndpointRestriction.filter_all_system(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]),
db=db_session
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_by_all_system():
"""Test filtering multiple records by keyword arguments without status filtering."""
print("\nTesting filter_by_all_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST009")
endpoint2 = create_sample_endpoint_restriction("TEST010")
result = EndpointRestriction.filter_by_all_system(
db=db_session,
endpoint_method="GET"
)
# Test PostgresResponse properties
success = (
result is not None and
result.count == 2 and
result.total_count == 2 and
result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_get_not_expired_query_arg():
"""Test adding expiry date filtering to query arguments."""
print("\nTesting get_not_expired_query_arg...")
with EndpointRestriction.new_session() as db_session:
try:
# Create a sample endpoint with a unique code
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Test the query argument generation
args = EndpointRestriction.get_not_expired_query_arg(())
# Verify the arguments
success = (
len(args) == 2 and
any(str(arg).startswith("endpoint_restriction.expiry_starts") for arg in args) and
any(str(arg).startswith("endpoint_restriction.expiry_ends") for arg in args)
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_add_new_arg_to_args():
"""Test adding new arguments to query arguments."""
print("\nTesting add_new_arg_to_args...")
try:
args = (EndpointRestriction.endpoint_code == "TEST001",)
new_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(args, "endpoint_method", new_arg)
success = len(updated_args) == 2
# Test duplicate prevention
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(updated_args, "endpoint_method", duplicate_arg)
success = success and len(updated_args) == 2 # Should not add duplicate
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_produce_query_to_add():
"""Test adding query parameters to filter options."""
print("\nTesting produce_query_to_add...")
with EndpointRestriction.new_session() as db_session:
try:
sample_endpoint = create_sample_endpoint_restriction("TEST001")
filter_list = {
"query": {
"endpoint_method": "GET",
"endpoint_code": "TEST001"
}
}
args = ()
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
success = len(updated_args) == 2
result = EndpointRestriction.filter_all(
*updated_args,
db=db_session
)
# Test PostgresResponse properties
success = (
success and
result is not None and
result.count == 1 and
result.total_count == 1 and
result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_get_dict():
"""Test the get_dict() function for single-record filters."""
print("\nTesting get_dict...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET",
db=db_session
).query
# Create a sample endpoint
endpoint_code = "TEST_DICT_001"
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Get the endpoint using filter_one
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code,
db=db_session
)
# Get the data and convert to dict
data = result.data
data_dict = data.get_dict()
# Test dictionary properties
success = (
data_dict is not None and
isinstance(data_dict, dict) and
data_dict.get("endpoint_code") == endpoint_code and
data_dict.get("endpoint_method") == "GET" and
data_dict.get("endpoint_function") == "test_function" and
data_dict.get("endpoint_name") == "Test Endpoint" and
data_dict.get("endpoint_desc") == "Test Description" and
data_dict.get("is_confirmed") is True and
data_dict.get("active") is True and
data_dict.get("deleted") is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def run_all_tests():
"""Run all tests and report results."""
print("Starting EndpointRestriction tests...")
# Clean up any existing test data before starting
cleanup_test_data()
tests = [
test_filter_by_one,
test_filter_by_one_system,
test_filter_one,
test_filter_one_system,
test_filter_all,
test_filter_all_system,
test_filter_by_all_system,
test_get_not_expired_query_arg,
test_add_new_arg_to_args,
test_produce_query_to_add,
test_get_dict # Added new test
]
passed_list, not_passed_list = [], []
passed, failed = 0, 0
for test in tests:
# Clean up test data before each test
cleanup_test_data()
try:
if test():
passed += 1
passed_list.append(
f"Test {test.__name__} passed"
)
else:
failed += 1
not_passed_list.append(
f"Test {test.__name__} failed"
)
except Exception as e:
print(f"Test {test.__name__} failed with exception: {e}")
failed += 1
not_passed_list.append(
f"Test {test.__name__} failed"
)
print(f"\nTest Results: {passed} passed, {failed} failed")
print('Passed Tests:')
print(
"\n".join(passed_list)
)
print('Failed Tests:')
print(
"\n".join(not_passed_list)
)
return passed, failed
if __name__ == "__main__":
generate_table_in_postgres()
run_all_tests()