updated services api

This commit is contained in:
Berkay 2025-05-30 23:43:46 +03:00
parent e5829f0525
commit f8184246d9
31 changed files with 2963 additions and 9 deletions

View File

@ -2,7 +2,7 @@ import arrow
from typing import Any, Dict, Optional, Union
from config import api_config
from schemas import (
from Schemas import (
Users,
People,
BuildLivingSpace,
@ -23,11 +23,11 @@ from schemas import (
Events,
EndpointRestriction,
)
from api_modules.token.password_module import PasswordModule
from api_controllers.mongo.database import mongo_handler
from api_validations.token.validations import TokenDictType, EmployeeTokenObject, OccupantTokenObject, CompanyToken, OccupantToken, UserType
from api_validations.defaults.validations import CommonHeaders
from api_modules.redis.redis_handlers import RedisHandlers
from Controllers.mongo.database import mongo_handler
from Validations.token.validations import TokenDictType, EmployeeTokenObject, OccupantTokenObject, CompanyToken, OccupantToken, UserType
from Validations.defaults.validations import CommonHeaders
from Extends.redis.redis_handlers import RedisHandlers
from Extends.token.password_module import PasswordModule
from validations.password.validations import PasswordHistoryViaUser
@ -68,7 +68,6 @@ class LoginHandler:
return str(email).split("@")[1] == api_config.ACCESS_EMAIL_EXT
@classmethod
# headers: CommonHeaders
def do_employee_login(cls, headers: CommonHeaders, data: Any, db_session):
"""Handle employee login."""
@ -159,7 +158,6 @@ class LoginHandler:
raise ValueError("Something went wrong")
@classmethod
# headers=headers, data=data, db_session=db_session
def do_occupant_login(cls, headers: CommonHeaders, data: Any, db_session):
"""
Handle occupant login.
@ -376,7 +374,7 @@ class LoginHandler:
)
return {"selected_uu_id": occupant_token.living_space_uu_id}
@classmethod # Requires auth context
@classmethod
def authentication_select_company_or_occupant_type(cls, request: Any, data: Any):
"""
Handle selection of company or occupant type

View File

@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Configs(BaseSettings):
"""
Email configuration settings.
"""
HOST: str = ""
USERNAME: str = ""
PASSWORD: str = ""
PORT: int = 0
SEND: bool = True
@property
def is_send(self):
return bool(self.SEND)
def as_dict(self):
return dict(
host=self.HOST,
port=self.PORT,
username=self.USERNAME,
password=self.PASSWORD,
)
model_config = SettingsConfigDict(env_prefix="EMAIL_")
# singleton instance of the POSTGRESQL configuration settings
email_configs = Configs()

View File

@ -0,0 +1,29 @@
from send_email import EmailService, EmailSendModel
# Create email parameters
email_params = EmailSendModel(
subject="Test Email",
html="<p>Hello world!</p>",
receivers=["recipient@example.com"],
text="Hello world!",
)
another_email_params = EmailSendModel(
subject="Test Email2",
html="<p>Hello world!2</p>",
receivers=["recipient@example.com"],
text="Hello world!2",
)
# The context manager handles connection errors
with EmailService.new_session() as email_session:
# Send email - any exceptions here will propagate up
EmailService.send_email(email_session, email_params)
# Or send directly through the session
email_session.send(email_params)
# Send more emails in the same session if needed
EmailService.send_email(email_session, another_email_params)

View File

@ -0,0 +1,90 @@
from redmail import EmailSender
from typing import List, Optional, Dict
from pydantic import BaseModel
from contextlib import contextmanager
from .config import email_configs
class EmailSendModel(BaseModel):
subject: str
html: str = ""
receivers: List[str]
text: Optional[str] = ""
cc: Optional[List[str]] = None
bcc: Optional[List[str]] = None
headers: Optional[Dict] = None
attachments: Optional[Dict] = None
class EmailSession:
def __init__(self, email_sender):
self.email_sender = email_sender
def send(self, params: EmailSendModel) -> bool:
"""Send email using this session."""
if not email_configs.is_send:
print("Email sending is disabled", params)
return False
receivers = [email_configs.USERNAME]
# Ensure connection is established before sending
try:
# Check if connection exists, if not establish it
if not hasattr(self.email_sender, '_connected') or not self.email_sender._connected:
self.email_sender.connect()
self.email_sender.send(
subject=params.subject,
receivers=receivers,
text=params.text + f" : Gonderilen [{str(receivers)}]",
html=params.html,
cc=params.cc,
bcc=params.bcc,
headers=params.headers or {},
attachments=params.attachments or {},
)
return True
except Exception as e:
print(f"Error sending email: {e}")
raise
class EmailService:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EmailService, cls).__new__(cls)
return cls._instance
@classmethod
@contextmanager
def new_session(cls):
"""Create and yield a new email session with active connection."""
email_sender = EmailSender(**email_configs.as_dict())
session = EmailSession(email_sender)
connection_established = False
try:
# Establish connection and set flag
email_sender.connect()
# Set a flag to track connection state
email_sender._connected = True
connection_established = True
yield session
except Exception as e:
print(f"Error with email connection: {e}")
raise
finally:
# Only close if connection was successfully established
if connection_established:
try:
email_sender.close()
email_sender._connected = False
except Exception as e:
print(f"Error closing email connection: {e}")
@classmethod
def send_email(cls, session: EmailSession, params: EmailSendModel) -> bool:
"""Send email using the provided session."""
return session.send(params)

View File

@ -0,0 +1,219 @@
# MongoDB Handler
A singleton MongoDB handler with context manager support for MongoDB collections and automatic retry capabilities.
## Features
- **Singleton Pattern**: Ensures only one instance of the MongoDB handler exists
- **Context Manager**: Automatically manages connection lifecycle
- **Retry Capability**: Automatically retries MongoDB operations on failure
- **Connection Pooling**: Configurable connection pooling
- **Graceful Degradation**: Handles connection failures without crashing
## Usage
```python
from Controllers.Mongo.database import mongo_handler
# Use the context manager to access a collection
with mongo_handler.collection("users") as users_collection:
# Perform operations on the collection
users_collection.insert_one({"username": "john", "email": "john@example.com"})
user = users_collection.find_one({"username": "john"})
# Connection is automatically closed when exiting the context
```
## Configuration
MongoDB connection settings are configured via environment variables with the `MONGO_` prefix:
- `MONGO_ENGINE`: Database engine (e.g., "mongodb")
- `MONGO_USER`: MongoDB username
- `MONGO_PASSWORD`: MongoDB password
- `MONGO_HOST`: MongoDB host
- `MONGO_PORT`: MongoDB port
- `MONGO_DB`: Database name
- `MONGO_AUTH_DB`: Authentication database
## Monitoring Connection Closure
To verify that MongoDB sessions are properly closed, you can implement one of the following approaches:
### 1. Add Logging to the `__exit__` Method
```python
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit context, closing the connection.
"""
if self.client:
print(f"Closing MongoDB connection for collection: {self.collection_name}")
# Or use a proper logger
# logger.info(f"Closing MongoDB connection for collection: {self.collection_name}")
self.client.close()
self.client = None
self.collection = None
print(f"MongoDB connection closed successfully")
```
### 2. Add Connection Tracking
```python
class MongoDBHandler:
# Add these to your class
_open_connections = 0
def get_connection_stats(self):
"""Return statistics about open connections"""
return {"open_connections": self._open_connections}
```
Then modify the `CollectionContext` class:
```python
def __enter__(self):
try:
# Create a new client connection
self.client = MongoClient(self.db_handler.uri, **self.db_handler.client_options)
# Increment connection counter
self.db_handler._open_connections += 1
# Rest of your code...
def __exit__(self, exc_type, exc_val, exc_tb):
if self.client:
# Decrement connection counter
self.db_handler._open_connections -= 1
self.client.close()
self.client = None
self.collection = None
```
### 3. Use MongoDB's Built-in Monitoring
```python
from pymongo import monitoring
class ConnectionCommandListener(monitoring.CommandListener):
def started(self, event):
print(f"Command {event.command_name} started on server {event.connection_id}")
def succeeded(self, event):
print(f"Command {event.command_name} succeeded in {event.duration_micros} microseconds")
def failed(self, event):
print(f"Command {event.command_name} failed in {event.duration_micros} microseconds")
# Register the listener
monitoring.register(ConnectionCommandListener())
```
### 4. Add a Test Function
```python
def test_connection_closure():
"""Test that MongoDB connections are properly closed."""
print("\nTesting connection closure...")
# Record initial connection count (if you implemented the counter)
initial_count = mongo_handler.get_connection_stats()["open_connections"]
# Use multiple nested contexts
for i in range(5):
with mongo_handler.collection("test_collection") as collection:
# Do some simple operation
collection.find_one({})
# Check final connection count
final_count = mongo_handler.get_connection_stats()["open_connections"]
if final_count == initial_count:
print("Test passed: All connections were properly closed")
return True
else:
print(f"Test failed: {final_count - initial_count} connections remain open")
return False
```
### 5. Use MongoDB Server Logs
You can also check the MongoDB server logs to see connection events:
```bash
# Run this on your MongoDB server
tail -f /var/log/mongodb/mongod.log | grep "connection"
```
## Best Practices
1. Always use the context manager pattern to ensure connections are properly closed
2. Keep operations within the context manager as concise as possible
3. Handle exceptions within the context to prevent unexpected behavior
4. Avoid nesting multiple context managers unnecessarily
5. Use the retry decorator for operations that might fail due to transient issues
## LXC Container Configuration
### Authentication Issues
If you encounter authentication errors when connecting to the MongoDB container at 10.10.2.13:27017, you may need to update the container configuration:
1. **Check MongoDB Authentication**: Ensure the MongoDB container is configured with the correct authentication mechanism
2. **Verify Network Configuration**: Make sure the container network allows connections from your application
3. **Update MongoDB Configuration**:
- Edit the MongoDB configuration file in the container
- Ensure `bindIp` is set correctly (e.g., `0.0.0.0` to allow connections from any IP)
- Check that authentication is enabled with the correct mechanism
4. **User Permissions**:
- Verify that the application user (`appuser`) exists in the MongoDB instance
- Ensure the user has the correct roles and permissions for the database
### Example MongoDB Container Configuration
```yaml
# Example docker-compose.yml configuration
services:
mongodb:
image: mongo:latest
container_name: mongodb
environment:
- MONGO_INITDB_ROOT_USERNAME=admin
- MONGO_INITDB_ROOT_PASSWORD=password
volumes:
- ./init-mongo.js:/docker-entrypoint-initdb.d/init-mongo.js:ro
ports:
- "27017:27017"
command: mongod --auth
```
```javascript
// Example init-mongo.js
db.createUser({
user: 'appuser',
pwd: 'apppassword',
roles: [
{ role: 'readWrite', db: 'appdb' }
]
});
```
## Troubleshooting
### Common Issues
1. **Authentication Failed**:
- Verify username and password in environment variables
- Check that the user exists in the specified authentication database
- Ensure the user has appropriate permissions
2. **Connection Refused**:
- Verify the MongoDB host and port are correct
- Check network connectivity between application and MongoDB container
- Ensure MongoDB is running and accepting connections
3. **Resource Leaks**:
- Use the context manager pattern to ensure connections are properly closed
- Monitor connection pool size and active connections
- Implement proper error handling to close connections in case of exceptions

View File

@ -0,0 +1,31 @@
import os
from pydantic_settings import BaseSettings, SettingsConfigDict
class Configs(BaseSettings):
"""
MongoDB configuration settings.
"""
# MongoDB connection settings
ENGINE: str = "mongodb"
USERNAME: str = "appuser" # Application user
PASSWORD: str = "apppassword" # Application password
HOST: str = "10.10.2.13"
PORT: int = 27017
DB: str = "appdb" # The application database
AUTH_DB: str = "appdb" # Authentication is done against admin database
@property
def url(self):
"""Generate the database URL.
mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}/{DB}?authSource={MONGO_AUTH_DB}
"""
# Include the database name in the URI
return f"{self.ENGINE}://{self.USERNAME}:{self.PASSWORD}@{self.HOST}:{self.PORT}/{self.DB}?authSource={self.DB}"
model_config = SettingsConfigDict(env_prefix="_MONGO_")
# Create a singleton instance of the MongoDB configuration settings
mongo_configs = Configs()

View File

@ -0,0 +1,373 @@
import time
import functools
from pymongo import MongoClient
from pymongo.errors import PyMongoError
from .config import mongo_configs
def retry_operation(max_attempts=3, delay=1.0, backoff=2.0, exceptions=(PyMongoError,)):
"""
Decorator for retrying MongoDB operations with exponential backoff.
Args:
max_attempts: Maximum number of retry attempts
delay: Initial delay between retries in seconds
backoff: Multiplier for delay after each retry
exceptions: Tuple of exceptions to catch and retry
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
mtries, mdelay = max_attempts, delay
while mtries > 1:
try:
return func(*args, **kwargs)
except exceptions as e:
time.sleep(mdelay)
mtries -= 1
mdelay *= backoff
return func(*args, **kwargs)
return wrapper
return decorator
class MongoDBHandler:
"""
A MongoDB handler that provides context manager access to specific collections
with automatic retry capability. Implements singleton pattern.
"""
_instance = None
_debug_mode = False # Set to True to enable debug mode
def __new__(cls, *args, **kwargs):
"""
Implement singleton pattern for the handler.
"""
if cls._instance is None:
cls._instance = super(MongoDBHandler, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, debug_mode=False, mock_mode=False):
"""Initialize the MongoDB handler.
Args:
debug_mode: If True, use a simplified connection for debugging
mock_mode: If True, use mock collections instead of real MongoDB connections
"""
if not hasattr(self, "_initialized") or not self._initialized:
self._debug_mode = debug_mode
self._mock_mode = mock_mode
if mock_mode:
# In mock mode, we don't need a real connection string
self.uri = "mongodb://mock:27017/mockdb"
print("MOCK MODE: Using simulated MongoDB connections")
elif debug_mode:
# Use a direct connection without authentication for testing
self.uri = f"mongodb://{mongo_configs.HOST}:{mongo_configs.PORT}/{mongo_configs.DB}"
print(f"DEBUG MODE: Using direct connection: {self.uri}")
else:
# Use the configured connection string with authentication
self.uri = mongo_configs.url
print(f"Connecting to MongoDB: {self.uri}")
# Define MongoDB client options with increased timeouts for better reliability
self.client_options = {
"maxPoolSize": 5,
"minPoolSize": 1,
"maxIdleTimeMS": 60000,
"waitQueueTimeoutMS": 5000,
"serverSelectionTimeoutMS": 10000,
"connectTimeoutMS": 30000,
"socketTimeoutMS": 45000,
"retryWrites": True,
"retryReads": True,
}
self._initialized = True
def collection(self, collection_name: str):
"""
Get a context manager for a specific collection.
Args:
collection_name: Name of the collection to access
Returns:
A context manager for the specified collection
"""
return CollectionContext(self, collection_name)
class CollectionContext:
"""
Context manager for MongoDB collections with automatic retry capability.
"""
def __init__(self, db_handler: MongoDBHandler, collection_name: str):
"""
Initialize collection context.
Args:
db_handler: Reference to the MongoDB handler
collection_name: Name of the collection to access
"""
self.db_handler = db_handler
self.collection_name = collection_name
self.client = None
self.collection = None
def __enter__(self):
"""
Enter context, establishing a new connection.
Returns:
The MongoDB collection object with retry capabilities
"""
# If we're in mock mode, return a mock collection immediately
if self.db_handler._mock_mode:
return self._create_mock_collection()
try:
# Create a new client connection
self.client = MongoClient(
self.db_handler.uri, **self.db_handler.client_options
)
if self.db_handler._debug_mode:
# In debug mode, we explicitly use the configured DB
db_name = mongo_configs.DB
print(f"DEBUG MODE: Using database '{db_name}'")
else:
# In normal mode, extract database name from the URI
try:
db_name = self.client.get_database().name
except Exception:
db_name = mongo_configs.DB
print(f"Using fallback database '{db_name}'")
self.collection = self.client[db_name][self.collection_name]
# Enhance collection methods with retry capabilities
self._add_retry_capabilities()
return self.collection
except pymongo.errors.OperationFailure as e:
if "Authentication failed" in str(e):
print(f"MongoDB authentication error: {e}")
print("Attempting to reconnect with direct connection...")
try:
# Try a direct connection without authentication for testing
direct_uri = f"mongodb://{mongo_configs.HOST}:{mongo_configs.PORT}/{mongo_configs.DB}"
print(f"Trying direct connection: {direct_uri}")
self.client = MongoClient(
direct_uri, **self.db_handler.client_options
)
self.collection = self.client[mongo_configs.DB][
self.collection_name
]
self._add_retry_capabilities()
return self.collection
except Exception as inner_e:
print(f"Direct connection also failed: {inner_e}")
# Fall through to mock collection creation
else:
print(f"MongoDB operation error: {e}")
if self.client:
self.client.close()
self.client = None
except Exception as e:
print(f"MongoDB connection error: {e}")
if self.client:
self.client.close()
self.client = None
return self._create_mock_collection()
def _create_mock_collection(self):
"""
Create a mock collection for testing or graceful degradation.
This prevents the application from crashing when MongoDB is unavailable.
Returns:
A mock MongoDB collection with simulated behaviors
"""
from unittest.mock import MagicMock
if self.db_handler._mock_mode:
print(f"MOCK MODE: Using mock collection '{self.collection_name}'")
else:
print(
f"Using mock MongoDB collection '{self.collection_name}' for graceful degradation"
)
# Create in-memory storage for this mock collection
if not hasattr(self.db_handler, "_mock_storage"):
self.db_handler._mock_storage = {}
if self.collection_name not in self.db_handler._mock_storage:
self.db_handler._mock_storage[self.collection_name] = []
mock_collection = MagicMock()
mock_data = self.db_handler._mock_storage[self.collection_name]
# Define behavior for find operations
def mock_find(query=None, *args, **kwargs):
# Simple implementation that returns all documents
return mock_data
def mock_find_one(query=None, *args, **kwargs):
# Simple implementation that returns the first matching document
if not mock_data:
return None
return mock_data[0]
def mock_insert_one(document, *args, **kwargs):
# Add _id if not present
if "_id" not in document:
document["_id"] = f"mock_id_{len(mock_data)}"
mock_data.append(document)
result = MagicMock()
result.inserted_id = document["_id"]
return result
def mock_insert_many(documents, *args, **kwargs):
inserted_ids = []
for doc in documents:
result = mock_insert_one(doc)
inserted_ids.append(result.inserted_id)
result = MagicMock()
result.inserted_ids = inserted_ids
return result
def mock_update_one(query, update, *args, **kwargs):
result = MagicMock()
result.modified_count = 1
return result
def mock_update_many(query, update, *args, **kwargs):
result = MagicMock()
result.modified_count = len(mock_data)
return result
def mock_delete_one(query, *args, **kwargs):
result = MagicMock()
result.deleted_count = 1
if mock_data:
mock_data.pop(0) # Just remove the first item for simplicity
return result
def mock_delete_many(query, *args, **kwargs):
count = len(mock_data)
mock_data.clear()
result = MagicMock()
result.deleted_count = count
return result
def mock_count_documents(query, *args, **kwargs):
return len(mock_data)
def mock_aggregate(pipeline, *args, **kwargs):
return []
def mock_create_index(keys, **kwargs):
return f"mock_index_{keys}"
# Assign the mock implementations
mock_collection.find.side_effect = mock_find
mock_collection.find_one.side_effect = mock_find_one
mock_collection.insert_one.side_effect = mock_insert_one
mock_collection.insert_many.side_effect = mock_insert_many
mock_collection.update_one.side_effect = mock_update_one
mock_collection.update_many.side_effect = mock_update_many
mock_collection.delete_one.side_effect = mock_delete_one
mock_collection.delete_many.side_effect = mock_delete_many
mock_collection.count_documents.side_effect = mock_count_documents
mock_collection.aggregate.side_effect = mock_aggregate
mock_collection.create_index.side_effect = mock_create_index
# Add retry capabilities to the mock collection
self._add_retry_capabilities_to_mock(mock_collection)
self.collection = mock_collection
return self.collection
def _add_retry_capabilities(self):
"""
Add retry capabilities to all collection methods.
"""
# Store original methods for common operations
original_insert_one = self.collection.insert_one
original_insert_many = self.collection.insert_many
original_find_one = self.collection.find_one
original_find = self.collection.find
original_update_one = self.collection.update_one
original_update_many = self.collection.update_many
original_delete_one = self.collection.delete_one
original_delete_many = self.collection.delete_many
original_replace_one = self.collection.replace_one
original_count_documents = self.collection.count_documents
# Add retry capabilities to methods
self.collection.insert_one = retry_operation()(original_insert_one)
self.collection.insert_many = retry_operation()(original_insert_many)
self.collection.find_one = retry_operation()(original_find_one)
self.collection.find = retry_operation()(original_find)
self.collection.update_one = retry_operation()(original_update_one)
self.collection.update_many = retry_operation()(original_update_many)
self.collection.delete_one = retry_operation()(original_delete_one)
self.collection.delete_many = retry_operation()(original_delete_many)
self.collection.replace_one = retry_operation()(original_replace_one)
self.collection.count_documents = retry_operation()(original_count_documents)
def _add_retry_capabilities_to_mock(self, mock_collection):
"""
Add retry capabilities to mock collection methods.
This is a simplified version that just wraps the mock methods.
Args:
mock_collection: The mock collection to enhance
"""
# List of common MongoDB collection methods to add retry capabilities to
methods = [
"insert_one",
"insert_many",
"find_one",
"find",
"update_one",
"update_many",
"delete_one",
"delete_many",
"replace_one",
"count_documents",
"aggregate",
]
# Add retry decorator to each method
for method_name in methods:
if hasattr(mock_collection, method_name):
original_method = getattr(mock_collection, method_name)
setattr(
mock_collection,
method_name,
retry_operation(max_retries=1, retry_interval=0)(original_method),
)
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Exit context, closing the connection.
"""
if self.client:
self.client.close()
self.client = None
self.collection = None
# Create a singleton instance of the MongoDB handler
mongo_handler = MongoDBHandler()

View File

@ -0,0 +1,519 @@
# Initialize the MongoDB handler with your configuration
from datetime import datetime
from .database import MongoDBHandler, mongo_handler
def cleanup_test_data():
"""Clean up any test data before running tests."""
try:
with mongo_handler.collection("test_collection") as collection:
collection.delete_many({})
print("Successfully cleaned up test data")
except Exception as e:
print(f"Warning: Could not clean up test data: {e}")
print("Continuing with tests using mock data...")
def test_basic_crud_operations():
"""Test basic CRUD operations on users collection."""
print("\nTesting basic CRUD operations...")
try:
with mongo_handler.collection("users") as users_collection:
# First, clear any existing data
users_collection.delete_many({})
print("Cleared existing data")
# Insert multiple documents
insert_result = users_collection.insert_many(
[
{"username": "john", "email": "john@example.com", "role": "user"},
{"username": "jane", "email": "jane@example.com", "role": "admin"},
{"username": "bob", "email": "bob@example.com", "role": "user"},
]
)
print(f"Inserted {len(insert_result.inserted_ids)} documents")
# Find with multiple conditions
admin_users = list(users_collection.find({"role": "admin"}))
print(f"Found {len(admin_users)} admin users")
if admin_users:
print(f"Admin user: {admin_users[0].get('username')}")
# Update multiple documents
update_result = users_collection.update_many(
{"role": "user"}, {"$set": {"last_login": datetime.now().isoformat()}}
)
print(f"Updated {update_result.modified_count} documents")
# Delete documents
delete_result = users_collection.delete_many({"username": "bob"})
print(f"Deleted {delete_result.deleted_count} documents")
# Count remaining documents
remaining = users_collection.count_documents({})
print(f"Remaining documents: {remaining}")
# Check each condition separately
condition1 = len(admin_users) == 1
condition2 = admin_users and admin_users[0].get("username") == "jane"
condition3 = update_result.modified_count == 2
condition4 = delete_result.deleted_count == 1
print(f"Condition 1 (admin count): {condition1}")
print(f"Condition 2 (admin is jane): {condition2}")
print(f"Condition 3 (updated 2 users): {condition3}")
print(f"Condition 4 (deleted bob): {condition4}")
success = condition1 and condition2 and condition3 and condition4
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_nested_documents():
"""Test operations with nested documents in products collection."""
print("\nTesting nested documents...")
try:
with mongo_handler.collection("products") as products_collection:
# Clear any existing data
products_collection.delete_many({})
print("Cleared existing data")
# Insert a product with nested data
insert_result = products_collection.insert_one(
{
"name": "Laptop",
"price": 999.99,
"specs": {"cpu": "Intel i7", "ram": "16GB", "storage": "512GB SSD"},
"in_stock": True,
"tags": ["electronics", "computers", "laptops"],
}
)
print(f"Inserted document with ID: {insert_result.inserted_id}")
# Find with nested field query
laptop = products_collection.find_one({"specs.cpu": "Intel i7"})
print(f"Found laptop: {laptop is not None}")
if laptop:
print(f"Laptop RAM: {laptop.get('specs', {}).get('ram')}")
# Update nested field
update_result = products_collection.update_one(
{"name": "Laptop"}, {"$set": {"specs.ram": "32GB"}}
)
print(f"Update modified count: {update_result.modified_count}")
# Verify the update
updated_laptop = products_collection.find_one({"name": "Laptop"})
print(f"Found updated laptop: {updated_laptop is not None}")
if updated_laptop:
print(f"Updated laptop specs: {updated_laptop.get('specs')}")
if "specs" in updated_laptop:
print(f"Updated RAM: {updated_laptop['specs'].get('ram')}")
# Check each condition separately
condition1 = laptop is not None
condition2 = laptop and laptop.get("specs", {}).get("ram") == "16GB"
condition3 = update_result.modified_count == 1
condition4 = (
updated_laptop and updated_laptop.get("specs", {}).get("ram") == "32GB"
)
print(f"Condition 1 (laptop found): {condition1}")
print(f"Condition 2 (original RAM is 16GB): {condition2}")
print(f"Condition 3 (update modified 1 doc): {condition3}")
print(f"Condition 4 (updated RAM is 32GB): {condition4}")
success = condition1 and condition2 and condition3 and condition4
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_array_operations():
"""Test operations with arrays in orders collection."""
print("\nTesting array operations...")
try:
with mongo_handler.collection("orders") as orders_collection:
# Clear any existing data
orders_collection.delete_many({})
print("Cleared existing data")
# Insert an order with array of items
insert_result = orders_collection.insert_one(
{
"order_id": "ORD001",
"customer": "john",
"items": [
{"product": "Laptop", "quantity": 1},
{"product": "Mouse", "quantity": 2},
],
"total": 1099.99,
"status": "pending",
}
)
print(f"Inserted order with ID: {insert_result.inserted_id}")
# Find orders containing specific items
laptop_orders = list(orders_collection.find({"items.product": "Laptop"}))
print(f"Found {len(laptop_orders)} orders with Laptop")
# Update array elements
update_result = orders_collection.update_one(
{"order_id": "ORD001"},
{"$push": {"items": {"product": "Keyboard", "quantity": 1}}},
)
print(f"Update modified count: {update_result.modified_count}")
# Verify the update
updated_order = orders_collection.find_one({"order_id": "ORD001"})
print(f"Found updated order: {updated_order is not None}")
if updated_order:
print(
f"Number of items in order: {len(updated_order.get('items', []))}"
)
items = updated_order.get("items", [])
if items:
last_item = items[-1] if items else None
print(f"Last item in order: {last_item}")
# Check each condition separately
condition1 = len(laptop_orders) == 1
condition2 = update_result.modified_count == 1
condition3 = updated_order and len(updated_order.get("items", [])) == 3
condition4 = (
updated_order
and updated_order.get("items", [])
and updated_order["items"][-1].get("product") == "Keyboard"
)
print(f"Condition 1 (found 1 laptop order): {condition1}")
print(f"Condition 2 (update modified 1 doc): {condition2}")
print(f"Condition 3 (order has 3 items): {condition3}")
print(f"Condition 4 (last item is keyboard): {condition4}")
success = condition1 and condition2 and condition3 and condition4
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_aggregation():
"""Test aggregation operations on sales collection."""
print("\nTesting aggregation operations...")
try:
with mongo_handler.collection("sales") as sales_collection:
# Clear any existing data
sales_collection.delete_many({})
print("Cleared existing data")
# Insert sample sales data
insert_result = sales_collection.insert_many(
[
{"product": "Laptop", "amount": 999.99, "date": datetime.now()},
{"product": "Mouse", "amount": 29.99, "date": datetime.now()},
{"product": "Keyboard", "amount": 59.99, "date": datetime.now()},
]
)
print(f"Inserted {len(insert_result.inserted_ids)} sales documents")
# Calculate total sales by product - use a simpler aggregation pipeline
pipeline = [
{"$match": {}}, # Match all documents
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}},
]
# Execute the aggregation
sales_summary = list(sales_collection.aggregate(pipeline))
print(f"Aggregation returned {len(sales_summary)} results")
# Print the results for debugging
for item in sales_summary:
print(f"Product: {item.get('_id')}, Total: {item.get('total')}")
# Check each condition separately
condition1 = len(sales_summary) == 3
condition2 = any(
item.get("_id") == "Laptop"
and abs(item.get("total", 0) - 999.99) < 0.01
for item in sales_summary
)
condition3 = any(
item.get("_id") == "Mouse" and abs(item.get("total", 0) - 29.99) < 0.01
for item in sales_summary
)
condition4 = any(
item.get("_id") == "Keyboard"
and abs(item.get("total", 0) - 59.99) < 0.01
for item in sales_summary
)
print(f"Condition 1 (3 summary items): {condition1}")
print(f"Condition 2 (laptop total correct): {condition2}")
print(f"Condition 3 (mouse total correct): {condition3}")
print(f"Condition 4 (keyboard total correct): {condition4}")
success = condition1 and condition2 and condition3 and condition4
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_index_operations():
"""Test index creation and unique constraints."""
print("\nTesting index operations...")
try:
with mongo_handler.collection("test_collection") as collection:
# Create indexes
collection.create_index("email", unique=True)
collection.create_index([("username", 1), ("role", 1)])
# Insert initial document
collection.insert_one(
{"username": "test_user", "email": "test@example.com"}
)
# Try to insert duplicate email (should fail)
try:
collection.insert_one(
{"username": "test_user2", "email": "test@example.com"}
)
success = False # Should not reach here
except Exception:
success = True
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def test_complex_queries():
"""Test complex queries with multiple conditions."""
print("\nTesting complex queries...")
try:
with mongo_handler.collection("products") as products_collection:
# Insert test data
products_collection.insert_many(
[
{
"name": "Expensive Laptop",
"price": 999.99,
"tags": ["electronics", "computers"],
"in_stock": True,
},
{
"name": "Cheap Mouse",
"price": 29.99,
"tags": ["electronics", "peripherals"],
"in_stock": True,
},
]
)
# Find products with price range and specific tags
expensive_electronics = list(
products_collection.find(
{
"price": {"$gt": 500},
"tags": {"$in": ["electronics"]},
"in_stock": True,
}
)
)
# Update with multiple conditions - split into separate operations for better compatibility
# First set the discount
products_collection.update_many(
{"price": {"$lt": 100}, "in_stock": True}, {"$set": {"discount": 0.1}}
)
# Then update the price
update_result = products_collection.update_many(
{"price": {"$lt": 100}, "in_stock": True}, {"$inc": {"price": -10}}
)
# Verify the update
updated_product = products_collection.find_one({"name": "Cheap Mouse"})
# Print debug information
print(f"Found expensive electronics: {len(expensive_electronics)}")
if expensive_electronics:
print(
f"First expensive product: {expensive_electronics[0].get('name')}"
)
print(f"Modified count: {update_result.modified_count}")
if updated_product:
print(f"Updated product price: {updated_product.get('price')}")
print(f"Updated product discount: {updated_product.get('discount')}")
# More flexible verification with approximate float comparison
success = (
len(expensive_electronics) >= 1
and expensive_electronics[0].get("name")
in ["Expensive Laptop", "Laptop"]
and update_result.modified_count >= 1
and updated_product is not None
and updated_product.get("discount", 0)
> 0 # Just check that discount exists and is positive
)
print(f"Test {'passed' if success else 'failed'}")
return success
except Exception as e:
print(f"Test failed with exception: {e}")
return False
def run_concurrent_operation_test(num_threads=100):
"""Run a simple operation in multiple threads to verify connection pooling."""
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
print(f"\nStarting concurrent operation test with {num_threads} threads...")
# Results tracking
results = {"passed": 0, "failed": 0, "errors": []}
results_lock = threading.Lock()
def worker(thread_id):
# Create a unique collection name for this thread
collection_name = f"concurrent_test_{thread_id}"
try:
# Generate unique data for this thread
unique_id = str(uuid.uuid4())
with mongo_handler.collection(collection_name) as collection:
# Insert a document
collection.insert_one(
{
"thread_id": thread_id,
"uuid": unique_id,
"timestamp": time.time(),
}
)
# Find the document
doc = collection.find_one({"thread_id": thread_id})
# Update the document
collection.update_one(
{"thread_id": thread_id}, {"$set": {"updated": True}}
)
# Verify update
updated_doc = collection.find_one({"thread_id": thread_id})
# Clean up
collection.delete_many({"thread_id": thread_id})
success = (
doc is not None
and updated_doc is not None
and updated_doc.get("updated") is True
)
# Update results with thread safety
with results_lock:
if success:
results["passed"] += 1
else:
results["failed"] += 1
results["errors"].append(f"Thread {thread_id} operation failed")
except Exception as e:
with results_lock:
results["failed"] += 1
results["errors"].append(f"Thread {thread_id} exception: {str(e)}")
# Create and start threads using a thread pool
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(worker, i) for i in range(num_threads)]
# Calculate execution time
execution_time = time.time() - start_time
# Print results
print(f"\nConcurrent Operation Test Results:")
print(f"Total threads: {num_threads}")
print(f"Passed: {results['passed']}")
print(f"Failed: {results['failed']}")
print(f"Execution time: {execution_time:.2f} seconds")
print(f"Operations per second: {num_threads / execution_time:.2f}")
if results["failed"] > 0:
print("\nErrors:")
for error in results["errors"][
:10
]: # Show only first 10 errors to avoid flooding output
print(f"- {error}")
if len(results["errors"]) > 10:
print(f"- ... and {len(results['errors']) - 10} more errors")
return results["failed"] == 0
def run_all_tests():
"""Run all MongoDB tests and report results."""
print("Starting MongoDB tests...")
# Clean up any existing test data before starting
cleanup_test_data()
tests = [
test_basic_crud_operations,
test_nested_documents,
test_array_operations,
test_aggregation,
test_index_operations,
test_complex_queries,
]
passed_list, not_passed_list = [], []
passed, failed = 0, 0
for test in tests:
# Clean up test data before each test
cleanup_test_data()
try:
if test():
passed += 1
passed_list.append(f"Test {test.__name__} passed")
else:
failed += 1
not_passed_list.append(f"Test {test.__name__} failed")
except Exception as e:
print(f"Test {test.__name__} failed with exception: {e}")
failed += 1
not_passed_list.append(f"Test {test.__name__} failed")
print(f"\nTest Results: {passed} passed, {failed} failed")
print("Passed Tests:")
print("\n".join(passed_list))
print("Failed Tests:")
print("\n".join(not_passed_list))
return passed, failed
if __name__ == "__main__":
mongo_handler = MongoDBHandler()
# Run standard tests first
passed, failed = run_all_tests()
# If all tests pass, run the concurrent operation test
if failed == 0:
run_concurrent_operation_test(10000)

