postgres-service/runner/db_tester.py

271 lines
8.6 KiB
Python

#!/usr/bin/env python
"""
PostgreSQL Database Tester
This script tests the connection to PostgreSQL and performs basic operations
using SQLAlchemy models and Alembic migrations.
"""
import os
import sys
import argparse
import uuid
from datetime import datetime
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from app import init_db, get_session, engine
from app.models import User, Post
def test_connection():
"""Test the database connection"""
try:
# Try to execute a simple query
with engine.connect() as conn:
result = conn.execute(text("SELECT 1"))
print("✅ Database connection successful!")
return True
except SQLAlchemyError as e:
print(f"❌ Database connection failed: {e}")
return False
def create_tables():
"""Create database tables using SQLAlchemy models"""
try:
init_db()
print("✅ Database tables created successfully!")
return True
except SQLAlchemyError as e:
print(f"❌ Failed to create tables: {e}")
return False
def insert_test_data():
"""Insert test data into the database"""
session = get_session()
try:
# Create test users
user1 = User(
username="testuser1",
email="testuser1@example.com",
password_hash="hashed_password_1",
first_name="Test",
last_name="User1"
)
user2 = User(
username="testuser2",
email="testuser2@example.com",
password_hash="hashed_password_2",
first_name="Test",
last_name="User2"
)
session.add_all([user1, user2])
session.flush() # Flush to get the IDs
# Create test posts
post1 = Post(
user_id=user1.id,
title="Test Post 1",
content="This is test post 1 content",
is_published=True
)
post2 = Post(
user_id=user1.id,
title="Test Post 2",
content="This is test post 2 content",
is_published=False
)
post3 = Post(
user_id=user2.id,
title="Test Post 3",
content="This is test post 3 content",
is_published=True
)
session.add_all([post1, post2, post3])
session.commit()
print(f"✅ Test data inserted successfully!")
print(f" - Created {2} users")
print(f" - Created {3} posts")
return True
except SQLAlchemyError as e:
session.rollback()
print(f"❌ Failed to insert test data: {e}")
return False
finally:
session.close()
def query_test_data():
"""Query and display test data from the database"""
session = get_session()
try:
# Query users
users = session.query(User).all()
print(f"\n📊 Users in database ({len(users)}):")
for user in users:
print(f" - {user.username} ({user.email})")
# Query posts
posts = session.query(Post).all()
print(f"\n📊 Posts in database ({len(posts)}):")
for post in posts:
print(f" - {post.title} by user_id: {post.user_id} (Published: {post.is_published})")
# Query posts with users (join)
posts_with_users = (
session.query(Post, User.username)
.join(User, Post.user_id == User.id)
.all()
)
print(f"\n📊 Posts with usernames:")
for post, username in posts_with_users:
print(f" - {post.title} by {username} (Published: {post.is_published})")
return True
except SQLAlchemyError as e:
print(f"❌ Failed to query test data: {e}")
return False
finally:
session.close()
def run_alembic_migration():
"""Run Alembic migration to create or update database schema"""
try:
# Run Alembic migration
os.system("alembic revision --autogenerate -m 'Create initial tables'")
os.system("alembic upgrade head")
print("✅ Alembic migration completed successfully!")
return True
except Exception as e:
print(f"❌ Failed to run Alembic migration: {e}")
return False
def run_connection_test(thread_id):
"""Run a connection test from a worker thread"""
try:
# Create a new engine for this thread to avoid sharing connections
from sqlalchemy import create_engine
import time
import random
# Get database connection string from environment variable or use default
database_url = os.environ.get(
"DATABASE_URL",
"postgresql://postgres:password@localhost:5432/postgres"
)
# Create a new engine for this thread
thread_engine = create_engine(database_url)
# Connect to the database
with thread_engine.connect() as conn:
# Execute a simple query
result = conn.execute(text(f"SELECT {thread_id} as thread_id, now() as time"))
row = result.fetchone()
# Simulate some work
time.sleep(random.uniform(0.1, 0.5))
# Execute another query
conn.execute(text("SELECT pg_sleep(0.1)"))
print(f"✅ Thread {thread_id:03d} connected successfully at {row.time}")
return True
except Exception as e:
print(f"❌ Thread {thread_id:03d} connection failed: {e}")
return False
def run_stress_test(num_threads=100):
"""Run a stress test with multiple threads"""
import concurrent.futures
import time
print(f"🧪 Starting stress test with {num_threads} concurrent connections")
start_time = time.time()
# Use ThreadPoolExecutor to run multiple threads
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
# Submit tasks
futures = [executor.submit(run_connection_test, i) for i in range(num_threads)]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
# Count successful connections
successful = sum(1 for future in futures if future.result())
end_time = time.time()
duration = end_time - start_time
print(f"\n📊 Stress test results:")
print(f" - Total connections: {num_threads}")
print(f" - Successful connections: {successful}")
print(f" - Failed connections: {num_threads - successful}")
print(f" - Duration: {duration:.2f} seconds")
print(f" - Connections per second: {num_threads / duration:.2f}")
return successful == num_threads
def main():
"""Main function to run the database tests"""
parser = argparse.ArgumentParser(description="PostgreSQL Database Tester")
parser.add_argument("--connection", action="store_true", help="Test database connection")
parser.add_argument("--create-tables", action="store_true", help="Create database tables")
parser.add_argument("--insert-data", action="store_true", help="Insert test data")
parser.add_argument("--query-data", action="store_true", help="Query test data")
parser.add_argument("--migration", action="store_true", help="Run Alembic migration")
parser.add_argument("--stress-test", action="store_true", help="Run stress test with 100 concurrent connections")
parser.add_argument("--threads", type=int, default=100, help="Number of threads for stress test")
parser.add_argument("--all", action="store_true", help="Run all tests")
args = parser.parse_args()
# If no arguments provided, show help
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
# Print database URL (with password masked)
db_url = os.environ.get("DATABASE_URL", "postgresql://postgres:******@localhost:5432/postgres")
print(f"🔌 Using database: {db_url.replace(':password@', ':******@')}")
# Run tests based on arguments
if args.all or args.connection:
if not test_connection():
print("❌ Connection test failed. Exiting.")
sys.exit(1)
if args.all or args.migration:
run_alembic_migration()
if args.all or args.create_tables:
create_tables()
if args.all or args.insert_data:
insert_test_data()
if args.all or args.query_data:
query_test_data()
if args.all or args.stress_test:
num_threads = args.threads
run_stress_test(num_threads)
print("\n✅ All tests completed!")
if __name__ == "__main__":
main()