import arrow from Controllers.Postgres.schema import EndpointRestriction from Controllers.Postgres.database import Base, engine def generate_table_in_postgres(): """Create the endpoint_restriction table in PostgreSQL if it doesn't exist.""" # Create all tables defined in the Base metadata Base.metadata.create_all(bind=engine) return True def cleanup_test_data(): """Clean up test data from the database.""" with EndpointRestriction.new_session() as db_session: try: # Get all test records test_records = EndpointRestriction.filter_all( EndpointRestriction.endpoint_code.like("TEST%"), db=db_session ).data # Delete each record using the same session for record in test_records: # Merge the record into the current session if it's not already attached if record not in db_session: record = db_session.merge(record) db_session.delete(record) db_session.commit() except Exception as e: print(f"Error cleaning up test data: {str(e)}") db_session.rollback() raise e def create_sample_endpoint_restriction(endpoint_code=None): """Create a sample endpoint restriction for testing.""" if endpoint_code is None: # Generate a unique endpoint code using timestamp and random number endpoint_code = f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}" with EndpointRestriction.new_session() as db_session: try: # First check if record exists existing = EndpointRestriction.filter_one( EndpointRestriction.endpoint_code == endpoint_code, db=db_session ) if existing and existing.data: return existing.data # If not found, create new record endpoint = EndpointRestriction.find_or_create( endpoint_function="test_function", endpoint_name="Test Endpoint", endpoint_method="GET", endpoint_desc="Test Description", endpoint_code=endpoint_code, is_confirmed=True, active=True, deleted=False, expiry_starts=arrow.now().shift(days=-1).__str__(), expiry_ends=arrow.now().shift(days=1).__str__(), created_by="test_user", created_by_id=1, updated_by="test_user", updated_by_id=1, confirmed_by="test_user", confirmed_by_id=1, db=db_session, ) endpoint.save(db=db_session) return endpoint except Exception as e: print(f"Error creating sample endpoint: {str(e)}") db_session.rollback() raise e def test_filter_by_one(): """Test filtering a single record by keyword arguments.""" print("\nTesting filter_by_one...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query sample_endpoint = create_sample_endpoint_restriction("TEST001") result = EndpointRestriction.filter_by_one( db=db_session, endpoint_code="TEST001" ) # Test PostgresResponse properties success = ( result is not None and result.count == 1 and result.total_count == 1 and result.is_list is False ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_by_one_system(): """Test filtering a single record by keyword arguments without status filtering.""" print("\nTesting filter_by_one_system...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query sample_endpoint = create_sample_endpoint_restriction("TEST002") result = EndpointRestriction.filter_by_one( db=db_session, endpoint_code="TEST002", system=True ) # Test PostgresResponse properties success = ( result is not None and result.count == 1 and result.total_count == 1 and result.is_list is False ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_one(): """Test filtering a single record by expressions.""" print("\nTesting filter_one...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query sample_endpoint = create_sample_endpoint_restriction("TEST003") result = EndpointRestriction.filter_one( EndpointRestriction.endpoint_code == "TEST003", db=db_session ) # Test PostgresResponse properties success = ( result is not None and result.count == 1 and result.total_count == 1 and result.is_list is False ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_one_system(): """Test filtering a single record by expressions without status filtering.""" print("\nTesting filter_one_system...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query sample_endpoint = create_sample_endpoint_restriction("TEST004") result = EndpointRestriction.filter_one_system( EndpointRestriction.endpoint_code == "TEST004", db=db_session ) # Test PostgresResponse properties success = ( result is not None and result.count == 1 and result.total_count == 1 and result.is_list is False ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_all(): """Test filtering multiple records by expressions.""" print("\nTesting filter_all...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query # Create two endpoint restrictions endpoint1 = create_sample_endpoint_restriction("TEST005") endpoint2 = create_sample_endpoint_restriction("TEST006") result = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session ) # Test PostgresResponse properties success = ( result is not None and result.count == 2 and result.total_count == 2 and result.is_list is True ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_all_system(): """Test filtering multiple records by expressions without status filtering.""" print("\nTesting filter_all_system...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query # Create two endpoint restrictions endpoint1 = create_sample_endpoint_restriction("TEST007") endpoint2 = create_sample_endpoint_restriction("TEST008") result = EndpointRestriction.filter_all_system( EndpointRestriction.endpoint_method.in_(["GET", "GET"]), db=db_session ) # Test PostgresResponse properties success = ( result is not None and result.count == 2 and result.total_count == 2 and result.is_list is True ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_filter_by_all_system(): """Test filtering multiple records by keyword arguments without status filtering.""" print("\nTesting filter_by_all_system...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query # Create two endpoint restrictions endpoint1 = create_sample_endpoint_restriction("TEST009") endpoint2 = create_sample_endpoint_restriction("TEST010") result = EndpointRestriction.filter_by_all_system( db=db_session, endpoint_method="GET" ) # Test PostgresResponse properties success = ( result is not None and result.count == 2 and result.total_count == 2 and result.is_list is True ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_get_not_expired_query_arg(): """Test adding expiry date filtering to query arguments.""" print("\nTesting get_not_expired_query_arg...") with EndpointRestriction.new_session() as db_session: try: # Create a sample endpoint with a unique code endpoint_code = ( f"TEST{int(arrow.now().timestamp())}{arrow.now().microsecond}" ) sample_endpoint = create_sample_endpoint_restriction(endpoint_code) # Test the query argument generation args = EndpointRestriction.get_not_expired_query_arg(()) # Verify the arguments success = ( len(args) == 2 and any( str(arg).startswith("endpoint_restriction.expiry_starts") for arg in args ) and any( str(arg).startswith("endpoint_restriction.expiry_ends") for arg in args ) ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_add_new_arg_to_args(): """Test adding new arguments to query arguments.""" print("\nTesting add_new_arg_to_args...") try: args = (EndpointRestriction.endpoint_code == "TEST001",) new_arg = EndpointRestriction.endpoint_method == "GET" updated_args = EndpointRestriction.add_new_arg_to_args( args, "endpoint_method", new_arg ) success = len(updated_args) == 2 # Test duplicate prevention duplicate_arg = EndpointRestriction.endpoint_method == "GET" updated_args = EndpointRestriction.add_new_arg_to_args( updated_args, "endpoint_method", duplicate_arg ) success = success and len(updated_args) == 2 # Should not add duplicate print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_produce_query_to_add(): """Test adding query parameters to filter options.""" print("\nTesting produce_query_to_add...") with EndpointRestriction.new_session() as db_session: try: sample_endpoint = create_sample_endpoint_restriction("TEST001") filter_list = { "query": {"endpoint_method": "GET", "endpoint_code": "TEST001"} } args = () updated_args = EndpointRestriction.produce_query_to_add(filter_list, args) success = len(updated_args) == 2 result = EndpointRestriction.filter_all(*updated_args, db=db_session) # Test PostgresResponse properties success = ( success and result is not None and result.count == 1 and result.total_count == 1 and result.is_list is True ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def test_get_dict(): """Test the get_dict() function for single-record filters.""" print("\nTesting get_dict...") with EndpointRestriction.new_session() as db_session: try: # Set up pre_query first EndpointRestriction.pre_query = EndpointRestriction.filter_all( EndpointRestriction.endpoint_method == "GET", db=db_session ).query # Create a sample endpoint endpoint_code = "TEST_DICT_001" sample_endpoint = create_sample_endpoint_restriction(endpoint_code) # Get the endpoint using filter_one result = EndpointRestriction.filter_one( EndpointRestriction.endpoint_code == endpoint_code, db=db_session ) # Get the data and convert to dict data = result.data data_dict = data.get_dict() # Test dictionary properties success = ( data_dict is not None and isinstance(data_dict, dict) and data_dict.get("endpoint_code") == endpoint_code and data_dict.get("endpoint_method") == "GET" and data_dict.get("endpoint_function") == "test_function" and data_dict.get("endpoint_name") == "Test Endpoint" and data_dict.get("endpoint_desc") == "Test Description" and data_dict.get("is_confirmed") is True and data_dict.get("active") is True and data_dict.get("deleted") is False ) print(f"Test {'passed' if success else 'failed'}") return success except Exception as e: print(f"Test failed with exception: {e}") return False def run_all_tests(): """Run all tests and report results.""" print("Starting EndpointRestriction tests...") # Clean up any existing test data before starting cleanup_test_data() tests = [ test_filter_by_one, test_filter_by_one_system, test_filter_one, test_filter_one_system, test_filter_all, test_filter_all_system, test_filter_by_all_system, test_get_not_expired_query_arg, test_add_new_arg_to_args, test_produce_query_to_add, test_get_dict, # Added new test ] passed_list, not_passed_list = [], [] passed, failed = 0, 0 for test in tests: # Clean up test data before each test cleanup_test_data() try: if test(): passed += 1 passed_list.append(f"Test {test.__name__} passed") else: failed += 1 not_passed_list.append(f"Test {test.__name__} failed") except Exception as e: print(f"Test {test.__name__} failed with exception: {e}") failed += 1 not_passed_list.append(f"Test {test.__name__} failed") print(f"\nTest Results: {passed} passed, {failed} failed") print("Passed Tests:") print("\n".join(passed_list)) print("Failed Tests:") print("\n".join(not_passed_list)) return passed, failed def run_simple_concurrent_test(num_threads=10): """Run a simplified concurrent test that just verifies connection pooling.""" import threading import time import random from concurrent.futures import ThreadPoolExecutor print(f"\nStarting simple concurrent test with {num_threads} threads...") # Results tracking results = {"passed": 0, "failed": 0, "errors": []} results_lock = threading.Lock() def worker(thread_id): try: # Simple query to test connection pooling with EndpointRestriction.new_session() as db_session: # Just run a simple count query count_query = db_session.query(EndpointRestriction).count() # Small delay to simulate work time.sleep(random.uniform(0.01, 0.05)) # Simple success criteria success = count_query >= 0 # Update results with thread safety with results_lock: if success: results["passed"] += 1 else: results["failed"] += 1 results["errors"].append(f"Thread {thread_id} failed to get count") except Exception as e: with results_lock: results["failed"] += 1 results["errors"].append(f"Thread {thread_id} exception: {str(e)}") # Create and start threads using a thread pool start_time = time.time() with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [executor.submit(worker, i) for i in range(num_threads)] # Calculate execution time execution_time = time.time() - start_time # Print results print(f"\nConcurrent Operation Test Results:") print(f"Total threads: {num_threads}") print(f"Passed: {results['passed']}") print(f"Failed: {results['failed']}") print(f"Execution time: {execution_time:.2f} seconds") print(f"Operations per second: {num_threads / execution_time:.2f}") if results["failed"] > 0: print("\nErrors:") for error in results["errors"][:10]: # Show only first 10 errors to avoid flooding output print(f"- {error}") if len(results["errors"]) > 10: print(f"- ... and {len(results['errors']) - 10} more errors") return results["failed"] == 0 if __name__ == "__main__": generate_table_in_postgres() passed, failed = run_all_tests() # If all tests pass, run the simple concurrent test if failed == 0: run_simple_concurrent_test(100)