View File

@ -0,0 +1,93 @@
"""
Test script for MongoDB handler with a local MongoDB instance.
"""
import os
from datetime import datetime
from .database import MongoDBHandler, CollectionContext
# Create a custom handler class for local testing
class LocalMongoDBHandler(MongoDBHandler):
"""A MongoDB handler for local testing without authentication."""
def __init__(self):
"""Initialize with a direct MongoDB URI."""
self._initialized = False
self.uri = "mongodb://localhost:27017/test"
self.client_options = {
"maxPoolSize": 5,
"minPoolSize": 2,
"maxIdleTimeMS": 30000,
"waitQueueTimeoutMS": 2000,
"serverSelectionTimeoutMS": 5000,
}
self._initialized = True
# Create a custom handler for local testing
def create_local_handler():
"""Create a MongoDB handler for local testing."""
# Create a fresh instance with direct MongoDB URI
handler = LocalMongoDBHandler()
return handler
def test_connection_monitoring():
"""Test connection monitoring with the MongoDB handler."""
print("\nTesting connection monitoring...")
# Create a local handler
local_handler = create_local_handler()
# Add connection tracking to the handler
local_handler._open_connections = 0
# Modify the CollectionContext class to track connections
original_enter = CollectionContext.__enter__
original_exit = CollectionContext.__exit__
def tracked_enter(self):
result = original_enter(self)
self.db_handler._open_connections += 1
print(f"Connection opened. Total open: {self.db_handler._open_connections}")
return result
def tracked_exit(self, exc_type, exc_val, exc_tb):
self.db_handler._open_connections -= 1
print(f"Connection closed. Total open: {self.db_handler._open_connections}")
return original_exit(self, exc_type, exc_val, exc_tb)
# Apply the tracking methods
CollectionContext.__enter__ = tracked_enter
CollectionContext.__exit__ = tracked_exit
try:
# Test with multiple operations
for i in range(3):
print(f"\nTest iteration {i+1}:")
try:
with local_handler.collection("test_collection") as collection:
# Try a simple operation
try:
collection.find_one({})
print("Operation succeeded")
except Exception as e:
print(f"Operation failed: {e}")
except Exception as e:
print(f"Connection failed: {e}")
# Final connection count
print(f"\nFinal open connections: {local_handler._open_connections}")
if local_handler._open_connections == 0:
print("✅ All connections were properly closed")
else:
print(f"{local_handler._open_connections} connections remain open")
finally:
# Restore original methods
CollectionContext.__enter__ = original_enter
CollectionContext.__exit__ = original_exit
if __name__ == "__main__":
test_connection_monitoring()

