271 lines
8.6 KiB
Python
271 lines
8.6 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
PostgreSQL Database Tester
|
|
|
|
This script tests the connection to PostgreSQL and performs basic operations
|
|
using SQLAlchemy models and Alembic migrations.
|
|
"""
|
|
import os
|
|
import sys
|
|
import argparse
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
from app import init_db, get_session, engine
|
|
from app.models import User, Post
|
|
|
|
|
|
def test_connection():
|
|
"""Test the database connection"""
|
|
try:
|
|
# Try to execute a simple query
|
|
with engine.connect() as conn:
|
|
result = conn.execute(text("SELECT 1"))
|
|
print("✅ Database connection successful!")
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
print(f"❌ Database connection failed: {e}")
|
|
return False
|
|
|
|
|
|
def create_tables():
|
|
"""Create database tables using SQLAlchemy models"""
|
|
try:
|
|
init_db()
|
|
print("✅ Database tables created successfully!")
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
print(f"❌ Failed to create tables: {e}")
|
|
return False
|
|
|
|
|
|
def insert_test_data():
|
|
"""Insert test data into the database"""
|
|
session = get_session()
|
|
try:
|
|
# Create test users
|
|
user1 = User(
|
|
username="testuser1",
|
|
email="testuser1@example.com",
|
|
password_hash="hashed_password_1",
|
|
first_name="Test",
|
|
last_name="User1"
|
|
)
|
|
|
|
user2 = User(
|
|
username="testuser2",
|
|
email="testuser2@example.com",
|
|
password_hash="hashed_password_2",
|
|
first_name="Test",
|
|
last_name="User2"
|
|
)
|
|
|
|
session.add_all([user1, user2])
|
|
session.flush() # Flush to get the IDs
|
|
|
|
# Create test posts
|
|
post1 = Post(
|
|
user_id=user1.id,
|
|
title="Test Post 1",
|
|
content="This is test post 1 content",
|
|
is_published=True
|
|
)
|
|
|
|
post2 = Post(
|
|
user_id=user1.id,
|
|
title="Test Post 2",
|
|
content="This is test post 2 content",
|
|
is_published=False
|
|
)
|
|
|
|
post3 = Post(
|
|
user_id=user2.id,
|
|
title="Test Post 3",
|
|
content="This is test post 3 content",
|
|
is_published=True
|
|
)
|
|
|
|
session.add_all([post1, post2, post3])
|
|
session.commit()
|
|
|
|
print(f"✅ Test data inserted successfully!")
|
|
print(f" - Created {2} users")
|
|
print(f" - Created {3} posts")
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
session.rollback()
|
|
print(f"❌ Failed to insert test data: {e}")
|
|
return False
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def query_test_data():
|
|
"""Query and display test data from the database"""
|
|
session = get_session()
|
|
try:
|
|
# Query users
|
|
users = session.query(User).all()
|
|
print(f"\n📊 Users in database ({len(users)}):")
|
|
for user in users:
|
|
print(f" - {user.username} ({user.email})")
|
|
|
|
# Query posts
|
|
posts = session.query(Post).all()
|
|
print(f"\n📊 Posts in database ({len(posts)}):")
|
|
for post in posts:
|
|
print(f" - {post.title} by user_id: {post.user_id} (Published: {post.is_published})")
|
|
|
|
# Query posts with users (join)
|
|
posts_with_users = (
|
|
session.query(Post, User.username)
|
|
.join(User, Post.user_id == User.id)
|
|
.all()
|
|
)
|
|
print(f"\n📊 Posts with usernames:")
|
|
for post, username in posts_with_users:
|
|
print(f" - {post.title} by {username} (Published: {post.is_published})")
|
|
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
print(f"❌ Failed to query test data: {e}")
|
|
return False
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def run_alembic_migration():
|
|
"""Run Alembic migration to create or update database schema"""
|
|
try:
|
|
# Run Alembic migration
|
|
os.system("alembic revision --autogenerate -m 'Create initial tables'")
|
|
os.system("alembic upgrade head")
|
|
print("✅ Alembic migration completed successfully!")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Failed to run Alembic migration: {e}")
|
|
return False
|
|
|
|
|
|
def run_connection_test(thread_id):
|
|
"""Run a connection test from a worker thread"""
|
|
try:
|
|
# Create a new engine for this thread to avoid sharing connections
|
|
from sqlalchemy import create_engine
|
|
import time
|
|
import random
|
|
|
|
# Get database connection string from environment variable or use default
|
|
database_url = os.environ.get(
|
|
"DATABASE_URL",
|
|
"postgresql://postgres:password@localhost:5432/postgres"
|
|
)
|
|
|
|
# Create a new engine for this thread
|
|
thread_engine = create_engine(database_url)
|
|
|
|
# Connect to the database
|
|
with thread_engine.connect() as conn:
|
|
# Execute a simple query
|
|
result = conn.execute(text(f"SELECT {thread_id} as thread_id, now() as time"))
|
|
row = result.fetchone()
|
|
|
|
# Simulate some work
|
|
time.sleep(random.uniform(0.1, 0.5))
|
|
|
|
# Execute another query
|
|
conn.execute(text("SELECT pg_sleep(0.1)"))
|
|
|
|
print(f"✅ Thread {thread_id:03d} connected successfully at {row.time}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"❌ Thread {thread_id:03d} connection failed: {e}")
|
|
return False
|
|
|
|
|
|
def run_stress_test(num_threads=100):
|
|
"""Run a stress test with multiple threads"""
|
|
import concurrent.futures
|
|
import time
|
|
|
|
print(f"🧪 Starting stress test with {num_threads} concurrent connections")
|
|
start_time = time.time()
|
|
|
|
# Use ThreadPoolExecutor to run multiple threads
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
# Submit tasks
|
|
futures = [executor.submit(run_connection_test, i) for i in range(num_threads)]
|
|
|
|
# Wait for all tasks to complete
|
|
concurrent.futures.wait(futures)
|
|
|
|
# Count successful connections
|
|
successful = sum(1 for future in futures if future.result())
|
|
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
|
|
print(f"\n📊 Stress test results:")
|
|
print(f" - Total connections: {num_threads}")
|
|
print(f" - Successful connections: {successful}")
|
|
print(f" - Failed connections: {num_threads - successful}")
|
|
print(f" - Duration: {duration:.2f} seconds")
|
|
print(f" - Connections per second: {num_threads / duration:.2f}")
|
|
|
|
return successful == num_threads
|
|
|
|
|
|
def main():
|
|
"""Main function to run the database tests"""
|
|
parser = argparse.ArgumentParser(description="PostgreSQL Database Tester")
|
|
parser.add_argument("--connection", action="store_true", help="Test database connection")
|
|
parser.add_argument("--create-tables", action="store_true", help="Create database tables")
|
|
parser.add_argument("--insert-data", action="store_true", help="Insert test data")
|
|
parser.add_argument("--query-data", action="store_true", help="Query test data")
|
|
parser.add_argument("--migration", action="store_true", help="Run Alembic migration")
|
|
parser.add_argument("--stress-test", action="store_true", help="Run stress test with 100 concurrent connections")
|
|
parser.add_argument("--threads", type=int, default=100, help="Number of threads for stress test")
|
|
parser.add_argument("--all", action="store_true", help="Run all tests")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# If no arguments provided, show help
|
|
if len(sys.argv) == 1:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
# Print database URL (with password masked)
|
|
db_url = os.environ.get("DATABASE_URL", "postgresql://postgres:******@localhost:5432/postgres")
|
|
print(f"🔌 Using database: {db_url.replace(':password@', ':******@')}")
|
|
|
|
# Run tests based on arguments
|
|
if args.all or args.connection:
|
|
if not test_connection():
|
|
print("❌ Connection test failed. Exiting.")
|
|
sys.exit(1)
|
|
|
|
if args.all or args.migration:
|
|
run_alembic_migration()
|
|
|
|
if args.all or args.create_tables:
|
|
create_tables()
|
|
|
|
if args.all or args.insert_data:
|
|
insert_test_data()
|
|
|
|
if args.all or args.query_data:
|
|
query_test_data()
|
|
|
|
if args.all or args.stress_test:
|
|
num_threads = args.threads
|
|
run_stress_test(num_threads)
|
|
|
|
print("\n✅ All tests completed!")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|