black shift
This commit is contained in:
@@ -63,7 +63,7 @@ class MongoDBHandler:
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
self._debug_mode = debug_mode
|
||||
self._mock_mode = mock_mode
|
||||
|
||||
|
||||
if mock_mode:
|
||||
# In mock mode, we don't need a real connection string
|
||||
self.uri = "mongodb://mock:27017/mockdb"
|
||||
@@ -76,7 +76,7 @@ class MongoDBHandler:
|
||||
# Use the configured connection string with authentication
|
||||
self.uri = mongo_configs.url
|
||||
print(f"Connecting to MongoDB: {self.uri}")
|
||||
|
||||
|
||||
# Define MongoDB client options with increased timeouts for better reliability
|
||||
self.client_options = {
|
||||
"maxPoolSize": 5,
|
||||
@@ -132,11 +132,13 @@ class CollectionContext:
|
||||
# If we're in mock mode, return a mock collection immediately
|
||||
if self.db_handler._mock_mode:
|
||||
return self._create_mock_collection()
|
||||
|
||||
|
||||
try:
|
||||
# Create a new client connection
|
||||
self.client = MongoClient(self.db_handler.uri, **self.db_handler.client_options)
|
||||
|
||||
self.client = MongoClient(
|
||||
self.db_handler.uri, **self.db_handler.client_options
|
||||
)
|
||||
|
||||
if self.db_handler._debug_mode:
|
||||
# In debug mode, we explicitly use the configured DB
|
||||
db_name = mongo_configs.DB
|
||||
@@ -148,7 +150,7 @@ class CollectionContext:
|
||||
except Exception:
|
||||
db_name = mongo_configs.DB
|
||||
print(f"Using fallback database '{db_name}'")
|
||||
|
||||
|
||||
self.collection = self.client[db_name][self.collection_name]
|
||||
|
||||
# Enhance collection methods with retry capabilities
|
||||
@@ -159,13 +161,17 @@ class CollectionContext:
|
||||
if "Authentication failed" in str(e):
|
||||
print(f"MongoDB authentication error: {e}")
|
||||
print("Attempting to reconnect with direct connection...")
|
||||
|
||||
|
||||
try:
|
||||
# Try a direct connection without authentication for testing
|
||||
direct_uri = f"mongodb://{mongo_configs.HOST}:{mongo_configs.PORT}/{mongo_configs.DB}"
|
||||
print(f"Trying direct connection: {direct_uri}")
|
||||
self.client = MongoClient(direct_uri, **self.db_handler.client_options)
|
||||
self.collection = self.client[mongo_configs.DB][self.collection_name]
|
||||
self.client = MongoClient(
|
||||
direct_uri, **self.db_handler.client_options
|
||||
)
|
||||
self.collection = self.client[mongo_configs.DB][
|
||||
self.collection_name
|
||||
]
|
||||
self._add_retry_capabilities()
|
||||
return self.collection
|
||||
except Exception as inner_e:
|
||||
@@ -181,54 +187,56 @@ class CollectionContext:
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
|
||||
|
||||
return self._create_mock_collection()
|
||||
|
||||
|
||||
def _create_mock_collection(self):
|
||||
"""
|
||||
Create a mock collection for testing or graceful degradation.
|
||||
This prevents the application from crashing when MongoDB is unavailable.
|
||||
|
||||
|
||||
Returns:
|
||||
A mock MongoDB collection with simulated behaviors
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
if self.db_handler._mock_mode:
|
||||
print(f"MOCK MODE: Using mock collection '{self.collection_name}'")
|
||||
else:
|
||||
print(f"Using mock MongoDB collection '{self.collection_name}' for graceful degradation")
|
||||
|
||||
print(
|
||||
f"Using mock MongoDB collection '{self.collection_name}' for graceful degradation"
|
||||
)
|
||||
|
||||
# Create in-memory storage for this mock collection
|
||||
if not hasattr(self.db_handler, '_mock_storage'):
|
||||
if not hasattr(self.db_handler, "_mock_storage"):
|
||||
self.db_handler._mock_storage = {}
|
||||
|
||||
|
||||
if self.collection_name not in self.db_handler._mock_storage:
|
||||
self.db_handler._mock_storage[self.collection_name] = []
|
||||
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_data = self.db_handler._mock_storage[self.collection_name]
|
||||
|
||||
|
||||
# Define behavior for find operations
|
||||
def mock_find(query=None, *args, **kwargs):
|
||||
# Simple implementation that returns all documents
|
||||
return mock_data
|
||||
|
||||
|
||||
def mock_find_one(query=None, *args, **kwargs):
|
||||
# Simple implementation that returns the first matching document
|
||||
if not mock_data:
|
||||
return None
|
||||
return mock_data[0]
|
||||
|
||||
|
||||
def mock_insert_one(document, *args, **kwargs):
|
||||
# Add _id if not present
|
||||
if '_id' not in document:
|
||||
document['_id'] = f"mock_id_{len(mock_data)}"
|
||||
if "_id" not in document:
|
||||
document["_id"] = f"mock_id_{len(mock_data)}"
|
||||
mock_data.append(document)
|
||||
result = MagicMock()
|
||||
result.inserted_id = document['_id']
|
||||
result.inserted_id = document["_id"]
|
||||
return result
|
||||
|
||||
|
||||
def mock_insert_many(documents, *args, **kwargs):
|
||||
inserted_ids = []
|
||||
for doc in documents:
|
||||
@@ -237,40 +245,40 @@ class CollectionContext:
|
||||
result = MagicMock()
|
||||
result.inserted_ids = inserted_ids
|
||||
return result
|
||||
|
||||
|
||||
def mock_update_one(query, update, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
result.modified_count = 1
|
||||
return result
|
||||
|
||||
|
||||
def mock_update_many(query, update, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
result.modified_count = len(mock_data)
|
||||
return result
|
||||
|
||||
|
||||
def mock_delete_one(query, *args, **kwargs):
|
||||
result = MagicMock()
|
||||
result.deleted_count = 1
|
||||
if mock_data:
|
||||
mock_data.pop(0) # Just remove the first item for simplicity
|
||||
return result
|
||||
|
||||
|
||||
def mock_delete_many(query, *args, **kwargs):
|
||||
count = len(mock_data)
|
||||
mock_data.clear()
|
||||
result = MagicMock()
|
||||
result.deleted_count = count
|
||||
return result
|
||||
|
||||
|
||||
def mock_count_documents(query, *args, **kwargs):
|
||||
return len(mock_data)
|
||||
|
||||
|
||||
def mock_aggregate(pipeline, *args, **kwargs):
|
||||
return []
|
||||
|
||||
|
||||
def mock_create_index(keys, **kwargs):
|
||||
return f"mock_index_{keys}"
|
||||
|
||||
|
||||
# Assign the mock implementations
|
||||
mock_collection.find.side_effect = mock_find
|
||||
mock_collection.find_one.side_effect = mock_find_one
|
||||
@@ -283,10 +291,10 @@ class CollectionContext:
|
||||
mock_collection.count_documents.side_effect = mock_count_documents
|
||||
mock_collection.aggregate.side_effect = mock_aggregate
|
||||
mock_collection.create_index.side_effect = mock_create_index
|
||||
|
||||
|
||||
# Add retry capabilities to the mock collection
|
||||
self._add_retry_capabilities_to_mock(mock_collection)
|
||||
|
||||
|
||||
self.collection = mock_collection
|
||||
return self.collection
|
||||
|
||||
@@ -322,17 +330,25 @@ class CollectionContext:
|
||||
"""
|
||||
Add retry capabilities to mock collection methods.
|
||||
This is a simplified version that just wraps the mock methods.
|
||||
|
||||
|
||||
Args:
|
||||
mock_collection: The mock collection to enhance
|
||||
"""
|
||||
# List of common MongoDB collection methods to add retry capabilities to
|
||||
methods = [
|
||||
'insert_one', 'insert_many', 'find_one', 'find',
|
||||
'update_one', 'update_many', 'delete_one', 'delete_many',
|
||||
'replace_one', 'count_documents', 'aggregate'
|
||||
"insert_one",
|
||||
"insert_many",
|
||||
"find_one",
|
||||
"find",
|
||||
"update_one",
|
||||
"update_many",
|
||||
"delete_one",
|
||||
"delete_many",
|
||||
"replace_one",
|
||||
"count_documents",
|
||||
"aggregate",
|
||||
]
|
||||
|
||||
|
||||
# Add retry decorator to each method
|
||||
for method_name in methods:
|
||||
if hasattr(mock_collection, method_name):
|
||||
@@ -340,7 +356,7 @@ class CollectionContext:
|
||||
setattr(
|
||||
mock_collection,
|
||||
method_name,
|
||||
retry_operation(max_retries=1, retry_interval=0)(original_method)
|
||||
retry_operation(max_retries=1, retry_interval=0)(original_method),
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
|
||||
@@ -22,7 +22,7 @@ def test_basic_crud_operations():
|
||||
# First, clear any existing data
|
||||
users_collection.delete_many({})
|
||||
print("Cleared existing data")
|
||||
|
||||
|
||||
# Insert multiple documents
|
||||
insert_result = users_collection.insert_many(
|
||||
[
|
||||
@@ -58,7 +58,7 @@ def test_basic_crud_operations():
|
||||
condition2 = admin_users and admin_users[0].get("username") == "jane"
|
||||
condition3 = update_result.modified_count == 2
|
||||
condition4 = delete_result.deleted_count == 1
|
||||
|
||||
|
||||
print(f"Condition 1 (admin count): {condition1}")
|
||||
print(f"Condition 2 (admin is jane): {condition2}")
|
||||
print(f"Condition 3 (updated 2 users): {condition3}")
|
||||
@@ -80,7 +80,7 @@ def test_nested_documents():
|
||||
# Clear any existing data
|
||||
products_collection.delete_many({})
|
||||
print("Cleared existing data")
|
||||
|
||||
|
||||
# Insert a product with nested data
|
||||
insert_result = products_collection.insert_one(
|
||||
{
|
||||
@@ -110,15 +110,17 @@ def test_nested_documents():
|
||||
print(f"Found updated laptop: {updated_laptop is not None}")
|
||||
if updated_laptop:
|
||||
print(f"Updated laptop specs: {updated_laptop.get('specs')}")
|
||||
if 'specs' in updated_laptop:
|
||||
if "specs" in updated_laptop:
|
||||
print(f"Updated RAM: {updated_laptop['specs'].get('ram')}")
|
||||
|
||||
# Check each condition separately
|
||||
condition1 = laptop is not None
|
||||
condition2 = laptop and laptop.get('specs', {}).get('ram') == "16GB"
|
||||
condition2 = laptop and laptop.get("specs", {}).get("ram") == "16GB"
|
||||
condition3 = update_result.modified_count == 1
|
||||
condition4 = updated_laptop and updated_laptop.get('specs', {}).get('ram') == "32GB"
|
||||
|
||||
condition4 = (
|
||||
updated_laptop and updated_laptop.get("specs", {}).get("ram") == "32GB"
|
||||
)
|
||||
|
||||
print(f"Condition 1 (laptop found): {condition1}")
|
||||
print(f"Condition 2 (original RAM is 16GB): {condition2}")
|
||||
print(f"Condition 3 (update modified 1 doc): {condition3}")
|
||||
@@ -140,7 +142,7 @@ def test_array_operations():
|
||||
# Clear any existing data
|
||||
orders_collection.delete_many({})
|
||||
print("Cleared existing data")
|
||||
|
||||
|
||||
# Insert an order with array of items
|
||||
insert_result = orders_collection.insert_one(
|
||||
{
|
||||
@@ -170,10 +172,12 @@ def test_array_operations():
|
||||
# Verify the update
|
||||
updated_order = orders_collection.find_one({"order_id": "ORD001"})
|
||||
print(f"Found updated order: {updated_order is not None}")
|
||||
|
||||
|
||||
if updated_order:
|
||||
print(f"Number of items in order: {len(updated_order.get('items', []))}")
|
||||
items = updated_order.get('items', [])
|
||||
print(
|
||||
f"Number of items in order: {len(updated_order.get('items', []))}"
|
||||
)
|
||||
items = updated_order.get("items", [])
|
||||
if items:
|
||||
last_item = items[-1] if items else None
|
||||
print(f"Last item in order: {last_item}")
|
||||
@@ -181,9 +185,13 @@ def test_array_operations():
|
||||
# Check each condition separately
|
||||
condition1 = len(laptop_orders) == 1
|
||||
condition2 = update_result.modified_count == 1
|
||||
condition3 = updated_order and len(updated_order.get('items', [])) == 3
|
||||
condition4 = updated_order and updated_order.get('items', []) and updated_order['items'][-1].get('product') == "Keyboard"
|
||||
|
||||
condition3 = updated_order and len(updated_order.get("items", [])) == 3
|
||||
condition4 = (
|
||||
updated_order
|
||||
and updated_order.get("items", [])
|
||||
and updated_order["items"][-1].get("product") == "Keyboard"
|
||||
)
|
||||
|
||||
print(f"Condition 1 (found 1 laptop order): {condition1}")
|
||||
print(f"Condition 2 (update modified 1 doc): {condition2}")
|
||||
print(f"Condition 3 (order has 3 items): {condition3}")
|
||||
@@ -205,7 +213,7 @@ def test_aggregation():
|
||||
# Clear any existing data
|
||||
sales_collection.delete_many({})
|
||||
print("Cleared existing data")
|
||||
|
||||
|
||||
# Insert sample sales data
|
||||
insert_result = sales_collection.insert_many(
|
||||
[
|
||||
@@ -219,13 +227,13 @@ def test_aggregation():
|
||||
# Calculate total sales by product - use a simpler aggregation pipeline
|
||||
pipeline = [
|
||||
{"$match": {}}, # Match all documents
|
||||
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}}
|
||||
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}},
|
||||
]
|
||||
|
||||
|
||||
# Execute the aggregation
|
||||
sales_summary = list(sales_collection.aggregate(pipeline))
|
||||
print(f"Aggregation returned {len(sales_summary)} results")
|
||||
|
||||
|
||||
# Print the results for debugging
|
||||
for item in sales_summary:
|
||||
print(f"Product: {item.get('_id')}, Total: {item.get('total')}")
|
||||
@@ -233,7 +241,8 @@ def test_aggregation():
|
||||
# Check each condition separately
|
||||
condition1 = len(sales_summary) == 3
|
||||
condition2 = any(
|
||||
item.get("_id") == "Laptop" and abs(item.get("total", 0) - 999.99) < 0.01
|
||||
item.get("_id") == "Laptop"
|
||||
and abs(item.get("total", 0) - 999.99) < 0.01
|
||||
for item in sales_summary
|
||||
)
|
||||
condition3 = any(
|
||||
@@ -241,10 +250,11 @@ def test_aggregation():
|
||||
for item in sales_summary
|
||||
)
|
||||
condition4 = any(
|
||||
item.get("_id") == "Keyboard" and abs(item.get("total", 0) - 59.99) < 0.01
|
||||
item.get("_id") == "Keyboard"
|
||||
and abs(item.get("total", 0) - 59.99) < 0.01
|
||||
for item in sales_summary
|
||||
)
|
||||
|
||||
|
||||
print(f"Condition 1 (3 summary items): {condition1}")
|
||||
print(f"Condition 2 (laptop total correct): {condition2}")
|
||||
print(f"Condition 3 (mouse total correct): {condition3}")
|
||||
@@ -325,35 +335,37 @@ def test_complex_queries():
|
||||
# Update with multiple conditions - split into separate operations for better compatibility
|
||||
# First set the discount
|
||||
products_collection.update_many(
|
||||
{"price": {"$lt": 100}, "in_stock": True},
|
||||
{"$set": {"discount": 0.1}}
|
||||
{"price": {"$lt": 100}, "in_stock": True}, {"$set": {"discount": 0.1}}
|
||||
)
|
||||
|
||||
|
||||
# Then update the price
|
||||
update_result = products_collection.update_many(
|
||||
{"price": {"$lt": 100}, "in_stock": True},
|
||||
{"$inc": {"price": -10}}
|
||||
{"price": {"$lt": 100}, "in_stock": True}, {"$inc": {"price": -10}}
|
||||
)
|
||||
|
||||
# Verify the update
|
||||
updated_product = products_collection.find_one({"name": "Cheap Mouse"})
|
||||
|
||||
|
||||
# Print debug information
|
||||
print(f"Found expensive electronics: {len(expensive_electronics)}")
|
||||
if expensive_electronics:
|
||||
print(f"First expensive product: {expensive_electronics[0].get('name')}")
|
||||
print(
|
||||
f"First expensive product: {expensive_electronics[0].get('name')}"
|
||||
)
|
||||
print(f"Modified count: {update_result.modified_count}")
|
||||
if updated_product:
|
||||
print(f"Updated product price: {updated_product.get('price')}")
|
||||
print(f"Updated product discount: {updated_product.get('discount')}")
|
||||
|
||||
|
||||
# More flexible verification with approximate float comparison
|
||||
success = (
|
||||
len(expensive_electronics) >= 1
|
||||
and expensive_electronics[0].get("name") in ["Expensive Laptop", "Laptop"]
|
||||
and expensive_electronics[0].get("name")
|
||||
in ["Expensive Laptop", "Laptop"]
|
||||
and update_result.modified_count >= 1
|
||||
and updated_product is not None
|
||||
and updated_product.get("discount", 0) > 0 # Just check that discount exists and is positive
|
||||
and updated_product.get("discount", 0)
|
||||
> 0 # Just check that discount exists and is positive
|
||||
)
|
||||
print(f"Test {'passed' if success else 'failed'}")
|
||||
return success
|
||||
@@ -368,48 +380,51 @@ def run_concurrent_operation_test(num_threads=100):
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
print(f"\nStarting concurrent operation test with {num_threads} threads...")
|
||||
|
||||
|
||||
# Results tracking
|
||||
results = {"passed": 0, "failed": 0, "errors": []}
|
||||
results_lock = threading.Lock()
|
||||
|
||||
|
||||
def worker(thread_id):
|
||||
# Create a unique collection name for this thread
|
||||
collection_name = f"concurrent_test_{thread_id}"
|
||||
|
||||
|
||||
try:
|
||||
# Generate unique data for this thread
|
||||
unique_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
with mongo_handler.collection(collection_name) as collection:
|
||||
# Insert a document
|
||||
collection.insert_one({
|
||||
"thread_id": thread_id,
|
||||
"uuid": unique_id,
|
||||
"timestamp": time.time()
|
||||
})
|
||||
|
||||
collection.insert_one(
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"uuid": unique_id,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
# Find the document
|
||||
doc = collection.find_one({"thread_id": thread_id})
|
||||
|
||||
|
||||
# Update the document
|
||||
collection.update_one(
|
||||
{"thread_id": thread_id},
|
||||
{"$set": {"updated": True}}
|
||||
{"thread_id": thread_id}, {"$set": {"updated": True}}
|
||||
)
|
||||
|
||||
|
||||
# Verify update
|
||||
updated_doc = collection.find_one({"thread_id": thread_id})
|
||||
|
||||
|
||||
# Clean up
|
||||
collection.delete_many({"thread_id": thread_id})
|
||||
|
||||
success = (doc is not None and
|
||||
updated_doc is not None and
|
||||
updated_doc.get("updated") is True)
|
||||
|
||||
|
||||
success = (
|
||||
doc is not None
|
||||
and updated_doc is not None
|
||||
and updated_doc.get("updated") is True
|
||||
)
|
||||
|
||||
# Update results with thread safety
|
||||
with results_lock:
|
||||
if success:
|
||||
@@ -421,15 +436,15 @@ def run_concurrent_operation_test(num_threads=100):
|
||||
with results_lock:
|
||||
results["failed"] += 1
|
||||
results["errors"].append(f"Thread {thread_id} exception: {str(e)}")
|
||||
|
||||
|
||||
# Create and start threads using a thread pool
|
||||
start_time = time.time()
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(num_threads)]
|
||||
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
# Print results
|
||||
print(f"\nConcurrent Operation Test Results:")
|
||||
print(f"Total threads: {num_threads}")
|
||||
@@ -437,14 +452,16 @@ def run_concurrent_operation_test(num_threads=100):
|
||||
print(f"Failed: {results['failed']}")
|
||||
print(f"Execution time: {execution_time:.2f} seconds")
|
||||
print(f"Operations per second: {num_threads / execution_time:.2f}")
|
||||
|
||||
|
||||
if results["failed"] > 0:
|
||||
print("\nErrors:")
|
||||
for error in results["errors"][:10]: # Show only first 10 errors to avoid flooding output
|
||||
for error in results["errors"][
|
||||
:10
|
||||
]: # Show only first 10 errors to avoid flooding output
|
||||
print(f"- {error}")
|
||||
if len(results["errors"]) > 10:
|
||||
print(f"- ... and {len(results['errors']) - 10} more errors")
|
||||
|
||||
|
||||
return results["failed"] == 0
|
||||
|
||||
|
||||
@@ -493,10 +510,10 @@ def run_all_tests():
|
||||
|
||||
if __name__ == "__main__":
|
||||
mongo_handler = MongoDBHandler()
|
||||
|
||||
|
||||
# Run standard tests first
|
||||
passed, failed = run_all_tests()
|
||||
|
||||
|
||||
# If all tests pass, run the concurrent operation test
|
||||
if failed == 0:
|
||||
run_concurrent_operation_test(10000)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""
|
||||
Test script for MongoDB handler with a local MongoDB instance.
|
||||
"""
|
||||
|
||||
import os
|
||||
from Controllers.Mongo.database import MongoDBHandler, CollectionContext
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Create a custom handler class for local testing
|
||||
class LocalMongoDBHandler(MongoDBHandler):
|
||||
"""A MongoDB handler for local testing without authentication."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize with a direct MongoDB URI."""
|
||||
self._initialized = False
|
||||
@@ -22,6 +24,7 @@ class LocalMongoDBHandler(MongoDBHandler):
|
||||
}
|
||||
self._initialized = True
|
||||
|
||||
|
||||
# Create a custom handler for local testing
|
||||
def create_local_handler():
|
||||
"""Create a MongoDB handler for local testing."""
|
||||
@@ -29,35 +32,36 @@ def create_local_handler():
|
||||
handler = LocalMongoDBHandler()
|
||||
return handler
|
||||
|
||||
|
||||
def test_connection_monitoring():
|
||||
"""Test connection monitoring with the MongoDB handler."""
|
||||
print("\nTesting connection monitoring...")
|
||||
|
||||
|
||||
# Create a local handler
|
||||
local_handler = create_local_handler()
|
||||
|
||||
|
||||
# Add connection tracking to the handler
|
||||
local_handler._open_connections = 0
|
||||
|
||||
|
||||
# Modify the CollectionContext class to track connections
|
||||
original_enter = CollectionContext.__enter__
|
||||
original_exit = CollectionContext.__exit__
|
||||
|
||||
|
||||
def tracked_enter(self):
|
||||
result = original_enter(self)
|
||||
self.db_handler._open_connections += 1
|
||||
print(f"Connection opened. Total open: {self.db_handler._open_connections}")
|
||||
return result
|
||||
|
||||
|
||||
def tracked_exit(self, exc_type, exc_val, exc_tb):
|
||||
self.db_handler._open_connections -= 1
|
||||
print(f"Connection closed. Total open: {self.db_handler._open_connections}")
|
||||
return original_exit(self, exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
# Apply the tracking methods
|
||||
CollectionContext.__enter__ = tracked_enter
|
||||
CollectionContext.__exit__ = tracked_exit
|
||||
|
||||
|
||||
try:
|
||||
# Test with multiple operations
|
||||
for i in range(3):
|
||||
@@ -72,18 +76,19 @@ def test_connection_monitoring():
|
||||
print(f"Operation failed: {e}")
|
||||
except Exception as e:
|
||||
print(f"Connection failed: {e}")
|
||||
|
||||
|
||||
# Final connection count
|
||||
print(f"\nFinal open connections: {local_handler._open_connections}")
|
||||
if local_handler._open_connections == 0:
|
||||
print("✅ All connections were properly closed")
|
||||
else:
|
||||
print(f"❌ {local_handler._open_connections} connections remain open")
|
||||
|
||||
|
||||
finally:
|
||||
# Restore original methods
|
||||
CollectionContext.__enter__ = original_enter
|
||||
CollectionContext.__exit__ = original_exit
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_connection_monitoring()
|
||||
|
||||
@@ -38,4 +38,4 @@ class Configs(BaseSettings):
|
||||
|
||||
# singleton instance of the POSTGRESQL configuration settings
|
||||
postgres_configs = Configs()
|
||||
print('url', postgres_configs.url)
|
||||
print("url", postgres_configs.url)
|
||||
|
||||
@@ -240,10 +240,10 @@ class CRUDModel:
|
||||
|
||||
exclude_args = exclude_args or []
|
||||
exclude_args = [exclude_arg.key for exclude_arg in exclude_args]
|
||||
|
||||
|
||||
include_args = include_args or []
|
||||
include_args = [include_arg.key for include_arg in include_args]
|
||||
|
||||
|
||||
# If include_args is provided, only use those fields for matching
|
||||
# Otherwise, use all fields except those in exclude_args
|
||||
for key, value in kwargs.items():
|
||||
|
||||
@@ -478,46 +478,48 @@ def run_simple_concurrent_test(num_threads=10):
|
||||
import time
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
print(f"\nStarting simple concurrent test with {num_threads} threads...")
|
||||
|
||||
|
||||
# Results tracking
|
||||
results = {"passed": 0, "failed": 0, "errors": []}
|
||||
results_lock = threading.Lock()
|
||||
|
||||
|
||||
def worker(thread_id):
|
||||
try:
|
||||
# Simple query to test connection pooling
|
||||
with EndpointRestriction.new_session() as db_session:
|
||||
# Just run a simple count query
|
||||
count_query = db_session.query(EndpointRestriction).count()
|
||||
|
||||
|
||||
# Small delay to simulate work
|
||||
time.sleep(random.uniform(0.01, 0.05))
|
||||
|
||||
|
||||
# Simple success criteria
|
||||
success = count_query >= 0
|
||||
|
||||
|
||||
# Update results with thread safety
|
||||
with results_lock:
|
||||
if success:
|
||||
results["passed"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
results["errors"].append(f"Thread {thread_id} failed to get count")
|
||||
results["errors"].append(
|
||||
f"Thread {thread_id} failed to get count"
|
||||
)
|
||||
except Exception as e:
|
||||
with results_lock:
|
||||
results["failed"] += 1
|
||||
results["errors"].append(f"Thread {thread_id} exception: {str(e)}")
|
||||
|
||||
|
||||
# Create and start threads using a thread pool
|
||||
start_time = time.time()
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(num_threads)]
|
||||
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
|
||||
# Print results
|
||||
print(f"\nConcurrent Operation Test Results:")
|
||||
print(f"Total threads: {num_threads}")
|
||||
@@ -525,21 +527,23 @@ def run_simple_concurrent_test(num_threads=10):
|
||||
print(f"Failed: {results['failed']}")
|
||||
print(f"Execution time: {execution_time:.2f} seconds")
|
||||
print(f"Operations per second: {num_threads / execution_time:.2f}")
|
||||
|
||||
|
||||
if results["failed"] > 0:
|
||||
print("\nErrors:")
|
||||
for error in results["errors"][:10]: # Show only first 10 errors to avoid flooding output
|
||||
for error in results["errors"][
|
||||
:10
|
||||
]: # Show only first 10 errors to avoid flooding output
|
||||
print(f"- {error}")
|
||||
if len(results["errors"]) > 10:
|
||||
print(f"- ... and {len(results['errors']) - 10} more errors")
|
||||
|
||||
|
||||
return results["failed"] == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_table_in_postgres()
|
||||
passed, failed = run_all_tests()
|
||||
|
||||
|
||||
# If all tests pass, run the simple concurrent test
|
||||
if failed == 0:
|
||||
run_simple_concurrent_test(100)
|
||||
|
||||
@@ -8,17 +8,17 @@ from Controllers.Redis.response import RedisResponse
|
||||
|
||||
class RedisPublisher:
|
||||
"""Redis Publisher class for broadcasting messages to channels."""
|
||||
|
||||
|
||||
def __init__(self, redis_client=redis_cli):
|
||||
self.redis_client = redis_client
|
||||
|
||||
|
||||
def publish(self, channel: str, message: Union[Dict, List, str]) -> RedisResponse:
|
||||
"""Publish a message to a Redis channel.
|
||||
|
||||
|
||||
Args:
|
||||
channel: The channel to publish to
|
||||
message: The message to publish (will be JSON serialized if dict or list)
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
@@ -26,113 +26,124 @@ class RedisPublisher:
|
||||
# Convert dict/list to JSON string if needed
|
||||
if isinstance(message, (dict, list)):
|
||||
message = json.dumps(message)
|
||||
|
||||
|
||||
# Publish the message
|
||||
recipient_count = self.redis_client.publish(channel, message)
|
||||
|
||||
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message=f"Message published successfully to {channel}.",
|
||||
data={"recipients": recipient_count}
|
||||
data={"recipients": recipient_count},
|
||||
)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message=f"Failed to publish message to {channel}.",
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis Subscriber class for listening to channels."""
|
||||
|
||||
|
||||
def __init__(self, redis_client=redis_cli):
|
||||
self.redis_client = redis_client
|
||||
self.pubsub = self.redis_client.pubsub()
|
||||
self.active_threads = {}
|
||||
|
||||
def subscribe(self, channel: str, callback: Callable[[Dict], None]) -> RedisResponse:
|
||||
|
||||
def subscribe(
|
||||
self, channel: str, callback: Callable[[Dict], None]
|
||||
) -> RedisResponse:
|
||||
"""Subscribe to a Redis channel with a callback function.
|
||||
|
||||
|
||||
Args:
|
||||
channel: The channel to subscribe to
|
||||
callback: Function to call when a message is received
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
try:
|
||||
# Subscribe to the channel
|
||||
self.pubsub.subscribe(**{channel: self._message_handler(callback)})
|
||||
|
||||
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message=f"Successfully subscribed to {channel}."
|
||||
status=True, message=f"Successfully subscribed to {channel}."
|
||||
)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message=f"Failed to subscribe to {channel}.",
|
||||
error=str(e)
|
||||
status=False, message=f"Failed to subscribe to {channel}.", error=str(e)
|
||||
)
|
||||
|
||||
def psubscribe(self, pattern: str, callback: Callable[[Dict], None]) -> RedisResponse:
|
||||
|
||||
def psubscribe(
|
||||
self, pattern: str, callback: Callable[[Dict], None]
|
||||
) -> RedisResponse:
|
||||
"""Subscribe to Redis channels matching a pattern.
|
||||
|
||||
|
||||
Args:
|
||||
pattern: The pattern to subscribe to (e.g., 'user.*')
|
||||
callback: Function to call when a message is received
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
try:
|
||||
# Subscribe to the pattern
|
||||
self.pubsub.psubscribe(**{pattern: self._message_handler(callback)})
|
||||
|
||||
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message=f"Successfully pattern-subscribed to {pattern}."
|
||||
status=True, message=f"Successfully pattern-subscribed to {pattern}."
|
||||
)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message=f"Failed to pattern-subscribe to {pattern}.",
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def _message_handler(self, callback: Callable[[Dict], None]):
|
||||
"""Create a message handler function for the subscription."""
|
||||
|
||||
def handler(message):
|
||||
# Skip subscription confirmation messages
|
||||
if message['type'] in ('subscribe', 'psubscribe'):
|
||||
if message["type"] in ("subscribe", "psubscribe"):
|
||||
return
|
||||
|
||||
|
||||
# Parse JSON if the message is a JSON string
|
||||
data = message['data']
|
||||
data = message["data"]
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode('utf-8')
|
||||
data = data.decode("utf-8")
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, keep as is
|
||||
pass
|
||||
|
||||
|
||||
# Call the callback with the message data
|
||||
callback({
|
||||
'channel': message.get('channel', b'').decode('utf-8') if isinstance(message.get('channel', b''), bytes) else message.get('channel', ''),
|
||||
'pattern': message.get('pattern', b'').decode('utf-8') if isinstance(message.get('pattern', b''), bytes) else message.get('pattern', ''),
|
||||
'data': data
|
||||
})
|
||||
|
||||
callback(
|
||||
{
|
||||
"channel": (
|
||||
message.get("channel", b"").decode("utf-8")
|
||||
if isinstance(message.get("channel", b""), bytes)
|
||||
else message.get("channel", "")
|
||||
),
|
||||
"pattern": (
|
||||
message.get("pattern", b"").decode("utf-8")
|
||||
if isinstance(message.get("pattern", b""), bytes)
|
||||
else message.get("pattern", "")
|
||||
),
|
||||
"data": data,
|
||||
}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def start_listening(self, in_thread: bool = True) -> RedisResponse:
|
||||
"""Start listening for messages on subscribed channels.
|
||||
|
||||
|
||||
Args:
|
||||
in_thread: If True, start listening in a separate thread
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
@@ -140,50 +151,41 @@ class RedisSubscriber:
|
||||
if in_thread:
|
||||
thread = Thread(target=self._listen_thread, daemon=True)
|
||||
thread.start()
|
||||
self.active_threads['listener'] = thread
|
||||
self.active_threads["listener"] = thread
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message="Listening thread started successfully."
|
||||
status=True, message="Listening thread started successfully."
|
||||
)
|
||||
else:
|
||||
# This will block the current thread
|
||||
self._listen_thread()
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message="Listening started successfully (blocking)."
|
||||
status=True, message="Listening started successfully (blocking)."
|
||||
)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message="Failed to start listening.",
|
||||
error=str(e)
|
||||
status=False, message="Failed to start listening.", error=str(e)
|
||||
)
|
||||
|
||||
|
||||
def _listen_thread(self):
|
||||
"""Thread function for listening to messages."""
|
||||
self.pubsub.run_in_thread(sleep_time=0.01)
|
||||
|
||||
|
||||
def stop_listening(self) -> RedisResponse:
|
||||
"""Stop listening for messages."""
|
||||
try:
|
||||
self.pubsub.close()
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message="Successfully stopped listening."
|
||||
)
|
||||
return RedisResponse(status=True, message="Successfully stopped listening.")
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message="Failed to stop listening.",
|
||||
error=str(e)
|
||||
status=False, message="Failed to stop listening.", error=str(e)
|
||||
)
|
||||
|
||||
|
||||
def unsubscribe(self, channel: Optional[str] = None) -> RedisResponse:
|
||||
"""Unsubscribe from a channel or all channels.
|
||||
|
||||
|
||||
Args:
|
||||
channel: The channel to unsubscribe from, or None for all channels
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
@@ -194,24 +196,21 @@ class RedisSubscriber:
|
||||
else:
|
||||
self.pubsub.unsubscribe()
|
||||
message = "Successfully unsubscribed from all channels."
|
||||
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message=message
|
||||
)
|
||||
|
||||
return RedisResponse(status=True, message=message)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message=f"Failed to unsubscribe from {'channel' if channel else 'all channels'}.",
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
def punsubscribe(self, pattern: Optional[str] = None) -> RedisResponse:
|
||||
"""Unsubscribe from a pattern or all patterns.
|
||||
|
||||
|
||||
Args:
|
||||
pattern: The pattern to unsubscribe from, or None for all patterns
|
||||
|
||||
|
||||
Returns:
|
||||
RedisResponse with status and message
|
||||
"""
|
||||
@@ -222,24 +221,21 @@ class RedisSubscriber:
|
||||
else:
|
||||
self.pubsub.punsubscribe()
|
||||
message = "Successfully unsubscribed from all patterns."
|
||||
|
||||
return RedisResponse(
|
||||
status=True,
|
||||
message=message
|
||||
)
|
||||
|
||||
return RedisResponse(status=True, message=message)
|
||||
except Exception as e:
|
||||
return RedisResponse(
|
||||
status=False,
|
||||
message=f"Failed to unsubscribe from {'pattern' if pattern else 'all patterns'}.",
|
||||
error=str(e)
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
class RedisPubSub:
|
||||
"""Singleton class that provides both publisher and subscriber functionality."""
|
||||
|
||||
|
||||
_instance = None
|
||||
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(RedisPubSub, cls).__new__(cls)
|
||||
|
||||
@@ -15,6 +15,7 @@ CHANNEL_WRITER = "chain:writer"
|
||||
# Flag to control the demo
|
||||
running = True
|
||||
|
||||
|
||||
def generate_mock_data():
|
||||
"""Generate a mock message with UUID, timestamp, and sample data."""
|
||||
return {
|
||||
@@ -24,40 +25,43 @@ def generate_mock_data():
|
||||
"data": {
|
||||
"value": f"Sample data {int(time.time())}",
|
||||
"status": "new",
|
||||
"counter": 0
|
||||
}
|
||||
"counter": 0,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def reader_function():
|
||||
"""
|
||||
First function in the chain.
|
||||
Generates mock data and publishes to the reader channel.
|
||||
"""
|
||||
print("[READER] Function started")
|
||||
|
||||
|
||||
while running:
|
||||
# Generate mock data
|
||||
message = generate_mock_data()
|
||||
start_time = time.time()
|
||||
message["start_time"] = start_time
|
||||
|
||||
|
||||
# Publish to reader channel
|
||||
result = redis_pubsub.publisher.publish(CHANNEL_READER, message)
|
||||
|
||||
|
||||
if result.status:
|
||||
print(f"[READER] {time.time():.6f} | Published UUID: {message['uuid']}")
|
||||
else:
|
||||
print(f"[READER] Publish error: {result.error}")
|
||||
|
||||
|
||||
# Wait before generating next message
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def processor_function():
|
||||
"""
|
||||
Second function in the chain.
|
||||
Subscribes to reader channel, processes messages, and publishes to processor channel.
|
||||
"""
|
||||
print("[PROCESSOR] Function started")
|
||||
|
||||
def on_reader_message(message):
|
||||
# The message structure from the subscriber has 'data' containing our actual message
|
||||
# If data is a string, parse it as JSON
|
||||
@@ -68,47 +72,51 @@ def processor_function():
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[PROCESSOR] Error parsing message data: {e}")
|
||||
return
|
||||
|
||||
|
||||
# Check if stage is 'red' before processing
|
||||
if data.get("stage") == "red":
|
||||
# Process the message
|
||||
data["processor_timestamp"] = datetime.now().isoformat()
|
||||
data["data"]["status"] = "processed"
|
||||
data["data"]["counter"] += 1
|
||||
|
||||
|
||||
# Update stage to 'processed'
|
||||
data["stage"] = "processed"
|
||||
|
||||
|
||||
# Add some processing metadata
|
||||
data["processing"] = {
|
||||
"duration_ms": 150, # Mock processing time
|
||||
"processor_id": "main-processor"
|
||||
"processor_id": "main-processor",
|
||||
}
|
||||
|
||||
|
||||
# Publish to processor channel
|
||||
result = redis_pubsub.publisher.publish(CHANNEL_PROCESSOR, data)
|
||||
|
||||
|
||||
if result.status:
|
||||
print(f"[PROCESSOR] {time.time():.6f} | Received UUID: {data['uuid']} | Published UUID: {data['uuid']}")
|
||||
print(
|
||||
f"[PROCESSOR] {time.time():.6f} | Received UUID: {data['uuid']} | Published UUID: {data['uuid']}"
|
||||
)
|
||||
else:
|
||||
print(f"[PROCESSOR] Publish error: {result.error}")
|
||||
else:
|
||||
print(f"[PROCESSOR] Skipped message: {data['uuid']} (stage is not 'red')")
|
||||
|
||||
|
||||
# Subscribe to reader channel
|
||||
result = redis_pubsub.subscriber.subscribe(CHANNEL_READER, on_reader_message)
|
||||
|
||||
|
||||
if result.status:
|
||||
print(f"[PROCESSOR] Subscribed to channel: {CHANNEL_READER}")
|
||||
else:
|
||||
print(f"[PROCESSOR] Subscribe error: {result.error}")
|
||||
|
||||
|
||||
def writer_function():
|
||||
"""
|
||||
Third function in the chain.
|
||||
Subscribes to processor channel and performs final processing.
|
||||
"""
|
||||
print("[WRITER] Function started")
|
||||
|
||||
def on_processor_message(message):
|
||||
# The message structure from the subscriber has 'data' containing our actual message
|
||||
# If data is a string, parse it as JSON
|
||||
@@ -119,42 +127,45 @@ def writer_function():
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[WRITER] Error parsing message data: {e}")
|
||||
return
|
||||
|
||||
|
||||
# Check if stage is 'processed' before processing
|
||||
if data.get("stage") == "processed":
|
||||
# Process the message
|
||||
data["writer_timestamp"] = datetime.now().isoformat()
|
||||
data["data"]["status"] = "completed"
|
||||
data["data"]["counter"] += 1
|
||||
|
||||
|
||||
# Update stage to 'completed'
|
||||
data["stage"] = "completed"
|
||||
|
||||
|
||||
# Add some writer metadata
|
||||
data["storage"] = {
|
||||
"location": "main-db",
|
||||
"partition": "events-2025-04"
|
||||
}
|
||||
|
||||
data["storage"] = {"location": "main-db", "partition": "events-2025-04"}
|
||||
|
||||
# Calculate elapsed time if start_time is available
|
||||
current_time = time.time()
|
||||
elapsed_ms = ""
|
||||
if "start_time" in data:
|
||||
elapsed_ms = f" | Elapsed: {(current_time - data['start_time']) * 1000:.2f}ms"
|
||||
|
||||
elapsed_ms = (
|
||||
f" | Elapsed: {(current_time - data['start_time']) * 1000:.2f}ms"
|
||||
)
|
||||
|
||||
# Optionally publish to writer channel for any downstream listeners
|
||||
result = redis_pubsub.publisher.publish(CHANNEL_WRITER, data)
|
||||
|
||||
|
||||
if result.status:
|
||||
print(f"[WRITER] {current_time:.6f} | Received UUID: {data['uuid']} | Published UUID: {data['uuid']}{elapsed_ms}")
|
||||
print(
|
||||
f"[WRITER] {current_time:.6f} | Received UUID: {data['uuid']} | Published UUID: {data['uuid']}{elapsed_ms}"
|
||||
)
|
||||
else:
|
||||
print(f"[WRITER] Publish error: {result.error}")
|
||||
else:
|
||||
print(f"[WRITER] Skipped message: {data['uuid']} (stage is not 'processed')")
|
||||
|
||||
print(
|
||||
f"[WRITER] Skipped message: {data['uuid']} (stage is not 'processed')"
|
||||
)
|
||||
|
||||
# Subscribe to processor channel
|
||||
result = redis_pubsub.subscriber.subscribe(CHANNEL_PROCESSOR, on_processor_message)
|
||||
|
||||
|
||||
if result.status:
|
||||
print(f"[WRITER] Subscribed to channel: {CHANNEL_PROCESSOR}")
|
||||
else:
|
||||
@@ -167,18 +178,18 @@ def run_demo():
|
||||
print("Chain: READER → PROCESSOR → WRITER")
|
||||
print(f"Channels: {CHANNEL_READER} → {CHANNEL_PROCESSOR} → {CHANNEL_WRITER}")
|
||||
print("Format: [SERVICE] TIMESTAMP | Received/Published UUID | [Elapsed time]")
|
||||
|
||||
|
||||
# Start the Redis subscriber listening thread
|
||||
redis_pubsub.subscriber.start_listening()
|
||||
|
||||
|
||||
# Start processor and writer functions (these subscribe to channels)
|
||||
processor_function()
|
||||
writer_function()
|
||||
|
||||
|
||||
# Create a thread for the reader function (this publishes messages)
|
||||
reader_thread = Thread(target=reader_function, daemon=True)
|
||||
reader_thread.start()
|
||||
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -41,15 +41,15 @@ class RedisConn:
|
||||
# Add connection pooling settings if not provided
|
||||
if "max_connections" not in self.config:
|
||||
self.config["max_connections"] = 50 # Increased for better concurrency
|
||||
|
||||
|
||||
# Add connection timeout settings
|
||||
if "health_check_interval" not in self.config:
|
||||
self.config["health_check_interval"] = 30 # Health check every 30 seconds
|
||||
|
||||
|
||||
# Add retry settings for operations
|
||||
if "retry_on_timeout" not in self.config:
|
||||
self.config["retry_on_timeout"] = True
|
||||
|
||||
|
||||
# Add connection pool settings for better performance
|
||||
if "socket_keepalive" not in self.config:
|
||||
self.config["socket_keepalive"] = True
|
||||
|
||||
@@ -113,56 +113,58 @@ def run_all_examples() -> None:
|
||||
|
||||
def run_concurrent_test(num_threads=100):
|
||||
"""Run a comprehensive concurrent test with multiple threads to verify Redis connection handling."""
|
||||
print(f"\nStarting comprehensive Redis concurrent test with {num_threads} threads...")
|
||||
|
||||
print(
|
||||
f"\nStarting comprehensive Redis concurrent test with {num_threads} threads..."
|
||||
)
|
||||
|
||||
# Results tracking with detailed metrics
|
||||
results = {
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"retried": 0,
|
||||
"errors": [],
|
||||
"operation_times": [],
|
||||
"retry_count": 0,
|
||||
"max_retries": 3,
|
||||
"retry_delay": 0.1
|
||||
"retry_delay": 0.1,
|
||||
}
|
||||
results_lock = threading.Lock()
|
||||
|
||||
|
||||
def worker(thread_id):
|
||||
# Track operation timing
|
||||
start_time = time.time()
|
||||
retry_count = 0
|
||||
success = False
|
||||
error_message = None
|
||||
|
||||
|
||||
while retry_count <= results["max_retries"] and not success:
|
||||
try:
|
||||
# Generate unique key for this thread
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
full_key = f"test:concurrent:{thread_id}:{unique_id}"
|
||||
|
||||
|
||||
# Simple string operations instead of JSON
|
||||
test_value = f"test-value-{thread_id}-{time.time()}"
|
||||
|
||||
|
||||
# Set data in Redis with pipeline for efficiency
|
||||
from Controllers.Redis.database import redis_cli
|
||||
|
||||
|
||||
# Use pipeline to reduce network overhead
|
||||
with redis_cli.pipeline() as pipe:
|
||||
pipe.set(full_key, test_value)
|
||||
pipe.get(full_key)
|
||||
pipe.delete(full_key)
|
||||
results_list = pipe.execute()
|
||||
|
||||
|
||||
# Check results
|
||||
set_ok = results_list[0]
|
||||
retrieved_value = results_list[1]
|
||||
if isinstance(retrieved_value, bytes):
|
||||
retrieved_value = retrieved_value.decode('utf-8')
|
||||
|
||||
retrieved_value = retrieved_value.decode("utf-8")
|
||||
|
||||
# Verify data
|
||||
success = set_ok and retrieved_value == test_value
|
||||
|
||||
|
||||
if success:
|
||||
break
|
||||
else:
|
||||
@@ -170,26 +172,28 @@ def run_concurrent_test(num_threads=100):
|
||||
retry_count += 1
|
||||
with results_lock:
|
||||
results["retry_count"] += 1
|
||||
time.sleep(results["retry_delay"] * (2 ** retry_count)) # Exponential backoff
|
||||
|
||||
time.sleep(
|
||||
results["retry_delay"] * (2**retry_count)
|
||||
) # Exponential backoff
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
retry_count += 1
|
||||
with results_lock:
|
||||
results["retry_count"] += 1
|
||||
|
||||
|
||||
# Check if it's a connection error and retry
|
||||
if "Too many connections" in str(e) or "Connection" in str(e):
|
||||
# Exponential backoff for connection issues
|
||||
backoff_time = results["retry_delay"] * (2 ** retry_count)
|
||||
backoff_time = results["retry_delay"] * (2**retry_count)
|
||||
time.sleep(backoff_time)
|
||||
else:
|
||||
# For other errors, use a smaller delay
|
||||
time.sleep(results["retry_delay"])
|
||||
|
||||
|
||||
# Record operation time
|
||||
operation_time = time.time() - start_time
|
||||
|
||||
|
||||
# Update results
|
||||
with results_lock:
|
||||
if success:
|
||||
@@ -200,26 +204,30 @@ def run_concurrent_test(num_threads=100):
|
||||
else:
|
||||
results["failed"] += 1
|
||||
if error_message:
|
||||
results["errors"].append(f"Thread {thread_id} failed after {retry_count} retries: {error_message}")
|
||||
results["errors"].append(
|
||||
f"Thread {thread_id} failed after {retry_count} retries: {error_message}"
|
||||
)
|
||||
else:
|
||||
results["errors"].append(f"Thread {thread_id} failed after {retry_count} retries with unknown error")
|
||||
|
||||
results["errors"].append(
|
||||
f"Thread {thread_id} failed after {retry_count} retries with unknown error"
|
||||
)
|
||||
|
||||
# Create and start threads using a thread pool
|
||||
start_time = time.time()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(num_threads)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
|
||||
# Calculate execution time and performance metrics
|
||||
execution_time = time.time() - start_time
|
||||
ops_per_second = num_threads / execution_time if execution_time > 0 else 0
|
||||
|
||||
|
||||
# Calculate additional metrics if we have successful operations
|
||||
avg_op_time = 0
|
||||
min_op_time = 0
|
||||
max_op_time = 0
|
||||
p95_op_time = 0
|
||||
|
||||
|
||||
if results["operation_times"]:
|
||||
avg_op_time = sum(results["operation_times"]) / len(results["operation_times"])
|
||||
min_op_time = min(results["operation_times"])
|
||||
@@ -227,8 +235,12 @@ def run_concurrent_test(num_threads=100):
|
||||
# Calculate 95th percentile
|
||||
sorted_times = sorted(results["operation_times"])
|
||||
p95_index = int(len(sorted_times) * 0.95)
|
||||
p95_op_time = sorted_times[p95_index] if p95_index < len(sorted_times) else sorted_times[-1]
|
||||
|
||||
p95_op_time = (
|
||||
sorted_times[p95_index]
|
||||
if p95_index < len(sorted_times)
|
||||
else sorted_times[-1]
|
||||
)
|
||||
|
||||
# Print detailed results
|
||||
print("\nConcurrent Redis Test Results:")
|
||||
print(f"Total threads: {num_threads}")
|
||||
@@ -237,17 +249,17 @@ def run_concurrent_test(num_threads=100):
|
||||
print(f"Operations with retries: {results['retried']}")
|
||||
print(f"Total retry attempts: {results['retry_count']}")
|
||||
print(f"Success rate: {(results['passed'] / num_threads) * 100:.2f}%")
|
||||
|
||||
|
||||
print("\nPerformance Metrics:")
|
||||
print(f"Total execution time: {execution_time:.2f} seconds")
|
||||
print(f"Operations per second: {ops_per_second:.2f}")
|
||||
|
||||
|
||||
if results["operation_times"]:
|
||||
print(f"Average operation time: {avg_op_time * 1000:.2f} ms")
|
||||
print(f"Minimum operation time: {min_op_time * 1000:.2f} ms")
|
||||
print(f"Maximum operation time: {max_op_time * 1000:.2f} ms")
|
||||
print(f"95th percentile operation time: {p95_op_time * 1000:.2f} ms")
|
||||
|
||||
|
||||
# Print errors (limited to 10 for readability)
|
||||
if results["errors"]:
|
||||
print("\nErrors:")
|
||||
@@ -255,7 +267,7 @@ def run_concurrent_test(num_threads=100):
|
||||
print(f"- {error}")
|
||||
if len(results["errors"]) > 10:
|
||||
print(f"- ... and {len(results['errors']) - 10} more errors")
|
||||
|
||||
|
||||
# Return results for potential further analysis
|
||||
return results
|
||||
|
||||
@@ -263,6 +275,6 @@ def run_concurrent_test(num_threads=100):
|
||||
if __name__ == "__main__":
|
||||
# Run basic examples
|
||||
run_all_examples()
|
||||
|
||||
|
||||
# Run enhanced concurrent test
|
||||
run_concurrent_test(10000)
|
||||
|
||||
Reference in New Issue
Block a user