View File

@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Configs(BaseSettings):
"""
Postgresql configuration settings.
"""
DB: str = ""
USER: str = ""
PASSWORD: str = ""
HOST: str = ""
PORT: int = 0
ENGINE: str = "postgresql+psycopg2"
POOL_PRE_PING: bool = True
POOL_SIZE: int = 20
MAX_OVERFLOW: int = 10
POOL_RECYCLE: int = 600
POOL_TIMEOUT: int = 30
ECHO: bool = True
@property
def url(self):
"""Generate the database URL."""
return f"{self.ENGINE}://{self.USER}:{self.PASSWORD}@{self.HOST}:{self.PORT}/{self.DB}"
model_config = SettingsConfigDict(env_prefix="POSTGRES_")
# singleton instance of the POSTGRESQL configuration settings
postgres_configs = Configs()

View File

@ -0,0 +1,63 @@
from contextlib import contextmanager
from functools import lru_cache
from typing import Generator
from api_controllers.postgres.config import postgres_configs
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Session
# Configure the database engine with proper pooling
engine = create_engine(
postgres_configs.url,
pool_pre_ping=True,
pool_size=10, # Reduced from 20 to better match your CPU cores
max_overflow=5, # Reduced from 10 to prevent too many connections
pool_recycle=600, # Keep as is
pool_timeout=30, # Keep as is
echo=False, # Consider setting to False in production
)
Base = declarative_base()
# Create a cached session factory
@lru_cache()
def get_session_factory() -> scoped_session:
"""Create a thread-safe session factory."""
session_local = sessionmaker(
bind=engine,
autocommit=False,
autoflush=False,
expire_on_commit=True, # Prevent expired object issues
)
return scoped_session(session_local)
# Get database session with proper connection management
@contextmanager
def get_db() -> Generator[Session, None, None]:
"""Get database session with proper connection management.
This context manager ensures:
- Proper connection pooling
- Session cleanup
- Connection return to pool
- Thread safety
Yields:
Session: SQLAlchemy session object
"""
session_factory = get_session_factory()
session = session_factory()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
session_factory.remove() # Clean up the session from the registry

