prod-wag-backend-automate-s.../Controllers/Postgres/implementations.py

546 lines
20 KiB
Python

import arrow
from Controllers.Postgres.schema import EndpointRestriction
from Controllers.Postgres.database import Base, engine
def generate_table_in_postgres():
"""Create the endpoint_restriction table in PostgreSQL if it doesn't exist."""
# Create all tables defined in the Base metadata
Base.metadata.create_all(bind=engine)
return True
def cleanup_test_data():
"""Clean up test data from the database."""
with EndpointRestriction.new_session() as db_session:
try:
# Get all test records
test_records = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_code.like("TEST%"), db=db_session
).data
# Delete each record using the same session
for record in test_records:
# Merge the record into the current session if it's not already attached
if record not in db_session:
record = db_session.merge(record)
db_session.delete(record)
db_session.commit()
except Exception as e:
print(f"Error cleaning up test data: {str(e)}")
db_session.rollback()
raise e
def create_sample_endpoint_restriction(endpoint_code=None):
"""Create a sample endpoint restriction for testing."""
if endpoint_code is None:
# Generate a unique endpoint code using timestamp and random number
endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
with EndpointRestriction.new_session() as db_session:
try:
# First check if record exists
existing = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
)
if existing and existing.data:
return existing.data
# If not found, create new record
endpoint = EndpointRestriction.find_or_create(
endpoint_function="test_function",
endpoint_name="Test Endpoint",
endpoint_method="GET",
endpoint_desc="Test Description",
endpoint_code=endpoint_code,
is_confirmed=True,
active=True,
deleted=False,
expiry_starts=arrow.now().shift(days=-1).__str__(),
expiry_ends=arrow.now().shift(days=1).__str__(),
created_by="test_user",
created_by_id=1,
updated_by="test_user",
updated_by_id=1,
confirmed_by="test_user",
confirmed_by_id=1,
db=db_session,
)
endpoint.save(db=db_session)
return endpoint
except Exception as e:
print(f"Error creating sample endpoint: {str(e)}")
db_session.rollback()
raise e
def test_filter_by_one():
"""Test filtering a single record by keyword arguments."""
print("\nTesting filter_by_one...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST001")
result = EndpointRestriction.filter_by_one(
db=db_session, endpoint_code="TEST001"
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_by_one_system():
"""Test filtering a single record by keyword arguments without status filtering."""
print("\nTesting filter_by_one_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST002")
result = EndpointRestriction.filter_by_one(
db=db_session, endpoint_code="TEST002", system=True
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_one():
"""Test filtering a single record by expressions."""
print("\nTesting filter_one...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST003")
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == "TEST003", db=db_session
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_one_system():
"""Test filtering a single record by expressions without status filtering."""
print("\nTesting filter_one_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
sample_endpoint = create_sample_endpoint_restriction("TEST004")
result = EndpointRestriction.filter_one_system(
EndpointRestriction.endpoint_code == "TEST004", db=db_session
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_all():
"""Test filtering multiple records by expressions."""
print("\nTesting filter_all...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST005")
endpoint2 = create_sample_endpoint_restriction("TEST006")
result = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_all_system():
"""Test filtering multiple records by expressions without status filtering."""
print("\nTesting filter_all_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST007")
endpoint2 = create_sample_endpoint_restriction("TEST008")
result = EndpointRestriction.filter_all_system(
EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_filter_by_all_system():
"""Test filtering multiple records by keyword arguments without status filtering."""
print("\nTesting filter_by_all_system...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create two endpoint restrictions
endpoint1 = create_sample_endpoint_restriction("TEST009")
endpoint2 = create_sample_endpoint_restriction("TEST010")
result = EndpointRestriction.filter_by_all_system(
db=db_session, endpoint_method="GET"
)
# Test PostgresResponse properties
success = (
result is not None
and result.count == 2
and result.total_count == 2
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_get_not_expired_query_arg():
"""Test adding expiry date filtering to query arguments."""
print("\nTesting get_not_expired_query_arg...")
with EndpointRestriction.new_session() as db_session:
try:
# Create a sample endpoint with a unique code
endpoint_code = (
f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}"
)
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Test the query argument generation
args = EndpointRestriction.get_not_expired_query_arg(())
# Verify the arguments
success = (
len(args) == 2
and any(
str(arg).startswith("endpoint_restriction.expiry_starts")
for arg in args
)
and any(
str(arg).startswith("endpoint_restriction.expiry_ends")
for arg in args
)
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_add_new_arg_to_args():
"""Test adding new arguments to query arguments."""
print("\nTesting add_new_arg_to_args...")
try:
args = (EndpointRestriction.endpoint_code == "TEST001",)
new_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(
args, "endpoint_method", new_arg
)
success = len(updated_args) == 2
# Test duplicate prevention
duplicate_arg = EndpointRestriction.endpoint_method == "GET"
updated_args = EndpointRestriction.add_new_arg_to_args(
updated_args, "endpoint_method", duplicate_arg
)
success = success and len(updated_args) == 2 # Should not add duplicate
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_produce_query_to_add():
"""Test adding query parameters to filter options."""
print("\nTesting produce_query_to_add...")
with EndpointRestriction.new_session() as db_session:
try:
sample_endpoint = create_sample_endpoint_restriction("TEST001")
filter_list = {
"query": {"endpoint_method": "GET", "endpoint_code": "TEST001"}
}
args = ()
updated_args = EndpointRestriction.produce_query_to_add(filter_list, args)
success = len(updated_args) == 2
result = EndpointRestriction.filter_all(*updated_args, db=db_session)
# Test PostgresResponse properties
success = (
success
and result is not None
and result.count == 1
and result.total_count == 1
and result.is_list is True
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_get_dict():
"""Test the get_dict() function for single-record filters."""
print("\nTesting get_dict...")
with EndpointRestriction.new_session() as db_session:
try:
# Set up pre_query first
EndpointRestriction.pre_query = EndpointRestriction.filter_all(
EndpointRestriction.endpoint_method == "GET", db=db_session
).query
# Create a sample endpoint
endpoint_code = "TEST_DICT_001"
sample_endpoint = create_sample_endpoint_restriction(endpoint_code)
# Get the endpoint using filter_one
result = EndpointRestriction.filter_one(
EndpointRestriction.endpoint_code == endpoint_code, db=db_session
)
# Get the data and convert to dict
data = result.data
data_dict = data.get_dict()
# Test dictionary properties
success = (
data_dict is not None
and isinstance(data_dict, dict)
and data_dict.get("endpoint_code") == endpoint_code
and data_dict.get("endpoint_method") == "GET"
and data_dict.get("endpoint_function") == "test_function"
and data_dict.get("endpoint_name") == "Test Endpoint"
and data_dict.get("endpoint_desc") == "Test Description"
and data_dict.get("is_confirmed") is True
and data_dict.get("active") is True
and data_dict.get("deleted") is False
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def run_all_tests():
"""Run all tests and report results."""
print("Starting EndpointRestriction tests...")
# Clean up any existing test data before starting
cleanup_test_data()
tests = [
test_filter_by_one,
test_filter_by_one_system,
test_filter_one,
test_filter_one_system,
test_filter_all,
test_filter_all_system,
test_filter_by_all_system,
test_get_not_expired_query_arg,
test_add_new_arg_to_args,
test_produce_query_to_add,
test_get_dict, # Added new test
]
passed_list, not_passed_list = [], []
passed, failed = 0, 0
for test in tests:
# Clean up test data before each test
cleanup_test_data()
try:
if test():
passed += 1
passed_list.append(f"Test {test.__name__} passed")
else:
failed += 1
not_passed_list.append(f"Test {test.__name__} failed")
except Exception as e:
print(f"Test {test.__name__} failed with exception: {e}")
failed += 1
not_passed_list.append(f"Test {test.__name__} failed")
print(f"\nTest Results: {passed} passed, {failed} failed")
print("Passed Tests:")
print("\n".join(passed_list))
print("Failed Tests:")
print("\n".join(not_passed_list))
return passed, failed
def run_simple_concurrent_test(num_threads=10):
"""Run a simplified concurrent test that just verifies connection pooling."""
import threading
import time
import random
from concurrent.futures import ThreadPoolExecutor
print(f"\nStarting simple concurrent test with {num_threads} threads...")
# Results tracking
results = {"passed": 0, "failed": 0, "errors": []}
results_lock = threading.Lock()
def worker(thread_id):
try:
# Simple query to test connection pooling
with EndpointRestriction.new_session() as db_session:
# Just run a simple count query
count_query = db_session.query(EndpointRestriction).count()
# Small delay to simulate work
time.sleep(random.uniform(0.01, 0.05))
# Simple success criteria
success = count_query >= 0
# Update results with thread safety
with results_lock:
if success:
results["passed"] += 1
else:
results["failed"] += 1
results["errors"].append(f"Thread {thread_id} failed to get count")
except Exception as e:
with results_lock:
results["failed"] += 1
results["errors"].append(f"Thread {thread_id} exception: {str(e)}")
# Create and start threads using a thread pool
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(worker, i) for i in range(num_threads)]
# Calculate execution time
execution_time = time.time() - start_time
# Print results
print(f"\nConcurrent Operation Test Results:")
print(f"Total threads: {num_threads}")
print(f"Passed: {results['passed']}")
print(f"Failed: {results['failed']}")
print(f"Execution time: {execution_time:.2f} seconds")
print(f"Operations per second: {num_threads / execution_time:.2f}")
if results["failed"] > 0:
print("\nErrors:")
for error in results["errors"][:10]: # Show only first 10 errors to avoid flooding output
print(f"- {error}")
if len(results["errors"]) > 10:
print(f"- ... and {len(results['errors']) - 10} more errors")
return results["failed"] == 0
if __name__ == "__main__":
generate_table_in_postgres()
passed, failed = run_all_tests()
# If all tests pass, run the simple concurrent test
if failed == 0:
run_simple_concurrent_test(100)