View File

@ -0,0 +1,275 @@
import arrow
import datetime
from decimal import Decimal
from typing import Any, TypeVar, Type, Union, Optional
from sqlalchemy import Column, Integer, String, Float, ForeignKey, UUID, TIMESTAMP, Boolean, SmallInteger, Numeric, func, text, NUMERIC, ColumnExpressionArgument
from sqlalchemy.orm import InstrumentedAttribute, Mapped, mapped_column, Query, Session
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy_mixins.serialize import SerializeMixin
from sqlalchemy_mixins.repr import ReprMixin
from sqlalchemy_mixins.smartquery import SmartQueryMixin
from sqlalchemy_mixins.activerecord import ActiveRecordMixin
from api_controllers.postgres.engine import get_db, Base
T = TypeVar("CrudMixin", bound="CrudMixin")
class BasicMixin(Base, ActiveRecordMixin, SerializeMixin, ReprMixin, SmartQueryMixin):
__abstract__ = True
__repr__ = ReprMixin.__repr__
@classmethod
def new_session(cls):
"""Get database session."""
return get_db()
@classmethod
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
"""
Process a field value based on its type and convert it to the appropriate format.
Args:
val: Field value
key: Field name
Returns:
Tuple of (should_include, processed_value)
"""
try:
key_ = cls.__annotations__.get(key, None)
is_primary = key in getattr(cls, "primary_keys", [])
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
# Skip primary keys and foreign keys
if is_primary or row_attr:
return False, None
if val is None: # Handle None values
return True, None
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
return True, str(val)
if key_: # Handle typed fields
if key_ == Mapped[int]:
return True, int(val)
elif key_ == Mapped[bool]:
return True, bool(val)
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
return True, round(float(val), 3)
elif key_ == Mapped[TIMESTAMP]:
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
elif key_ == Mapped[str]:
return True, str(val)
else: # Handle based on Python types
if isinstance(val, datetime.datetime):
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
elif isinstance(val, bool):
return True, bool(val)
elif isinstance(val, (float, Decimal)):
return True, round(float(val), 3)
elif isinstance(val, int):
return True, int(val)
elif isinstance(val, str):
return True, str(val)
elif val is None:
return True, None
return False, None
except Exception as e:
err = e
return False, None
@classmethod
def convert(cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None) -> Optional[tuple[BinaryExpression, ...]]:
"""
Convert smart options to SQLAlchemy filter expressions.
Args:
smart_options: Dictionary of filter options
validate_model: Optional model to validate against
Returns:
Tuple of SQLAlchemy filter expressions or None if validation fails
"""
try:
# Let SQLAlchemy handle the validation by attempting to create the filter expressions
return tuple(cls.filter_expr(**smart_options))
except Exception as e:
# If there's an error, provide a helpful message with valid columns and relationships
valid_columns = set()
relationship_names = set()
# Get column names if available
if hasattr(cls, '__table__') and hasattr(cls.__table__, 'columns'):
valid_columns = set(column.key for column in cls.__table__.columns)
# Get relationship names if available
if hasattr(cls, '__mapper__') and hasattr(cls.__mapper__, 'relationships'):
relationship_names = set(rel.key for rel in cls.__mapper__.relationships)
# Create a helpful error message
error_msg = f"Error in filter expression: {str(e)}\n"
error_msg += f"Attempted to filter with: {smart_options}\n"
error_msg += f"Valid columns are: {', '.join(valid_columns)}\n"
error_msg += f"Valid relationships are: {', '.join(relationship_names)}"
raise ValueError(error_msg) from e
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> dict[str, Any]:
"""
Convert model instance to dictionary with customizable fields.
Args:
exclude_list: List of fields to exclude from the dictionary
Returns:
Dictionary representation of the model
"""
try:
return_dict: Dict[str, Any] = {}
exclude_list = exclude_list or []
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
# Get all column names from the model
columns = [col.name for col in self.__table__.columns]
columns_set = set(columns)
# Filter columns
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
columns_extend = set(col for col in columns_set if str(col)[-5:].lower() == "uu_id")
columns_list = set(columns_list) | set(columns_extend)
columns_list = list(set(columns_list) - set(exclude_list))
for key in columns_list:
val = getattr(self, key)
correct, value_of_database = self.iterate_over_variables(val, key)
if correct:
return_dict[key] = value_of_database
return return_dict
except Exception as e:
err = e
return {}
class CrudMixin(BasicMixin):
"""
Base mixin providing CRUD operations and common fields for PostgreSQL models.
Features:
- Automatic timestamps (created_at, updated_at)
- Soft delete capability
- User tracking (created_by, updated_by)
- Data serialization
- Multi-language support
"""
__abstract__ = True
# Primary and reference fields
id: Mapped[int] = mapped_column(Integer, primary_key=True)
uu_id: Mapped[str] = mapped_column(
UUID,
server_default=text("gen_random_uuid()"),
index=True,
unique=True,
comment="Unique identifier UUID",
)
# Common timestamp fields for all models
expiry_starts: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
comment="Record validity start timestamp",
)
expiry_ends: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
default=str(arrow.get("2099-12-31")),
server_default=func.now(),
comment="Record validity end timestamp",
)
# Timestamps
created_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
nullable=False,
index=True,
comment="Record creation timestamp",
)
updated_at: Mapped[TIMESTAMP] = mapped_column(
TIMESTAMP(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
index=True,
comment="Last update timestamp",
)
class CrudCollection(CrudMixin):
"""
Full-featured model class with all common fields.
Includes:
- UUID and reference ID
- Timestamps
- User tracking
- Confirmation status
- Soft delete
- Notification flags
"""
__abstract__ = True
__repr__ = ReprMixin.__repr__
# Outer reference fields
ref_id: Mapped[str] = mapped_column(
String(100), nullable=True, index=True, comment="External reference ID"
)
replication_id: Mapped[int] = mapped_column(
SmallInteger, server_default="0", comment="Replication identifier"
)
# Cryptographic and user tracking
cryp_uu_id: Mapped[str] = mapped_column(
String, nullable=True, index=True, comment="Cryptographic UUID"
)
# Token fields of modification
created_credentials_token: Mapped[str] = mapped_column(
String, nullable=True, comment="Created Credentials token"
)
updated_credentials_token: Mapped[str] = mapped_column(
String, nullable=True, comment="Updated Credentials token"
)
confirmed_credentials_token: Mapped[str] = mapped_column(
String, nullable=True, comment="Confirmed Credentials token"
)
# Status flags
is_confirmed: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Record confirmation status"
)
deleted: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Soft delete flag"
)
active: Mapped[bool] = mapped_column(
Boolean, server_default="1", comment="Record active status"
)
is_notification_send: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Notification sent flag"
)
is_email_send: Mapped[bool] = mapped_column(
Boolean, server_default="0", comment="Email sent flag"
)

View File

@ -0,0 +1,15 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Configs(BaseSettings):
"""
ApiTemplate configuration settings.
"""
ACCESS_TOKEN_LENGTH: int = 90
REFRESHER_TOKEN_LENGTH: int = 144
model_config = SettingsConfigDict(env_prefix="API_")
token_config = Configs()

View File

@ -0,0 +1,40 @@
import hashlib
import uuid
import secrets
import random
from .config import token_config
class PasswordModule:
@staticmethod
def generate_random_uu_id(str_std: bool = True):
return str(uuid.uuid4()) if str_std else uuid.uuid4()
@staticmethod
def generate_token(length=32) -> str:
letters = "abcdefghijklmnopqrstuvwxyz"
merged_letters = [letter for letter in letters] + [letter.upper() for letter in letters]
token_generated = secrets.token_urlsafe(length)
for i in str(token_generated):
if i not in merged_letters:
token_generated = token_generated.replace(i, random.choice(merged_letters), 1)
return token_generated
raise ValueError("EYS_0004")
@classmethod
def generate_access_token(cls) -> str:
return cls.generate_token(int(token_config.ACCESS_TOKEN_LENGTH))
@classmethod
def generate_refresher_token(cls) -> str:
return cls.generate_token(int(token_config.REFRESHER_TOKEN_LENGTH))
@staticmethod
def create_hashed_password(domain: str, id_: str, password: str) -> str:
return hashlib.sha256(f"{domain}:{id_}:{password}".encode("utf-8")).hexdigest()
@classmethod
def check_password(cls, domain, id_, password, password_hashed) -> bool:
return cls.create_hashed_password(domain, id_, password) == password_hashed

View File

@ -0,0 +1,14 @@
import uvicorn
from api_initializer.config import api_config
from api_initializer.create_app import create_app
# from prometheus_fastapi_instrumentator import Instrumentator
app = create_app() # Create FastAPI application
# Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics
if __name__ == "__main__":
uvicorn_config = uvicorn.Config(**api_config.app_as_dict, workers=1) # Run the application with Uvicorn Server
uvicorn.Server(uvicorn_config).run()

View File

@ -0,0 +1,64 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from fastapi.responses import JSONResponse
class Configs(BaseSettings):
"""
ApiTemplate configuration settings.
"""
PATH: str = ""
HOST: str = ""
PORT: int = 0
LOG_LEVEL: str = "info"
RELOAD: int = 0
ACCESS_TOKEN_TAG: str = ""
ACCESS_EMAIL_EXT: str = ""
TITLE: str = ""
ALGORITHM: str = ""
ACCESS_TOKEN_LENGTH: int = 90
REFRESHER_TOKEN_LENGTH: int = 144
EMAIL_HOST: str = ""
DATETIME_FORMAT: str = ""
FORGOT_LINK: str = ""
ALLOW_ORIGINS: list = ["http://localhost:3000", "http://localhost:3001", "http://localhost:3001/api", "http://localhost:3001/api/"]
VERSION: str = "0.1.001"
DESCRIPTION: str = ""
@property
def app_as_dict(self) -> dict:
"""
Convert the settings to a dictionary.
"""
return {
"app": self.PATH,
"host": self.HOST,
"port": int(self.PORT),
"log_level": self.LOG_LEVEL,
"reload": bool(self.RELOAD),
}
@property
def api_info(self):
"""
Returns a dictionary with application information.
"""
return {
"title": self.TITLE,
"description": self.DESCRIPTION,
"default_response_class": JSONResponse,
"version": self.VERSION,
}
@classmethod
def forgot_link(cls, forgot_key):
"""
Generate a forgot password link.
"""
return cls.FORGOT_LINK + forgot_key
model_config = SettingsConfigDict(env_prefix="API_")
api_config = Configs()

View File

@ -0,0 +1,58 @@
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from event_clusters import RouterCluster, EventCluster
from config import api_config
from open_api_creator import create_openapi_schema
from create_route import RouteRegisterController
from api_middlewares.token_middleware import token_middleware
from endpoints.routes import get_routes
import events
cluster_is_set = False
def create_app():
def create_events_if_any_cluster_set():
global cluster_is_set
if not events.__all__ or cluster_is_set:
return
router_cluster_stack: list[RouterCluster] = [getattr(events, e, None) for e in events.__all__]
for router_cluster in router_cluster_stack:
event_cluster_stack: list[EventCluster] = list(router_cluster.event_clusters.values())
for event_cluster in event_cluster_stack:
try:
event_cluster.set_events_to_database()
except Exception as e:
print(f"Error creating event cluster: {e}")
cluster_is_set = True
application = FastAPI(**api_config.api_info)
application.add_middleware(
CORSMiddleware,
allow_origins=api_config.ALLOW_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@application.middleware("http")
async def add_token_middleware(request: Request, call_next):
return await token_middleware(request, call_next)
@application.get("/", description="Redirect Route", include_in_schema=False)
async def redirect_to_docs():
return RedirectResponse(url="/docs")
route_register = RouteRegisterController(app=application, router_list=get_routes())
application = route_register.register_routes()
create_events_if_any_cluster_set()
application.openapi = lambda _=application: create_openapi_schema(_)
return application

View File

@ -0,0 +1,41 @@
from typing import List
from fastapi import APIRouter, FastAPI
class RouteRegisterController:
def __init__(self, app: FastAPI, router_list: List[APIRouter]):
self.router_list = router_list
self.app = app
@staticmethod
def add_router_with_event_to_database(router: APIRouter):
from schemas import EndpointRestriction
with EndpointRestriction.new_session() as db_session:
EndpointRestriction.set_session(db_session)
for route in router.routes:
route_path = str(getattr(route, "path"))
route_summary = str(getattr(route, "name"))
operation_id = getattr(route, "operation_id", None)
if not operation_id:
raise ValueError(f"Route {route_path} operation_id is not found")
if not getattr(route, "methods") and isinstance(getattr(route, "methods")):
raise ValueError(f"Route {route_path} methods is not found")
route_method = [method.lower() for method in getattr(route, "methods")][0]
add_or_update_dict = dict(
endpoint_method=route_method, endpoint_name=route_path, endpoint_desc=route_summary.replace("_", " "), endpoint_function=route_summary, is_confirmed=True
)
if to_save_endpoint := EndpointRestriction.query.filter(EndpointRestriction.operation_uu_id == operation_id).first():
to_save_endpoint.update(**add_or_update_dict)
to_save_endpoint.save()
else:
created_endpoint = EndpointRestriction.create(**add_or_update_dict, operation_uu_id=operation_id)
created_endpoint.save()
def register_routes(self):
for router in self.router_list:
self.app.include_router(router)
self.add_router_with_event_to_database(router)
return self.app

View File

@ -0,0 +1,122 @@
from typing import Optional, Type
from pydantic import BaseModel
class EventCluster:
"""
EventCluster
"""
def __repr__(self):
return f"EventCluster(name={self.name})"
def __init__(self, endpoint_uu_id: str, name: str):
self.endpoint_uu_id = endpoint_uu_id
self.name = name
self.events: list["Event"] = []
def add_event(self, event: "Event"):
"""
Add an event to the cluster
"""
if event.key not in [e.key for e in self.events]:
self.events.append(event)
def get_event(self, event_key: str):
"""
Get an event by its key
"""
for event in self.events:
if event.key == event_key:
return event
return None
def set_events_to_database(self):
from schemas import Events, EndpointRestriction
with Events.new_session() as db_session:
Events.set_session(db_session)
EndpointRestriction.set_session(db_session)
if to_save_endpoint := EndpointRestriction.query.filter(EndpointRestriction.operation_uu_id == self.endpoint_uu_id).first():
print('to_save_endpoint', to_save_endpoint)
for event in self.events:
event_dict_to_save = dict(
function_code=event.key,
function_class=event.name,
description=event.description,
endpoint_code=self.endpoint_uu_id,
endpoint_id=to_save_endpoint.id,
endpoint_uu_id=str(to_save_endpoint.uu_id),
is_confirmed=True,
)
print('set_events_to_database event_dict_to_save', event_dict_to_save)
check_event = Events.query.filter(Events.endpoint_uu_id == event_dict_to_save["endpoint_uu_id"]).first()
if check_event:
check_event.update(**event_dict_to_save)
check_event.save()
else:
event_created = Events.create(**event_dict_to_save)
print(f"UUID: {event_created.uu_id} event is saved to {to_save_endpoint.uu_id}")
event_created.save()
def match_event(self, event_key: str) -> "Event":
"""
Match an event by its key
"""
if event := self.get_event(event_key=event_key):
return event
raise ValueError("Event key not found")
class Event:
def __init__(
self,
name: str,
key: str,
request_validator: Optional[Type[BaseModel]] = None,
response_validator: Optional[Type[BaseModel]] = None,
description: str = "",
):
self.name = name
self.key = key
self.request_validator = request_validator
self.response_validator = response_validator
self.description = description
def event_callable(self):
"""
Example callable method
"""
print(self.name)
return {}
class RouterCluster:
"""
RouterCluster
"""
def __repr__(self):
return f"RouterCluster(name={self.name})"
def __init__(self, name: str):
self.name = name
self.event_clusters: dict[str, EventCluster] = {}
def set_event_cluster(self, event_cluster: EventCluster):
"""
Add an event cluster to the set
"""
print("Setting event cluster:", event_cluster.name)
if event_cluster.name not in self.event_clusters:
self.event_clusters[event_cluster.name] = event_cluster
def get_event_cluster(self, event_cluster_name: str) -> EventCluster:
"""
Get an event cluster by its name
"""
if event_cluster_name not in self.event_clusters:
raise ValueError("Event cluster not found")
return self.event_clusters[event_cluster_name]

View File

@ -0,0 +1,116 @@
from typing import Any, Dict
from fastapi import FastAPI
from fastapi.routing import APIRoute
from fastapi.openapi.utils import get_openapi
from endpoints.routes import get_safe_endpoint_urls
from config import api_config
class OpenAPISchemaCreator:
"""
OpenAPI schema creator and customizer for FastAPI applications.
"""
def __init__(self, app: FastAPI):
"""
Initialize the OpenAPI schema creator.
Args:
app: FastAPI application instance
"""
self.app = app
self.safe_endpoint_list: list[tuple[str, str]] = get_safe_endpoint_urls()
self.routers_list = self.app.routes
@staticmethod
def create_security_schemes() -> Dict[str, Any]:
"""
Create security scheme definitions.
Returns:
Dict[str, Any]: Security scheme configurations
"""
return {
"BearerAuth": {
"type": "apiKey",
"in": "header",
"name": api_config.ACCESS_TOKEN_TAG,
"description": "Enter: **'Bearer &lt;JWT&gt;'**, where JWT is the access token",
}
}
def configure_route_security(
self, path: str, method: str, schema: Dict[str, Any]
) -> None:
"""
Configure security requirements for a specific route.
Args:
path: Route path
method: HTTP method
schema: OpenAPI schema to modify
"""
if not schema.get("paths", {}).get(path, {}).get(method):
return
# Check if endpoint is in safe list
endpoint_path = f"{path}:{method}"
list_of_safe_endpoints = [
f"{e[0]}:{str(e[1]).lower()}" for e in self.safe_endpoint_list
]
if endpoint_path not in list_of_safe_endpoints:
if "security" not in schema["paths"][path][method]:
schema["paths"][path][method]["security"] = []
schema["paths"][path][method]["security"].append({"BearerAuth": []})
def create_schema(self) -> Dict[str, Any]:
"""
Create the complete OpenAPI schema.
Returns:
Dict[str, Any]: Complete OpenAPI schema
"""
openapi_schema = get_openapi(
title=api_config.TITLE,
description=api_config.DESCRIPTION,
version=api_config.VERSION,
routes=self.app.routes,
)
# Add security schemes
if "components" not in openapi_schema:
openapi_schema["components"] = {}
openapi_schema["components"]["securitySchemes"] = self.create_security_schemes()
# Configure route security and responses
for route in self.app.routes:
if isinstance(route, APIRoute) and route.include_in_schema:
path = str(route.path)
methods = [method.lower() for method in route.methods]
for method in methods:
self.configure_route_security(path, method, openapi_schema)
# Add custom documentation extensions
openapi_schema["x-documentation"] = {
"postman_collection": "/docs/postman",
"swagger_ui": "/docs",
"redoc": "/redoc",
}
return openapi_schema
def create_openapi_schema(app: FastAPI) -> Dict[str, Any]:
"""
Create OpenAPI schema for a FastAPI application.
Args:
app: FastAPI application instance
Returns:
Dict[str, Any]: Complete OpenAPI schema
"""
creator = OpenAPISchemaCreator(app)
return creator.create_schema()

View File

@ -0,0 +1,8 @@
if __name__ == "__main__":
import time
while True:
time.sleep(10)

View File

@ -0,0 +1,2 @@
result_with_keys_dict {'Employees.id': 3, 'Employees.uu_id': UUID('2b757a5c-01bb-4213-9cf1-402480e73edc'), 'People.id': 2, 'People.uu_id': UUID('d945d320-2a4e-48db-a18c-6fd024beb517'), 'Users.id': 3, 'Users.uu_id': UUID('bdfa84d9-0e05-418c-9406-d6d1d41ae2a1'), 'Companies.id': 1, 'Companies.uu_id': UUID('da1de172-2f89-42d2-87f3-656b36a79d5b'), 'Departments.id': 3, 'Departments.uu_id': UUID('4edcec87-e072-408d-a780-3a62151b3971'), 'Duty.id': 9, 'Duty.uu_id': UUID('00d29292-c29e-4435-be41-9704ccf4b24d'), 'Addresses.id': None, 'Addresses.letter_address': None}

View File

@ -0,0 +1,56 @@
from fastapi import Header, Request, Response
from pydantic import BaseModel
from config import api_config
class CommonHeaders(BaseModel):
language: str | None = None
domain: str | None = None
timezone: str | None = None
token: str | None = None
request: Request | None = None
response: Response | None = None
operation_id: str | None = None
model_config = {
"arbitrary_types_allowed": True
}
@classmethod
def as_dependency(
cls,
request: Request,
response: Response,
language: str = Header(None, alias="language"),
domain: str = Header(None, alias="domain"),
tz: str = Header(None, alias="timezone"),
):
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
# Extract operation_id from the route
operation_id = None
if hasattr(request.scope.get("route"), "operation_id"):
operation_id = request.scope.get("route").operation_id
return cls(
language=language,
domain=domain,
timezone=tz,
token=token,
request=request,
response=response,
operation_id=operation_id,
)
def get_headers_dict(self):
"""Convert the headers to a dictionary format used in the application"""
import uuid
return {
"language": self.language or "",
"domain": self.domain or "",
"eys-ext": f"{str(uuid.uuid4())}",
"tz": self.timezone or "GMT+3",
"token": self.token,
}

View File

@ -0,0 +1,17 @@
class BaseModelCore(BaseModel):
"""
BaseModelCore
model_dump override for alias support Users.name -> Table[Users] Field(alias="name")
"""
__abstract__ = True
class Config:
validate_by_name = True
use_enum_values = True
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
return {self.__class__.model_fields[field].alias: value for field, value in data.items()}

View File

@ -0,0 +1,4 @@
from .pagination import PaginateOnly, ListOptions, PaginationConfig
from .result import Pagination, PaginationResult
from .base import PostgresResponseSingle, PostgresResponse, ResultQueryJoin, ResultQueryJoinSingle
from .api import EndpointResponse, CreateEndpointResponse

View File

@ -0,0 +1,60 @@
from .result import PaginationResult
from .base import PostgresResponseSingle
from pydantic import BaseModel
from typing import Any, Type
class EndpointResponse(BaseModel):
"""Endpoint response model."""
completed: bool = True
message: str = "Success"
pagination_result: PaginationResult
@property
def response(self):
"""Convert response to dictionary format."""
result_data = getattr(self.pagination_result, "data", None)
if not result_data:
return {
"completed": False,
"message": "MSG0004-NODATA",
"data": None,
"pagination": None,
}
result_pagination = getattr(self.pagination_result, "pagination", None)
if not result_pagination:
raise ValueError("Invalid pagination result pagination.")
pagination_dict = getattr(result_pagination, "as_dict", None)
if not pagination_dict:
raise ValueError("Invalid pagination result as_dict.")
return {
"completed": self.completed,
"message": self.message,
"data": result_data,
"pagination": pagination_dict,
}
model_config = {
"arbitrary_types_allowed": True
}
class CreateEndpointResponse(BaseModel):
"""Create endpoint response model."""
completed: bool = True
message: str = "Success"
data: PostgresResponseSingle
@property
def response(self):
"""Convert response to dictionary format."""
return {
"completed": self.completed,
"message": self.message,
"data": self.data.data,
}
model_config = {
"arbitrary_types_allowed": True
}

View File

@ -0,0 +1,193 @@
"""
Response handler for PostgreSQL query results.
This module provides a wrapper class for SQLAlchemy query results,
adding convenience methods for accessing data and managing query state.
"""
from typing import Any, Dict, Optional, TypeVar, Generic, Union
from pydantic import BaseModel
from sqlalchemy.orm import Query
T = TypeVar("T")
class PostgresResponse(Generic[T]):
"""
Wrapper for PostgreSQL/SQLAlchemy query results.
Properties:
count: Total count of results
query: Get query object
as_dict: Convert response to dictionary format
"""
def __init__(self, query: Query, base_model: Optional[BaseModel] = None):
self._query = query
self._count: Optional[int] = None
self._base_model: Optional[BaseModel] = base_model
self.single = False
@property
def query(self) -> Query:
"""Get query object."""
return self._query
@property
def data(self) -> Union[list[T], T]:
"""Get query object."""
return self._query.all()
@property
def count(self) -> int:
"""Get query object."""
return self._query.count()
@property
def to_dict(self, **kwargs) -> list[dict]:
"""Get query object."""
if self._base_model:
return [self._base_model(**item.to_dict()).model_dump(**kwargs) for item in self.data]
return [item.to_dict() for item in self.data]
@property
def as_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary format."""
return {
"query": str(self.query),
"count": self.count,
"data": self.to_dict,
}
class PostgresResponseSingle(Generic[T]):
"""
Wrapper for PostgreSQL/SQLAlchemy query results.
Properties:
count: Total count of results
query: Get query object
as_dict: Convert response to dictionary format
data: Get query object
"""
def __init__(self, query: Query, base_model: Optional[BaseModel] = None):
self._query = query
self._count: Optional[int] = None
self._base_model: Optional[BaseModel] = base_model
self.single = True
@property
def query(self) -> Query:
"""Get query object."""
return self._query
@property
def to_dict(self, **kwargs) -> dict:
"""Get query object."""
if self._base_model:
return self._base_model(**self._query.first().to_dict()).model_dump(**kwargs)
return self._query.first().to_dict()
@property
def data(self) -> T:
"""Get query object."""
return self._query.first()
@property
def count(self) -> int:
"""Get query object."""
return self._query.count()
@property
def as_dict(self) -> Dict[str, Any]:
"""Convert response to dictionary format."""
return {"query": str(self.query),"data": self.to_dict, "count": self.count}
class ResultQueryJoin:
"""
ResultQueryJoin
params:
list_of_instrumented_attributes: list of instrumented attributes
query: query object
"""
def __init__(self, list_of_instrumented_attributes, query):
"""Initialize ResultQueryJoin"""
self.list_of_instrumented_attributes = list_of_instrumented_attributes
self._query = query
@property
def query(self):
"""Get query object."""
return self._query
@property
def to_dict(self):
"""Convert response to dictionary format."""
list_of_dictionaries, result = [], dict()
for user_orders_shipping_iter in self.query.all():
for index, instrumented_attribute_iter in enumerate(self.list_of_instrumented_attributes):
result[str(instrumented_attribute_iter)] = user_orders_shipping_iter[index]
list_of_dictionaries.append(result)
return list_of_dictionaries
@property
def count(self):
"""Get count of query."""
return self.query.count()
@property
def data(self):
"""Get query object."""
return self.query.all()
@property
def as_dict(self):
"""Convert response to dictionary format."""
return {"query": str(self.query), "data": self.data, "count": self.count}
class ResultQueryJoinSingle:
"""
ResultQueryJoinSingle
params:
list_of_instrumented_attributes: list of instrumented attributes
query: query object
"""
def __init__(self, list_of_instrumented_attributes, query):
"""Initialize ResultQueryJoinSingle"""
self.list_of_instrumented_attributes = list_of_instrumented_attributes
self._query = query
@property
def query(self):
"""Get query object."""
return self._query
@property
def to_dict(self):
"""Convert response to dictionary format."""
data, result = self.query.first(), dict()
for index, instrumented_attribute_iter in enumerate(self.list_of_instrumented_attributes):
result[str(instrumented_attribute_iter)] = data[index]
return result
@property
def count(self):
"""Get count of query."""
return self.query.count()
@property
def data(self):
"""Get query object."""
return self._query.first()
@property
def as_dict(self):
"""Convert response to dictionary format."""
return {"query": str(self.query), "data": self.data, "count": self.count}

View File

@ -0,0 +1,19 @@
class UserPydantic(BaseModel):
username: str = Field(..., alias='user.username')
account_balance: float = Field(..., alias='user.account_balance')
preferred_category_id: Optional[int] = Field(None, alias='user.preferred_category_id')
last_ordered_product_id: Optional[int] = Field(None, alias='user.last_ordered_product_id')
supplier_rating_id: Optional[int] = Field(None, alias='user.supplier_rating_id')
other_rating_id: Optional[int] = Field(None, alias='product.supplier_rating_id')
id: int = Field(..., alias='user.id')
class Config:
validate_by_name = True
use_enum_values = True
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
return {self.__class__.model_fields[field].alias: value for field, value in data.items()}

View File

@ -0,0 +1,70 @@
from typing import Any, Dict, Optional, Union, TypeVar, Type
from sqlalchemy import desc, asc
from pydantic import BaseModel
from .base import PostgresResponse
# Type variable for class methods returning self
T = TypeVar("T", bound="BaseModel")
class PaginateConfig:
"""
Configuration for pagination settings.
Attributes:
DEFAULT_SIZE: Default number of items per page (10)
MIN_SIZE: Minimum allowed page size (10)
MAX_SIZE: Maximum allowed page size (40)
"""
DEFAULT_SIZE = 10
MIN_SIZE = 5
MAX_SIZE = 100
class ListOptions(BaseModel):
"""
Query for list option abilities
"""
page: Optional[int] = 1
size: Optional[int] = 10
orderField: Optional[Union[tuple[str], list[str]]] = ["uu_id"]
orderType: Optional[Union[tuple[str], list[str]]] = ["asc"]
# include_joins: Optional[list] = None
class PaginateOnly(ListOptions):
"""
Query for list option abilities
"""
query: Optional[dict] = None
class PaginationConfig(BaseModel):
"""
Configuration for pagination settings.
Attributes:
page: Current page number (default: 1)
size: Items per page (default: 10)
orderField: Field to order by (default: "created_at")
orderType: Order direction (default: "desc")
"""
page: int = 1
size: int = 10
orderField: Optional[Union[tuple[str], list[str]]] = ["created_at"]
orderType: Optional[Union[tuple[str], list[str]]] = ["desc"]
def __init__(self, **data):
super().__init__(**data)
if self.orderField is None:
self.orderField = ["created_at"]
if self.orderType is None:
self.orderType = ["desc"]
default_paginate_config = PaginateConfig()

View File

@ -0,0 +1,180 @@
from typing import Optional, Union, Type, Any, Dict, TypeVar
from pydantic import BaseModel
from sqlalchemy.orm import Query
from sqlalchemy import asc, desc
from .pagination import default_paginate_config
from .base import PostgresResponse
from .pagination import PaginationConfig
T = TypeVar("T")
class Pagination:
"""
Handles pagination logic for query results.
Manages page size, current page, ordering, and calculates total pages
and items based on the data source.
Attributes:
DEFAULT_SIZE: Default number of items per page (10)
MIN_SIZE: Minimum allowed page size (10)
MAX_SIZE: Maximum allowed page size (40)
"""
DEFAULT_SIZE = default_paginate_config.DEFAULT_SIZE
MIN_SIZE = default_paginate_config.MIN_SIZE
MAX_SIZE = default_paginate_config.MAX_SIZE
def __init__(self, data: PostgresResponse):
self.query = data
self.size: int = self.DEFAULT_SIZE
self.page: int = 1
self.orderField: Optional[Union[tuple[str], list[str]]] = ["uu_id"]
self.orderType: Optional[Union[tuple[str], list[str]]] = ["asc"]
self.page_count: int = 1
self.total_count: int = 0
self.all_count: int = 0
self.total_pages: int = 1
self._update_page_counts()
def change(self, **kwargs) -> None:
"""Update pagination settings from config."""
config = PaginationConfig(**kwargs)
self.size = (
config.size
if self.MIN_SIZE <= config.size <= self.MAX_SIZE
else self.DEFAULT_SIZE
)
self.page = config.page
self.orderField = config.orderField
self.orderType = config.orderType
self._update_page_counts()
def feed(self, data: PostgresResponse) -> None:
"""Calculate pagination based on data source."""
self.query = data
self._update_page_counts()
def _update_page_counts(self) -> None:
"""Update page counts and validate current page."""
if self.query:
self.total_count = self.query.count()
self.all_count = self.query.count()
self.size = (
self.size
if self.MIN_SIZE <= self.size <= self.MAX_SIZE
else self.DEFAULT_SIZE
)
self.total_pages = max(1, (self.total_count + self.size - 1) // self.size)
self.page = max(1, min(self.page, self.total_pages))
self.page_count = (
self.total_count % self.size
if self.page == self.total_pages and self.total_count % self.size
else self.size
)
def refresh(self) -> None:
"""Reset pagination state to defaults."""
self._update_page_counts()
def reset(self) -> None:
"""Reset pagination state to defaults."""
self.size = self.DEFAULT_SIZE
self.page = 1
self.orderField = "uu_id"
self.orderType = "asc"
@property
def next_available(self) -> bool:
if self.page < self.total_pages:
return True
return False
@property
def back_available(self) -> bool:
if self.page > 1:
return True
return False
@property
def as_dict(self) -> Dict[str, Any]:
"""Convert pagination state to dictionary format."""
self.refresh()
return {
"size": self.size,
"page": self.page,
"allCount": self.all_count,
"totalCount": self.total_count,
"totalPages": self.total_pages,
"pageCount": self.page_count,
"orderField": self.orderField,
"orderType": self.orderType,
"next": self.next_available,
"back": self.back_available,
}
class PaginationResult:
"""
Result of a paginated query.
Contains the query result and pagination state.
data: PostgresResponse of query results
pagination: Pagination state
Attributes:
_query: Original query object
pagination: Pagination state
"""
def __init__(
self,
data: PostgresResponse,
pagination: Pagination,
is_list: bool = True,
response_model: Type[T] = None,
):
self._query = data
self.pagination = pagination
self.response_type = is_list
self.limit = self.pagination.size
self.offset = self.pagination.size * (self.pagination.page - 1)
self.order_by = self.pagination.orderField
self.response_model = response_model
def dynamic_order_by(self):
"""
Dynamically order a query by multiple fields.
Returns:
Ordered query object.
"""
if not len(self.order_by) == len(self.pagination.orderType):
raise ValueError(
"Order by fields and order types must have the same length."
)
order_criteria = zip(self.order_by, self.pagination.orderType)
for field, direction in order_criteria:
if hasattr(self._query.column_descriptions[0]["entity"], field):
if direction.lower().startswith("d"):
self._query = self._query.order_by(
desc(
getattr(self._query.column_descriptions[0]["entity"], field)
)
)
else:
self._query = self._query.order_by(
asc(
getattr(self._query.column_descriptions[0]["entity"], field)
)
)
return self._query
@property
def data(self) -> Union[list | dict]:
"""Get query object."""
query_paginated = self.dynamic_order_by().limit(self.limit).offset(self.offset)
queried_data = (query_paginated.all() if self.response_type else query_paginated.first())
data = ([result.get_dict() for result in queried_data] if self.response_type else queried_data.get_dict())
return [self.response_model(**item).model_dump() for item in data] if self.response_model else data

View File

@ -0,0 +1,123 @@
from enum import Enum
from pydantic import BaseModel
from typing import Optional, Union
class UserType(Enum):
employee = 1
occupant = 2
class Credentials(BaseModel):
person_id: int
person_name: str
class ApplicationToken(BaseModel):
# Application Token Object -> is the main object for the user
user_type: int = UserType.occupant.value
credential_token: str = ""
user_uu_id: str
user_id: int
person_id: int
person_uu_id: str
request: Optional[dict] = None # Request Info of Client
expires_at: Optional[float] = None # Expiry timestamp
class OccupantToken(BaseModel):
# Selection of the occupant type for a build part is made by the user
living_space_id: int # Internal use
living_space_uu_id: str # Outer use
occupant_type_id: int
occupant_type_uu_id: str
occupant_type: str
build_id: int
build_uuid: str
build_part_id: int
build_part_uuid: str
responsible_company_id: Optional[int] = None
responsible_company_uuid: Optional[str] = None
responsible_employee_id: Optional[int] = None
responsible_employee_uuid: Optional[str] = None
# ID list of reachable event codes as "endpoint_code": ["UUID", "UUID"]
reachable_event_codes: Optional[dict[str, str]] = None
# ID list of reachable applications as "page_url": ["UUID", "UUID"]
reachable_app_codes: Optional[dict[str, str]] = None
class CompanyToken(BaseModel):
# Selection of the company for an employee is made by the user
company_id: int
company_uu_id: str
department_id: int # ID list of departments
department_uu_id: str # ID list of departments
duty_id: int
duty_uu_id: str
staff_id: int
staff_uu_id: str
employee_id: int
employee_uu_id: str
bulk_duties_id: int
# ID list of reachable event codes as "endpoint_code": ["UUID", "UUID"]
reachable_event_codes: Optional[dict[str, str]] = None
# ID list of reachable applications as "page_url": ["UUID", "UUID"]
reachable_app_codes: Optional[dict[str, str]] = None
class OccupantTokenObject(ApplicationToken):
# Occupant Token Object -> Requires selection of the occupant type for a specific build part
available_occupants: dict = None
selected_occupant: Optional[OccupantToken] = None # Selected Occupant Type
@property
def is_employee(self) -> bool:
return False
@property
def is_occupant(self) -> bool:
return True
class EmployeeTokenObject(ApplicationToken):
# Full hierarchy Employee[staff_id] -> Staff -> Duty -> Department -> Company
companies_id_list: list[int] # List of company objects
companies_uu_id_list: list[str] # List of company objects
duty_id_list: list[int] # List of duty objects
duty_uu_id_list: list[str] # List of duty objects
selected_company: Optional[CompanyToken] = None # Selected Company Object
@property
def is_employee(self) -> bool:
return True
@property
def is_occupant(self) -> bool:
return False
TokenDictType = Union[EmployeeTokenObject, OccupantTokenObject]