updated services api
This commit is contained in:
parent
e5829f0525
commit
f8184246d9
|
|
@ -2,7 +2,7 @@ import arrow
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
from config import api_config
|
from config import api_config
|
||||||
from schemas import (
|
from Schemas import (
|
||||||
Users,
|
Users,
|
||||||
People,
|
People,
|
||||||
BuildLivingSpace,
|
BuildLivingSpace,
|
||||||
|
|
@ -23,11 +23,11 @@ from schemas import (
|
||||||
Events,
|
Events,
|
||||||
EndpointRestriction,
|
EndpointRestriction,
|
||||||
)
|
)
|
||||||
from api_modules.token.password_module import PasswordModule
|
from Controllers.mongo.database import mongo_handler
|
||||||
from api_controllers.mongo.database import mongo_handler
|
from Validations.token.validations import TokenDictType, EmployeeTokenObject, OccupantTokenObject, CompanyToken, OccupantToken, UserType
|
||||||
from api_validations.token.validations import TokenDictType, EmployeeTokenObject, OccupantTokenObject, CompanyToken, OccupantToken, UserType
|
from Validations.defaults.validations import CommonHeaders
|
||||||
from api_validations.defaults.validations import CommonHeaders
|
from Extends.redis.redis_handlers import RedisHandlers
|
||||||
from api_modules.redis.redis_handlers import RedisHandlers
|
from Extends.token.password_module import PasswordModule
|
||||||
from validations.password.validations import PasswordHistoryViaUser
|
from validations.password.validations import PasswordHistoryViaUser
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -68,7 +68,6 @@ class LoginHandler:
|
||||||
return str(email).split("@")[1] == api_config.ACCESS_EMAIL_EXT
|
return str(email).split("@")[1] == api_config.ACCESS_EMAIL_EXT
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# headers: CommonHeaders
|
|
||||||
def do_employee_login(cls, headers: CommonHeaders, data: Any, db_session):
|
def do_employee_login(cls, headers: CommonHeaders, data: Any, db_session):
|
||||||
"""Handle employee login."""
|
"""Handle employee login."""
|
||||||
|
|
||||||
|
|
@ -159,7 +158,6 @@ class LoginHandler:
|
||||||
raise ValueError("Something went wrong")
|
raise ValueError("Something went wrong")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# headers=headers, data=data, db_session=db_session
|
|
||||||
def do_occupant_login(cls, headers: CommonHeaders, data: Any, db_session):
|
def do_occupant_login(cls, headers: CommonHeaders, data: Any, db_session):
|
||||||
"""
|
"""
|
||||||
Handle occupant login.
|
Handle occupant login.
|
||||||
|
|
@ -376,7 +374,7 @@ class LoginHandler:
|
||||||
)
|
)
|
||||||
return {"selected_uu_id": occupant_token.living_space_uu_id}
|
return {"selected_uu_id": occupant_token.living_space_uu_id}
|
||||||
|
|
||||||
@classmethod # Requires auth context
|
@classmethod
|
||||||
def authentication_select_company_or_occupant_type(cls, request: Any, data: Any):
|
def authentication_select_company_or_occupant_type(cls, request: Any, data: Any):
|
||||||
"""
|
"""
|
||||||
Handle selection of company or occupant type
|
Handle selection of company or occupant type
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(BaseSettings):
|
||||||
|
"""
|
||||||
|
Email configuration settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
HOST: str = ""
|
||||||
|
USERNAME: str = ""
|
||||||
|
PASSWORD: str = ""
|
||||||
|
PORT: int = 0
|
||||||
|
SEND: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_send(self):
|
||||||
|
return bool(self.SEND)
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
return dict(
|
||||||
|
host=self.HOST,
|
||||||
|
port=self.PORT,
|
||||||
|
username=self.USERNAME,
|
||||||
|
password=self.PASSWORD,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_prefix="EMAIL_")
|
||||||
|
|
||||||
|
|
||||||
|
# singleton instance of the POSTGRESQL configuration settings
|
||||||
|
email_configs = Configs()
|
||||||
|
|
@ -0,0 +1,29 @@
|
||||||
|
from send_email import EmailService, EmailSendModel
|
||||||
|
|
||||||
|
|
||||||
|
# Create email parameters
|
||||||
|
email_params = EmailSendModel(
|
||||||
|
subject="Test Email",
|
||||||
|
html="<p>Hello world!</p>",
|
||||||
|
receivers=["recipient@example.com"],
|
||||||
|
text="Hello world!",
|
||||||
|
)
|
||||||
|
|
||||||
|
another_email_params = EmailSendModel(
|
||||||
|
subject="Test Email2",
|
||||||
|
html="<p>Hello world!2</p>",
|
||||||
|
receivers=["recipient@example.com"],
|
||||||
|
text="Hello world!2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# The context manager handles connection errors
|
||||||
|
with EmailService.new_session() as email_session:
|
||||||
|
# Send email - any exceptions here will propagate up
|
||||||
|
EmailService.send_email(email_session, email_params)
|
||||||
|
|
||||||
|
# Or send directly through the session
|
||||||
|
email_session.send(email_params)
|
||||||
|
|
||||||
|
# Send more emails in the same session if needed
|
||||||
|
EmailService.send_email(email_session, another_email_params)
|
||||||
|
|
@ -0,0 +1,90 @@
|
||||||
|
from redmail import EmailSender
|
||||||
|
from typing import List, Optional, Dict
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from .config import email_configs
|
||||||
|
|
||||||
|
|
||||||
|
class EmailSendModel(BaseModel):
|
||||||
|
subject: str
|
||||||
|
html: str = ""
|
||||||
|
receivers: List[str]
|
||||||
|
text: Optional[str] = ""
|
||||||
|
cc: Optional[List[str]] = None
|
||||||
|
bcc: Optional[List[str]] = None
|
||||||
|
headers: Optional[Dict] = None
|
||||||
|
attachments: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmailSession:
|
||||||
|
|
||||||
|
def __init__(self, email_sender):
|
||||||
|
self.email_sender = email_sender
|
||||||
|
|
||||||
|
def send(self, params: EmailSendModel) -> bool:
|
||||||
|
"""Send email using this session."""
|
||||||
|
if not email_configs.is_send:
|
||||||
|
print("Email sending is disabled", params)
|
||||||
|
return False
|
||||||
|
receivers = [email_configs.USERNAME]
|
||||||
|
|
||||||
|
# Ensure connection is established before sending
|
||||||
|
try:
|
||||||
|
# Check if connection exists, if not establish it
|
||||||
|
if not hasattr(self.email_sender, '_connected') or not self.email_sender._connected:
|
||||||
|
self.email_sender.connect()
|
||||||
|
|
||||||
|
self.email_sender.send(
|
||||||
|
subject=params.subject,
|
||||||
|
receivers=receivers,
|
||||||
|
text=params.text + f" : Gonderilen [{str(receivers)}]",
|
||||||
|
html=params.html,
|
||||||
|
cc=params.cc,
|
||||||
|
bcc=params.bcc,
|
||||||
|
headers=params.headers or {},
|
||||||
|
attachments=params.attachments or {},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error sending email: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class EmailService:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(EmailService, cls).__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def new_session(cls):
|
||||||
|
"""Create and yield a new email session with active connection."""
|
||||||
|
email_sender = EmailSender(**email_configs.as_dict())
|
||||||
|
session = EmailSession(email_sender)
|
||||||
|
connection_established = False
|
||||||
|
try:
|
||||||
|
# Establish connection and set flag
|
||||||
|
email_sender.connect()
|
||||||
|
# Set a flag to track connection state
|
||||||
|
email_sender._connected = True
|
||||||
|
connection_established = True
|
||||||
|
yield session
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error with email connection: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Only close if connection was successfully established
|
||||||
|
if connection_established:
|
||||||
|
try:
|
||||||
|
email_sender.close()
|
||||||
|
email_sender._connected = False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error closing email connection: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_email(cls, session: EmailSession, params: EmailSendModel) -> bool:
|
||||||
|
"""Send email using the provided session."""
|
||||||
|
return session.send(params)
|
||||||
|
|
@ -0,0 +1,219 @@
|
||||||
|
# MongoDB Handler
|
||||||
|
|
||||||
|
A singleton MongoDB handler with context manager support for MongoDB collections and automatic retry capabilities.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Singleton Pattern**: Ensures only one instance of the MongoDB handler exists
|
||||||
|
- **Context Manager**: Automatically manages connection lifecycle
|
||||||
|
- **Retry Capability**: Automatically retries MongoDB operations on failure
|
||||||
|
- **Connection Pooling**: Configurable connection pooling
|
||||||
|
- **Graceful Degradation**: Handles connection failures without crashing
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from Controllers.Mongo.database import mongo_handler
|
||||||
|
|
||||||
|
# Use the context manager to access a collection
|
||||||
|
with mongo_handler.collection("users") as users_collection:
|
||||||
|
# Perform operations on the collection
|
||||||
|
users_collection.insert_one({"username": "john", "email": "john@example.com"})
|
||||||
|
user = users_collection.find_one({"username": "john"})
|
||||||
|
# Connection is automatically closed when exiting the context
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
MongoDB connection settings are configured via environment variables with the `MONGO_` prefix:
|
||||||
|
|
||||||
|
- `MONGO_ENGINE`: Database engine (e.g., "mongodb")
|
||||||
|
- `MONGO_USER`: MongoDB username
|
||||||
|
- `MONGO_PASSWORD`: MongoDB password
|
||||||
|
- `MONGO_HOST`: MongoDB host
|
||||||
|
- `MONGO_PORT`: MongoDB port
|
||||||
|
- `MONGO_DB`: Database name
|
||||||
|
- `MONGO_AUTH_DB`: Authentication database
|
||||||
|
|
||||||
|
## Monitoring Connection Closure
|
||||||
|
|
||||||
|
To verify that MongoDB sessions are properly closed, you can implement one of the following approaches:
|
||||||
|
|
||||||
|
### 1. Add Logging to the `__exit__` Method
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""
|
||||||
|
Exit context, closing the connection.
|
||||||
|
"""
|
||||||
|
if self.client:
|
||||||
|
print(f"Closing MongoDB connection for collection: {self.collection_name}")
|
||||||
|
# Or use a proper logger
|
||||||
|
# logger.info(f"Closing MongoDB connection for collection: {self.collection_name}")
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
self.collection = None
|
||||||
|
print(f"MongoDB connection closed successfully")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Add Connection Tracking
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MongoDBHandler:
|
||||||
|
# Add these to your class
|
||||||
|
_open_connections = 0
|
||||||
|
|
||||||
|
def get_connection_stats(self):
|
||||||
|
"""Return statistics about open connections"""
|
||||||
|
return {"open_connections": self._open_connections}
|
||||||
|
```
|
||||||
|
|
||||||
|
Then modify the `CollectionContext` class:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __enter__(self):
|
||||||
|
try:
|
||||||
|
# Create a new client connection
|
||||||
|
self.client = MongoClient(self.db_handler.uri, **self.db_handler.client_options)
|
||||||
|
# Increment connection counter
|
||||||
|
self.db_handler._open_connections += 1
|
||||||
|
# Rest of your code...
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.client:
|
||||||
|
# Decrement connection counter
|
||||||
|
self.db_handler._open_connections -= 1
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
self.collection = None
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Use MongoDB's Built-in Monitoring
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pymongo import monitoring
|
||||||
|
|
||||||
|
class ConnectionCommandListener(monitoring.CommandListener):
|
||||||
|
def started(self, event):
|
||||||
|
print(f"Command {event.command_name} started on server {event.connection_id}")
|
||||||
|
|
||||||
|
def succeeded(self, event):
|
||||||
|
print(f"Command {event.command_name} succeeded in {event.duration_micros} microseconds")
|
||||||
|
|
||||||
|
def failed(self, event):
|
||||||
|
print(f"Command {event.command_name} failed in {event.duration_micros} microseconds")
|
||||||
|
|
||||||
|
# Register the listener
|
||||||
|
monitoring.register(ConnectionCommandListener())
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Add a Test Function
|
||||||
|
|
||||||
|
```python
|
||||||
|
def test_connection_closure():
|
||||||
|
"""Test that MongoDB connections are properly closed."""
|
||||||
|
print("\nTesting connection closure...")
|
||||||
|
|
||||||
|
# Record initial connection count (if you implemented the counter)
|
||||||
|
initial_count = mongo_handler.get_connection_stats()["open_connections"]
|
||||||
|
|
||||||
|
# Use multiple nested contexts
|
||||||
|
for i in range(5):
|
||||||
|
with mongo_handler.collection("test_collection") as collection:
|
||||||
|
# Do some simple operation
|
||||||
|
collection.find_one({})
|
||||||
|
|
||||||
|
# Check final connection count
|
||||||
|
final_count = mongo_handler.get_connection_stats()["open_connections"]
|
||||||
|
|
||||||
|
if final_count == initial_count:
|
||||||
|
print("Test passed: All connections were properly closed")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Test failed: {final_count - initial_count} connections remain open")
|
||||||
|
return False
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Use MongoDB Server Logs
|
||||||
|
|
||||||
|
You can also check the MongoDB server logs to see connection events:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run this on your MongoDB server
|
||||||
|
tail -f /var/log/mongodb/mongod.log | grep "connection"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. Always use the context manager pattern to ensure connections are properly closed
|
||||||
|
2. Keep operations within the context manager as concise as possible
|
||||||
|
3. Handle exceptions within the context to prevent unexpected behavior
|
||||||
|
4. Avoid nesting multiple context managers unnecessarily
|
||||||
|
5. Use the retry decorator for operations that might fail due to transient issues
|
||||||
|
|
||||||
|
## LXC Container Configuration
|
||||||
|
|
||||||
|
### Authentication Issues
|
||||||
|
|
||||||
|
If you encounter authentication errors when connecting to the MongoDB container at 10.10.2.13:27017, you may need to update the container configuration:
|
||||||
|
|
||||||
|
1. **Check MongoDB Authentication**: Ensure the MongoDB container is configured with the correct authentication mechanism
|
||||||
|
|
||||||
|
2. **Verify Network Configuration**: Make sure the container network allows connections from your application
|
||||||
|
|
||||||
|
3. **Update MongoDB Configuration**:
|
||||||
|
- Edit the MongoDB configuration file in the container
|
||||||
|
- Ensure `bindIp` is set correctly (e.g., `0.0.0.0` to allow connections from any IP)
|
||||||
|
- Check that authentication is enabled with the correct mechanism
|
||||||
|
|
||||||
|
4. **User Permissions**:
|
||||||
|
- Verify that the application user (`appuser`) exists in the MongoDB instance
|
||||||
|
- Ensure the user has the correct roles and permissions for the database
|
||||||
|
|
||||||
|
### Example MongoDB Container Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example docker-compose.yml configuration
|
||||||
|
services:
|
||||||
|
mongodb:
|
||||||
|
image: mongo:latest
|
||||||
|
container_name: mongodb
|
||||||
|
environment:
|
||||||
|
- MONGO_INITDB_ROOT_USERNAME=admin
|
||||||
|
- MONGO_INITDB_ROOT_PASSWORD=password
|
||||||
|
volumes:
|
||||||
|
- ./init-mongo.js:/docker-entrypoint-initdb.d/init-mongo.js:ro
|
||||||
|
ports:
|
||||||
|
- "27017:27017"
|
||||||
|
command: mongod --auth
|
||||||
|
```
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// Example init-mongo.js
|
||||||
|
db.createUser({
|
||||||
|
user: 'appuser',
|
||||||
|
pwd: 'apppassword',
|
||||||
|
roles: [
|
||||||
|
{ role: 'readWrite', db: 'appdb' }
|
||||||
|
]
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Authentication Failed**:
|
||||||
|
- Verify username and password in environment variables
|
||||||
|
- Check that the user exists in the specified authentication database
|
||||||
|
- Ensure the user has appropriate permissions
|
||||||
|
|
||||||
|
2. **Connection Refused**:
|
||||||
|
- Verify the MongoDB host and port are correct
|
||||||
|
- Check network connectivity between application and MongoDB container
|
||||||
|
- Ensure MongoDB is running and accepting connections
|
||||||
|
|
||||||
|
3. **Resource Leaks**:
|
||||||
|
- Use the context manager pattern to ensure connections are properly closed
|
||||||
|
- Monitor connection pool size and active connections
|
||||||
|
- Implement proper error handling to close connections in case of exceptions
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
import os
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(BaseSettings):
|
||||||
|
"""
|
||||||
|
MongoDB configuration settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MongoDB connection settings
|
||||||
|
ENGINE: str = "mongodb"
|
||||||
|
USERNAME: str = "appuser" # Application user
|
||||||
|
PASSWORD: str = "apppassword" # Application password
|
||||||
|
HOST: str = "10.10.2.13"
|
||||||
|
PORT: int = 27017
|
||||||
|
DB: str = "appdb" # The application database
|
||||||
|
AUTH_DB: str = "appdb" # Authentication is done against admin database
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self):
|
||||||
|
"""Generate the database URL.
|
||||||
|
mongodb://{MONGO_USERNAME}:{MONGO_PASSWORD}@{MONGO_HOST}:{MONGO_PORT}/{DB}?authSource={MONGO_AUTH_DB}
|
||||||
|
"""
|
||||||
|
# Include the database name in the URI
|
||||||
|
return f"{self.ENGINE}://{self.USERNAME}:{self.PASSWORD}@{self.HOST}:{self.PORT}/{self.DB}?authSource={self.DB}"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_prefix="_MONGO_")
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance of the MongoDB configuration settings
|
||||||
|
mongo_configs = Configs()
|
||||||
|
|
@ -0,0 +1,373 @@
|
||||||
|
import time
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from pymongo import MongoClient
|
||||||
|
from pymongo.errors import PyMongoError
|
||||||
|
from .config import mongo_configs
|
||||||
|
|
||||||
|
|
||||||
|
def retry_operation(max_attempts=3, delay=1.0, backoff=2.0, exceptions=(PyMongoError,)):
|
||||||
|
"""
|
||||||
|
Decorator for retrying MongoDB operations with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_attempts: Maximum number of retry attempts
|
||||||
|
delay: Initial delay between retries in seconds
|
||||||
|
backoff: Multiplier for delay after each retry
|
||||||
|
exceptions: Tuple of exceptions to catch and retry
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
mtries, mdelay = max_attempts, delay
|
||||||
|
while mtries > 1:
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except exceptions as e:
|
||||||
|
time.sleep(mdelay)
|
||||||
|
mtries -= 1
|
||||||
|
mdelay *= backoff
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MongoDBHandler:
|
||||||
|
"""
|
||||||
|
A MongoDB handler that provides context manager access to specific collections
|
||||||
|
with automatic retry capability. Implements singleton pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_debug_mode = False # Set to True to enable debug mode
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Implement singleton pattern for the handler.
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(MongoDBHandler, cls).__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, debug_mode=False, mock_mode=False):
|
||||||
|
"""Initialize the MongoDB handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
debug_mode: If True, use a simplified connection for debugging
|
||||||
|
mock_mode: If True, use mock collections instead of real MongoDB connections
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "_initialized") or not self._initialized:
|
||||||
|
self._debug_mode = debug_mode
|
||||||
|
self._mock_mode = mock_mode
|
||||||
|
|
||||||
|
if mock_mode:
|
||||||
|
# In mock mode, we don't need a real connection string
|
||||||
|
self.uri = "mongodb://mock:27017/mockdb"
|
||||||
|
print("MOCK MODE: Using simulated MongoDB connections")
|
||||||
|
elif debug_mode:
|
||||||
|
# Use a direct connection without authentication for testing
|
||||||
|
self.uri = f"mongodb://{mongo_configs.HOST}:{mongo_configs.PORT}/{mongo_configs.DB}"
|
||||||
|
print(f"DEBUG MODE: Using direct connection: {self.uri}")
|
||||||
|
else:
|
||||||
|
# Use the configured connection string with authentication
|
||||||
|
self.uri = mongo_configs.url
|
||||||
|
print(f"Connecting to MongoDB: {self.uri}")
|
||||||
|
|
||||||
|
# Define MongoDB client options with increased timeouts for better reliability
|
||||||
|
self.client_options = {
|
||||||
|
"maxPoolSize": 5,
|
||||||
|
"minPoolSize": 1,
|
||||||
|
"maxIdleTimeMS": 60000,
|
||||||
|
"waitQueueTimeoutMS": 5000,
|
||||||
|
"serverSelectionTimeoutMS": 10000,
|
||||||
|
"connectTimeoutMS": 30000,
|
||||||
|
"socketTimeoutMS": 45000,
|
||||||
|
"retryWrites": True,
|
||||||
|
"retryReads": True,
|
||||||
|
}
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
def collection(self, collection_name: str):
|
||||||
|
"""
|
||||||
|
Get a context manager for a specific collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Name of the collection to access
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A context manager for the specified collection
|
||||||
|
"""
|
||||||
|
return CollectionContext(self, collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
class CollectionContext:
|
||||||
|
"""
|
||||||
|
Context manager for MongoDB collections with automatic retry capability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_handler: MongoDBHandler, collection_name: str):
|
||||||
|
"""
|
||||||
|
Initialize collection context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_handler: Reference to the MongoDB handler
|
||||||
|
collection_name: Name of the collection to access
|
||||||
|
"""
|
||||||
|
self.db_handler = db_handler
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.client = None
|
||||||
|
self.collection = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""
|
||||||
|
Enter context, establishing a new connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The MongoDB collection object with retry capabilities
|
||||||
|
"""
|
||||||
|
# If we're in mock mode, return a mock collection immediately
|
||||||
|
if self.db_handler._mock_mode:
|
||||||
|
return self._create_mock_collection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a new client connection
|
||||||
|
self.client = MongoClient(
|
||||||
|
self.db_handler.uri, **self.db_handler.client_options
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.db_handler._debug_mode:
|
||||||
|
# In debug mode, we explicitly use the configured DB
|
||||||
|
db_name = mongo_configs.DB
|
||||||
|
print(f"DEBUG MODE: Using database '{db_name}'")
|
||||||
|
else:
|
||||||
|
# In normal mode, extract database name from the URI
|
||||||
|
try:
|
||||||
|
db_name = self.client.get_database().name
|
||||||
|
except Exception:
|
||||||
|
db_name = mongo_configs.DB
|
||||||
|
print(f"Using fallback database '{db_name}'")
|
||||||
|
|
||||||
|
self.collection = self.client[db_name][self.collection_name]
|
||||||
|
|
||||||
|
# Enhance collection methods with retry capabilities
|
||||||
|
self._add_retry_capabilities()
|
||||||
|
|
||||||
|
return self.collection
|
||||||
|
except pymongo.errors.OperationFailure as e:
|
||||||
|
if "Authentication failed" in str(e):
|
||||||
|
print(f"MongoDB authentication error: {e}")
|
||||||
|
print("Attempting to reconnect with direct connection...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try a direct connection without authentication for testing
|
||||||
|
direct_uri = f"mongodb://{mongo_configs.HOST}:{mongo_configs.PORT}/{mongo_configs.DB}"
|
||||||
|
print(f"Trying direct connection: {direct_uri}")
|
||||||
|
self.client = MongoClient(
|
||||||
|
direct_uri, **self.db_handler.client_options
|
||||||
|
)
|
||||||
|
self.collection = self.client[mongo_configs.DB][
|
||||||
|
self.collection_name
|
||||||
|
]
|
||||||
|
self._add_retry_capabilities()
|
||||||
|
return self.collection
|
||||||
|
except Exception as inner_e:
|
||||||
|
print(f"Direct connection also failed: {inner_e}")
|
||||||
|
# Fall through to mock collection creation
|
||||||
|
else:
|
||||||
|
print(f"MongoDB operation error: {e}")
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"MongoDB connection error: {e}")
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
return self._create_mock_collection()
|
||||||
|
|
||||||
|
def _create_mock_collection(self):
|
||||||
|
"""
|
||||||
|
Create a mock collection for testing or graceful degradation.
|
||||||
|
This prevents the application from crashing when MongoDB is unavailable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mock MongoDB collection with simulated behaviors
|
||||||
|
"""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
if self.db_handler._mock_mode:
|
||||||
|
print(f"MOCK MODE: Using mock collection '{self.collection_name}'")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Using mock MongoDB collection '{self.collection_name}' for graceful degradation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create in-memory storage for this mock collection
|
||||||
|
if not hasattr(self.db_handler, "_mock_storage"):
|
||||||
|
self.db_handler._mock_storage = {}
|
||||||
|
|
||||||
|
if self.collection_name not in self.db_handler._mock_storage:
|
||||||
|
self.db_handler._mock_storage[self.collection_name] = []
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_data = self.db_handler._mock_storage[self.collection_name]
|
||||||
|
|
||||||
|
# Define behavior for find operations
|
||||||
|
def mock_find(query=None, *args, **kwargs):
|
||||||
|
# Simple implementation that returns all documents
|
||||||
|
return mock_data
|
||||||
|
|
||||||
|
def mock_find_one(query=None, *args, **kwargs):
|
||||||
|
# Simple implementation that returns the first matching document
|
||||||
|
if not mock_data:
|
||||||
|
return None
|
||||||
|
return mock_data[0]
|
||||||
|
|
||||||
|
def mock_insert_one(document, *args, **kwargs):
|
||||||
|
# Add _id if not present
|
||||||
|
if "_id" not in document:
|
||||||
|
document["_id"] = f"mock_id_{len(mock_data)}"
|
||||||
|
mock_data.append(document)
|
||||||
|
result = MagicMock()
|
||||||
|
result.inserted_id = document["_id"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_insert_many(documents, *args, **kwargs):
|
||||||
|
inserted_ids = []
|
||||||
|
for doc in documents:
|
||||||
|
result = mock_insert_one(doc)
|
||||||
|
inserted_ids.append(result.inserted_id)
|
||||||
|
result = MagicMock()
|
||||||
|
result.inserted_ids = inserted_ids
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_update_one(query, update, *args, **kwargs):
|
||||||
|
result = MagicMock()
|
||||||
|
result.modified_count = 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_update_many(query, update, *args, **kwargs):
|
||||||
|
result = MagicMock()
|
||||||
|
result.modified_count = len(mock_data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_delete_one(query, *args, **kwargs):
|
||||||
|
result = MagicMock()
|
||||||
|
result.deleted_count = 1
|
||||||
|
if mock_data:
|
||||||
|
mock_data.pop(0) # Just remove the first item for simplicity
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_delete_many(query, *args, **kwargs):
|
||||||
|
count = len(mock_data)
|
||||||
|
mock_data.clear()
|
||||||
|
result = MagicMock()
|
||||||
|
result.deleted_count = count
|
||||||
|
return result
|
||||||
|
|
||||||
|
def mock_count_documents(query, *args, **kwargs):
|
||||||
|
return len(mock_data)
|
||||||
|
|
||||||
|
def mock_aggregate(pipeline, *args, **kwargs):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def mock_create_index(keys, **kwargs):
|
||||||
|
return f"mock_index_{keys}"
|
||||||
|
|
||||||
|
# Assign the mock implementations
|
||||||
|
mock_collection.find.side_effect = mock_find
|
||||||
|
mock_collection.find_one.side_effect = mock_find_one
|
||||||
|
mock_collection.insert_one.side_effect = mock_insert_one
|
||||||
|
mock_collection.insert_many.side_effect = mock_insert_many
|
||||||
|
mock_collection.update_one.side_effect = mock_update_one
|
||||||
|
mock_collection.update_many.side_effect = mock_update_many
|
||||||
|
mock_collection.delete_one.side_effect = mock_delete_one
|
||||||
|
mock_collection.delete_many.side_effect = mock_delete_many
|
||||||
|
mock_collection.count_documents.side_effect = mock_count_documents
|
||||||
|
mock_collection.aggregate.side_effect = mock_aggregate
|
||||||
|
mock_collection.create_index.side_effect = mock_create_index
|
||||||
|
|
||||||
|
# Add retry capabilities to the mock collection
|
||||||
|
self._add_retry_capabilities_to_mock(mock_collection)
|
||||||
|
|
||||||
|
self.collection = mock_collection
|
||||||
|
return self.collection
|
||||||
|
|
||||||
|
def _add_retry_capabilities(self):
|
||||||
|
"""
|
||||||
|
Add retry capabilities to all collection methods.
|
||||||
|
"""
|
||||||
|
# Store original methods for common operations
|
||||||
|
original_insert_one = self.collection.insert_one
|
||||||
|
original_insert_many = self.collection.insert_many
|
||||||
|
original_find_one = self.collection.find_one
|
||||||
|
original_find = self.collection.find
|
||||||
|
original_update_one = self.collection.update_one
|
||||||
|
original_update_many = self.collection.update_many
|
||||||
|
original_delete_one = self.collection.delete_one
|
||||||
|
original_delete_many = self.collection.delete_many
|
||||||
|
original_replace_one = self.collection.replace_one
|
||||||
|
original_count_documents = self.collection.count_documents
|
||||||
|
|
||||||
|
# Add retry capabilities to methods
|
||||||
|
self.collection.insert_one = retry_operation()(original_insert_one)
|
||||||
|
self.collection.insert_many = retry_operation()(original_insert_many)
|
||||||
|
self.collection.find_one = retry_operation()(original_find_one)
|
||||||
|
self.collection.find = retry_operation()(original_find)
|
||||||
|
self.collection.update_one = retry_operation()(original_update_one)
|
||||||
|
self.collection.update_many = retry_operation()(original_update_many)
|
||||||
|
self.collection.delete_one = retry_operation()(original_delete_one)
|
||||||
|
self.collection.delete_many = retry_operation()(original_delete_many)
|
||||||
|
self.collection.replace_one = retry_operation()(original_replace_one)
|
||||||
|
self.collection.count_documents = retry_operation()(original_count_documents)
|
||||||
|
|
||||||
|
def _add_retry_capabilities_to_mock(self, mock_collection):
|
||||||
|
"""
|
||||||
|
Add retry capabilities to mock collection methods.
|
||||||
|
This is a simplified version that just wraps the mock methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mock_collection: The mock collection to enhance
|
||||||
|
"""
|
||||||
|
# List of common MongoDB collection methods to add retry capabilities to
|
||||||
|
methods = [
|
||||||
|
"insert_one",
|
||||||
|
"insert_many",
|
||||||
|
"find_one",
|
||||||
|
"find",
|
||||||
|
"update_one",
|
||||||
|
"update_many",
|
||||||
|
"delete_one",
|
||||||
|
"delete_many",
|
||||||
|
"replace_one",
|
||||||
|
"count_documents",
|
||||||
|
"aggregate",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add retry decorator to each method
|
||||||
|
for method_name in methods:
|
||||||
|
if hasattr(mock_collection, method_name):
|
||||||
|
original_method = getattr(mock_collection, method_name)
|
||||||
|
setattr(
|
||||||
|
mock_collection,
|
||||||
|
method_name,
|
||||||
|
retry_operation(max_retries=1, retry_interval=0)(original_method),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""
|
||||||
|
Exit context, closing the connection.
|
||||||
|
"""
|
||||||
|
if self.client:
|
||||||
|
self.client.close()
|
||||||
|
self.client = None
|
||||||
|
self.collection = None
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance of the MongoDB handler
|
||||||
|
mongo_handler = MongoDBHandler()
|
||||||
|
|
@ -0,0 +1,519 @@
|
||||||
|
# Initialize the MongoDB handler with your configuration
|
||||||
|
from datetime import datetime
|
||||||
|
from .database import MongoDBHandler, mongo_handler
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_test_data():
|
||||||
|
"""Clean up any test data before running tests."""
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("test_collection") as collection:
|
||||||
|
collection.delete_many({})
|
||||||
|
print("Successfully cleaned up test data")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not clean up test data: {e}")
|
||||||
|
print("Continuing with tests using mock data...")
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_crud_operations():
|
||||||
|
"""Test basic CRUD operations on users collection."""
|
||||||
|
print("\nTesting basic CRUD operations...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("users") as users_collection:
|
||||||
|
# First, clear any existing data
|
||||||
|
users_collection.delete_many({})
|
||||||
|
print("Cleared existing data")
|
||||||
|
|
||||||
|
# Insert multiple documents
|
||||||
|
insert_result = users_collection.insert_many(
|
||||||
|
[
|
||||||
|
{"username": "john", "email": "john@example.com", "role": "user"},
|
||||||
|
{"username": "jane", "email": "jane@example.com", "role": "admin"},
|
||||||
|
{"username": "bob", "email": "bob@example.com", "role": "user"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"Inserted {len(insert_result.inserted_ids)} documents")
|
||||||
|
|
||||||
|
# Find with multiple conditions
|
||||||
|
admin_users = list(users_collection.find({"role": "admin"}))
|
||||||
|
print(f"Found {len(admin_users)} admin users")
|
||||||
|
if admin_users:
|
||||||
|
print(f"Admin user: {admin_users[0].get('username')}")
|
||||||
|
|
||||||
|
# Update multiple documents
|
||||||
|
update_result = users_collection.update_many(
|
||||||
|
{"role": "user"}, {"$set": {"last_login": datetime.now().isoformat()}}
|
||||||
|
)
|
||||||
|
print(f"Updated {update_result.modified_count} documents")
|
||||||
|
|
||||||
|
# Delete documents
|
||||||
|
delete_result = users_collection.delete_many({"username": "bob"})
|
||||||
|
print(f"Deleted {delete_result.deleted_count} documents")
|
||||||
|
|
||||||
|
# Count remaining documents
|
||||||
|
remaining = users_collection.count_documents({})
|
||||||
|
print(f"Remaining documents: {remaining}")
|
||||||
|
|
||||||
|
# Check each condition separately
|
||||||
|
condition1 = len(admin_users) == 1
|
||||||
|
condition2 = admin_users and admin_users[0].get("username") == "jane"
|
||||||
|
condition3 = update_result.modified_count == 2
|
||||||
|
condition4 = delete_result.deleted_count == 1
|
||||||
|
|
||||||
|
print(f"Condition 1 (admin count): {condition1}")
|
||||||
|
print(f"Condition 2 (admin is jane): {condition2}")
|
||||||
|
print(f"Condition 3 (updated 2 users): {condition3}")
|
||||||
|
print(f"Condition 4 (deleted bob): {condition4}")
|
||||||
|
|
||||||
|
success = condition1 and condition2 and condition3 and condition4
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_documents():
|
||||||
|
"""Test operations with nested documents in products collection."""
|
||||||
|
print("\nTesting nested documents...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("products") as products_collection:
|
||||||
|
# Clear any existing data
|
||||||
|
products_collection.delete_many({})
|
||||||
|
print("Cleared existing data")
|
||||||
|
|
||||||
|
# Insert a product with nested data
|
||||||
|
insert_result = products_collection.insert_one(
|
||||||
|
{
|
||||||
|
"name": "Laptop",
|
||||||
|
"price": 999.99,
|
||||||
|
"specs": {"cpu": "Intel i7", "ram": "16GB", "storage": "512GB SSD"},
|
||||||
|
"in_stock": True,
|
||||||
|
"tags": ["electronics", "computers", "laptops"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Inserted document with ID: {insert_result.inserted_id}")
|
||||||
|
|
||||||
|
# Find with nested field query
|
||||||
|
laptop = products_collection.find_one({"specs.cpu": "Intel i7"})
|
||||||
|
print(f"Found laptop: {laptop is not None}")
|
||||||
|
if laptop:
|
||||||
|
print(f"Laptop RAM: {laptop.get('specs', {}).get('ram')}")
|
||||||
|
|
||||||
|
# Update nested field
|
||||||
|
update_result = products_collection.update_one(
|
||||||
|
{"name": "Laptop"}, {"$set": {"specs.ram": "32GB"}}
|
||||||
|
)
|
||||||
|
print(f"Update modified count: {update_result.modified_count}")
|
||||||
|
|
||||||
|
# Verify the update
|
||||||
|
updated_laptop = products_collection.find_one({"name": "Laptop"})
|
||||||
|
print(f"Found updated laptop: {updated_laptop is not None}")
|
||||||
|
if updated_laptop:
|
||||||
|
print(f"Updated laptop specs: {updated_laptop.get('specs')}")
|
||||||
|
if "specs" in updated_laptop:
|
||||||
|
print(f"Updated RAM: {updated_laptop['specs'].get('ram')}")
|
||||||
|
|
||||||
|
# Check each condition separately
|
||||||
|
condition1 = laptop is not None
|
||||||
|
condition2 = laptop and laptop.get("specs", {}).get("ram") == "16GB"
|
||||||
|
condition3 = update_result.modified_count == 1
|
||||||
|
condition4 = (
|
||||||
|
updated_laptop and updated_laptop.get("specs", {}).get("ram") == "32GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Condition 1 (laptop found): {condition1}")
|
||||||
|
print(f"Condition 2 (original RAM is 16GB): {condition2}")
|
||||||
|
print(f"Condition 3 (update modified 1 doc): {condition3}")
|
||||||
|
print(f"Condition 4 (updated RAM is 32GB): {condition4}")
|
||||||
|
|
||||||
|
success = condition1 and condition2 and condition3 and condition4
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_array_operations():
|
||||||
|
"""Test operations with arrays in orders collection."""
|
||||||
|
print("\nTesting array operations...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("orders") as orders_collection:
|
||||||
|
# Clear any existing data
|
||||||
|
orders_collection.delete_many({})
|
||||||
|
print("Cleared existing data")
|
||||||
|
|
||||||
|
# Insert an order with array of items
|
||||||
|
insert_result = orders_collection.insert_one(
|
||||||
|
{
|
||||||
|
"order_id": "ORD001",
|
||||||
|
"customer": "john",
|
||||||
|
"items": [
|
||||||
|
{"product": "Laptop", "quantity": 1},
|
||||||
|
{"product": "Mouse", "quantity": 2},
|
||||||
|
],
|
||||||
|
"total": 1099.99,
|
||||||
|
"status": "pending",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"Inserted order with ID: {insert_result.inserted_id}")
|
||||||
|
|
||||||
|
# Find orders containing specific items
|
||||||
|
laptop_orders = list(orders_collection.find({"items.product": "Laptop"}))
|
||||||
|
print(f"Found {len(laptop_orders)} orders with Laptop")
|
||||||
|
|
||||||
|
# Update array elements
|
||||||
|
update_result = orders_collection.update_one(
|
||||||
|
{"order_id": "ORD001"},
|
||||||
|
{"$push": {"items": {"product": "Keyboard", "quantity": 1}}},
|
||||||
|
)
|
||||||
|
print(f"Update modified count: {update_result.modified_count}")
|
||||||
|
|
||||||
|
# Verify the update
|
||||||
|
updated_order = orders_collection.find_one({"order_id": "ORD001"})
|
||||||
|
print(f"Found updated order: {updated_order is not None}")
|
||||||
|
|
||||||
|
if updated_order:
|
||||||
|
print(
|
||||||
|
f"Number of items in order: {len(updated_order.get('items', []))}"
|
||||||
|
)
|
||||||
|
items = updated_order.get("items", [])
|
||||||
|
if items:
|
||||||
|
last_item = items[-1] if items else None
|
||||||
|
print(f"Last item in order: {last_item}")
|
||||||
|
|
||||||
|
# Check each condition separately
|
||||||
|
condition1 = len(laptop_orders) == 1
|
||||||
|
condition2 = update_result.modified_count == 1
|
||||||
|
condition3 = updated_order and len(updated_order.get("items", [])) == 3
|
||||||
|
condition4 = (
|
||||||
|
updated_order
|
||||||
|
and updated_order.get("items", [])
|
||||||
|
and updated_order["items"][-1].get("product") == "Keyboard"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Condition 1 (found 1 laptop order): {condition1}")
|
||||||
|
print(f"Condition 2 (update modified 1 doc): {condition2}")
|
||||||
|
print(f"Condition 3 (order has 3 items): {condition3}")
|
||||||
|
print(f"Condition 4 (last item is keyboard): {condition4}")
|
||||||
|
|
||||||
|
success = condition1 and condition2 and condition3 and condition4
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregation():
|
||||||
|
"""Test aggregation operations on sales collection."""
|
||||||
|
print("\nTesting aggregation operations...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("sales") as sales_collection:
|
||||||
|
# Clear any existing data
|
||||||
|
sales_collection.delete_many({})
|
||||||
|
print("Cleared existing data")
|
||||||
|
|
||||||
|
# Insert sample sales data
|
||||||
|
insert_result = sales_collection.insert_many(
|
||||||
|
[
|
||||||
|
{"product": "Laptop", "amount": 999.99, "date": datetime.now()},
|
||||||
|
{"product": "Mouse", "amount": 29.99, "date": datetime.now()},
|
||||||
|
{"product": "Keyboard", "amount": 59.99, "date": datetime.now()},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(f"Inserted {len(insert_result.inserted_ids)} sales documents")
|
||||||
|
|
||||||
|
# Calculate total sales by product - use a simpler aggregation pipeline
|
||||||
|
pipeline = [
|
||||||
|
{"$match": {}}, # Match all documents
|
||||||
|
{"$group": {"_id": "$product", "total": {"$sum": "$amount"}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute the aggregation
|
||||||
|
sales_summary = list(sales_collection.aggregate(pipeline))
|
||||||
|
print(f"Aggregation returned {len(sales_summary)} results")
|
||||||
|
|
||||||
|
# Print the results for debugging
|
||||||
|
for item in sales_summary:
|
||||||
|
print(f"Product: {item.get('_id')}, Total: {item.get('total')}")
|
||||||
|
|
||||||
|
# Check each condition separately
|
||||||
|
condition1 = len(sales_summary) == 3
|
||||||
|
condition2 = any(
|
||||||
|
item.get("_id") == "Laptop"
|
||||||
|
and abs(item.get("total", 0) - 999.99) < 0.01
|
||||||
|
for item in sales_summary
|
||||||
|
)
|
||||||
|
condition3 = any(
|
||||||
|
item.get("_id") == "Mouse" and abs(item.get("total", 0) - 29.99) < 0.01
|
||||||
|
for item in sales_summary
|
||||||
|
)
|
||||||
|
condition4 = any(
|
||||||
|
item.get("_id") == "Keyboard"
|
||||||
|
and abs(item.get("total", 0) - 59.99) < 0.01
|
||||||
|
for item in sales_summary
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Condition 1 (3 summary items): {condition1}")
|
||||||
|
print(f"Condition 2 (laptop total correct): {condition2}")
|
||||||
|
print(f"Condition 3 (mouse total correct): {condition3}")
|
||||||
|
print(f"Condition 4 (keyboard total correct): {condition4}")
|
||||||
|
|
||||||
|
success = condition1 and condition2 and condition3 and condition4
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_index_operations():
|
||||||
|
"""Test index creation and unique constraints."""
|
||||||
|
print("\nTesting index operations...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("test_collection") as collection:
|
||||||
|
# Create indexes
|
||||||
|
collection.create_index("email", unique=True)
|
||||||
|
collection.create_index([("username", 1), ("role", 1)])
|
||||||
|
|
||||||
|
# Insert initial document
|
||||||
|
collection.insert_one(
|
||||||
|
{"username": "test_user", "email": "test@example.com"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to insert duplicate email (should fail)
|
||||||
|
try:
|
||||||
|
collection.insert_one(
|
||||||
|
{"username": "test_user2", "email": "test@example.com"}
|
||||||
|
)
|
||||||
|
success = False # Should not reach here
|
||||||
|
except Exception:
|
||||||
|
success = True
|
||||||
|
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_queries():
|
||||||
|
"""Test complex queries with multiple conditions."""
|
||||||
|
print("\nTesting complex queries...")
|
||||||
|
try:
|
||||||
|
with mongo_handler.collection("products") as products_collection:
|
||||||
|
# Insert test data
|
||||||
|
products_collection.insert_many(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "Expensive Laptop",
|
||||||
|
"price": 999.99,
|
||||||
|
"tags": ["electronics", "computers"],
|
||||||
|
"in_stock": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Cheap Mouse",
|
||||||
|
"price": 29.99,
|
||||||
|
"tags": ["electronics", "peripherals"],
|
||||||
|
"in_stock": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find products with price range and specific tags
|
||||||
|
expensive_electronics = list(
|
||||||
|
products_collection.find(
|
||||||
|
{
|
||||||
|
"price": {"$gt": 500},
|
||||||
|
"tags": {"$in": ["electronics"]},
|
||||||
|
"in_stock": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update with multiple conditions - split into separate operations for better compatibility
|
||||||
|
# First set the discount
|
||||||
|
products_collection.update_many(
|
||||||
|
{"price": {"$lt": 100}, "in_stock": True}, {"$set": {"discount": 0.1}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then update the price
|
||||||
|
update_result = products_collection.update_many(
|
||||||
|
{"price": {"$lt": 100}, "in_stock": True}, {"$inc": {"price": -10}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the update
|
||||||
|
updated_product = products_collection.find_one({"name": "Cheap Mouse"})
|
||||||
|
|
||||||
|
# Print debug information
|
||||||
|
print(f"Found expensive electronics: {len(expensive_electronics)}")
|
||||||
|
if expensive_electronics:
|
||||||
|
print(
|
||||||
|
f"First expensive product: {expensive_electronics[0].get('name')}"
|
||||||
|
)
|
||||||
|
print(f"Modified count: {update_result.modified_count}")
|
||||||
|
if updated_product:
|
||||||
|
print(f"Updated product price: {updated_product.get('price')}")
|
||||||
|
print(f"Updated product discount: {updated_product.get('discount')}")
|
||||||
|
|
||||||
|
# More flexible verification with approximate float comparison
|
||||||
|
success = (
|
||||||
|
len(expensive_electronics) >= 1
|
||||||
|
and expensive_electronics[0].get("name")
|
||||||
|
in ["Expensive Laptop", "Laptop"]
|
||||||
|
and update_result.modified_count >= 1
|
||||||
|
and updated_product is not None
|
||||||
|
and updated_product.get("discount", 0)
|
||||||
|
> 0 # Just check that discount exists and is positive
|
||||||
|
)
|
||||||
|
print(f"Test {'passed' if success else 'failed'}")
|
||||||
|
return success
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test failed with exception: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def run_concurrent_operation_test(num_threads=100):
|
||||||
|
"""Run a simple operation in multiple threads to verify connection pooling."""
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
print(f"\nStarting concurrent operation test with {num_threads} threads...")
|
||||||
|
|
||||||
|
# Results tracking
|
||||||
|
results = {"passed": 0, "failed": 0, "errors": []}
|
||||||
|
results_lock = threading.Lock()
|
||||||
|
|
||||||
|
def worker(thread_id):
|
||||||
|
# Create a unique collection name for this thread
|
||||||
|
collection_name = f"concurrent_test_{thread_id}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate unique data for this thread
|
||||||
|
unique_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with mongo_handler.collection(collection_name) as collection:
|
||||||
|
# Insert a document
|
||||||
|
collection.insert_one(
|
||||||
|
{
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"uuid": unique_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find the document
|
||||||
|
doc = collection.find_one({"thread_id": thread_id})
|
||||||
|
|
||||||
|
# Update the document
|
||||||
|
collection.update_one(
|
||||||
|
{"thread_id": thread_id}, {"$set": {"updated": True}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify update
|
||||||
|
updated_doc = collection.find_one({"thread_id": thread_id})
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
collection.delete_many({"thread_id": thread_id})
|
||||||
|
|
||||||
|
success = (
|
||||||
|
doc is not None
|
||||||
|
and updated_doc is not None
|
||||||
|
and updated_doc.get("updated") is True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update results with thread safety
|
||||||
|
with results_lock:
|
||||||
|
if success:
|
||||||
|
results["passed"] += 1
|
||||||
|
else:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["errors"].append(f"Thread {thread_id} operation failed")
|
||||||
|
except Exception as e:
|
||||||
|
with results_lock:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["errors"].append(f"Thread {thread_id} exception: {str(e)}")
|
||||||
|
|
||||||
|
# Create and start threads using a thread pool
|
||||||
|
start_time = time.time()
|
||||||
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||||
|
futures = [executor.submit(worker, i) for i in range(num_threads)]
|
||||||
|
|
||||||
|
# Calculate execution time
|
||||||
|
execution_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\nConcurrent Operation Test Results:")
|
||||||
|
print(f"Total threads: {num_threads}")
|
||||||
|
print(f"Passed: {results['passed']}")
|
||||||
|
print(f"Failed: {results['failed']}")
|
||||||
|
print(f"Execution time: {execution_time:.2f} seconds")
|
||||||
|
print(f"Operations per second: {num_threads / execution_time:.2f}")
|
||||||
|
|
||||||
|
if results["failed"] > 0:
|
||||||
|
print("\nErrors:")
|
||||||
|
for error in results["errors"][
|
||||||
|
:10
|
||||||
|
]: # Show only first 10 errors to avoid flooding output
|
||||||
|
print(f"- {error}")
|
||||||
|
if len(results["errors"]) > 10:
|
||||||
|
print(f"- ... and {len(results['errors']) - 10} more errors")
|
||||||
|
|
||||||
|
return results["failed"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""Run all MongoDB tests and report results."""
|
||||||
|
print("Starting MongoDB tests...")
|
||||||
|
|
||||||
|
# Clean up any existing test data before starting
|
||||||
|
cleanup_test_data()
|
||||||
|
|
||||||
|
tests = [
|
||||||
|
test_basic_crud_operations,
|
||||||
|
test_nested_documents,
|
||||||
|
test_array_operations,
|
||||||
|
test_aggregation,
|
||||||
|
test_index_operations,
|
||||||
|
test_complex_queries,
|
||||||
|
]
|
||||||
|
|
||||||
|
passed_list, not_passed_list = [], []
|
||||||
|
passed, failed = 0, 0
|
||||||
|
|
||||||
|
for test in tests:
|
||||||
|
# Clean up test data before each test
|
||||||
|
cleanup_test_data()
|
||||||
|
try:
|
||||||
|
if test():
|
||||||
|
passed += 1
|
||||||
|
passed_list.append(f"Test {test.__name__} passed")
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
not_passed_list.append(f"Test {test.__name__} failed")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test {test.__name__} failed with exception: {e}")
|
||||||
|
failed += 1
|
||||||
|
not_passed_list.append(f"Test {test.__name__} failed")
|
||||||
|
|
||||||
|
print(f"\nTest Results: {passed} passed, {failed} failed")
|
||||||
|
print("Passed Tests:")
|
||||||
|
print("\n".join(passed_list))
|
||||||
|
print("Failed Tests:")
|
||||||
|
print("\n".join(not_passed_list))
|
||||||
|
|
||||||
|
return passed, failed
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mongo_handler = MongoDBHandler()
|
||||||
|
|
||||||
|
# Run standard tests first
|
||||||
|
passed, failed = run_all_tests()
|
||||||
|
|
||||||
|
# If all tests pass, run the concurrent operation test
|
||||||
|
if failed == 0:
|
||||||
|
run_concurrent_operation_test(10000)
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
"""
|
||||||
|
Test script for MongoDB handler with a local MongoDB instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from .database import MongoDBHandler, CollectionContext
|
||||||
|
|
||||||
|
# Create a custom handler class for local testing
|
||||||
|
class LocalMongoDBHandler(MongoDBHandler):
|
||||||
|
"""A MongoDB handler for local testing without authentication."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize with a direct MongoDB URI."""
|
||||||
|
self._initialized = False
|
||||||
|
self.uri = "mongodb://localhost:27017/test"
|
||||||
|
self.client_options = {
|
||||||
|
"maxPoolSize": 5,
|
||||||
|
"minPoolSize": 2,
|
||||||
|
"maxIdleTimeMS": 30000,
|
||||||
|
"waitQueueTimeoutMS": 2000,
|
||||||
|
"serverSelectionTimeoutMS": 5000,
|
||||||
|
}
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
# Create a custom handler for local testing
|
||||||
|
def create_local_handler():
|
||||||
|
"""Create a MongoDB handler for local testing."""
|
||||||
|
# Create a fresh instance with direct MongoDB URI
|
||||||
|
handler = LocalMongoDBHandler()
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def test_connection_monitoring():
|
||||||
|
"""Test connection monitoring with the MongoDB handler."""
|
||||||
|
print("\nTesting connection monitoring...")
|
||||||
|
|
||||||
|
# Create a local handler
|
||||||
|
local_handler = create_local_handler()
|
||||||
|
|
||||||
|
# Add connection tracking to the handler
|
||||||
|
local_handler._open_connections = 0
|
||||||
|
|
||||||
|
# Modify the CollectionContext class to track connections
|
||||||
|
original_enter = CollectionContext.__enter__
|
||||||
|
original_exit = CollectionContext.__exit__
|
||||||
|
|
||||||
|
def tracked_enter(self):
|
||||||
|
result = original_enter(self)
|
||||||
|
self.db_handler._open_connections += 1
|
||||||
|
print(f"Connection opened. Total open: {self.db_handler._open_connections}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def tracked_exit(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.db_handler._open_connections -= 1
|
||||||
|
print(f"Connection closed. Total open: {self.db_handler._open_connections}")
|
||||||
|
return original_exit(self, exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
|
# Apply the tracking methods
|
||||||
|
CollectionContext.__enter__ = tracked_enter
|
||||||
|
CollectionContext.__exit__ = tracked_exit
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Test with multiple operations
|
||||||
|
for i in range(3):
|
||||||
|
print(f"\nTest iteration {i+1}:")
|
||||||
|
try:
|
||||||
|
with local_handler.collection("test_collection") as collection:
|
||||||
|
# Try a simple operation
|
||||||
|
try:
|
||||||
|
collection.find_one({})
|
||||||
|
print("Operation succeeded")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Operation failed: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Connection failed: {e}")
|
||||||
|
|
||||||
|
# Final connection count
|
||||||
|
print(f"\nFinal open connections: {local_handler._open_connections}")
|
||||||
|
if local_handler._open_connections == 0:
|
||||||
|
print("✅ All connections were properly closed")
|
||||||
|
else:
|
||||||
|
print(f"❌ {local_handler._open_connections} connections remain open")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original methods
|
||||||
|
CollectionContext.__enter__ = original_enter
|
||||||
|
CollectionContext.__exit__ = original_exit
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_connection_monitoring()
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(BaseSettings):
|
||||||
|
"""
|
||||||
|
Postgresql configuration settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DB: str = ""
|
||||||
|
USER: str = ""
|
||||||
|
PASSWORD: str = ""
|
||||||
|
HOST: str = ""
|
||||||
|
PORT: int = 0
|
||||||
|
ENGINE: str = "postgresql+psycopg2"
|
||||||
|
POOL_PRE_PING: bool = True
|
||||||
|
POOL_SIZE: int = 20
|
||||||
|
MAX_OVERFLOW: int = 10
|
||||||
|
POOL_RECYCLE: int = 600
|
||||||
|
POOL_TIMEOUT: int = 30
|
||||||
|
ECHO: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self):
|
||||||
|
"""Generate the database URL."""
|
||||||
|
return f"{self.ENGINE}://{self.USER}:{self.PASSWORD}@{self.HOST}:{self.PORT}/{self.DB}"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_prefix="POSTGRES_")
|
||||||
|
|
||||||
|
|
||||||
|
# singleton instance of the POSTGRESQL configuration settings
|
||||||
|
postgres_configs = Configs()
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Generator
|
||||||
|
from api_controllers.postgres.config import postgres_configs
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker, scoped_session, Session
|
||||||
|
|
||||||
|
|
||||||
|
# Configure the database engine with proper pooling
|
||||||
|
engine = create_engine(
|
||||||
|
postgres_configs.url,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_size=10, # Reduced from 20 to better match your CPU cores
|
||||||
|
max_overflow=5, # Reduced from 10 to prevent too many connections
|
||||||
|
pool_recycle=600, # Keep as is
|
||||||
|
pool_timeout=30, # Keep as is
|
||||||
|
echo=False, # Consider setting to False in production
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
# Create a cached session factory
|
||||||
|
@lru_cache()
|
||||||
|
def get_session_factory() -> scoped_session:
|
||||||
|
"""Create a thread-safe session factory."""
|
||||||
|
session_local = sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
expire_on_commit=True, # Prevent expired object issues
|
||||||
|
)
|
||||||
|
return scoped_session(session_local)
|
||||||
|
|
||||||
|
|
||||||
|
# Get database session with proper connection management
|
||||||
|
@contextmanager
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""Get database session with proper connection management.
|
||||||
|
|
||||||
|
This context manager ensures:
|
||||||
|
- Proper connection pooling
|
||||||
|
- Session cleanup
|
||||||
|
- Connection return to pool
|
||||||
|
- Thread safety
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Session: SQLAlchemy session object
|
||||||
|
"""
|
||||||
|
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
session = session_factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
session_factory.remove() # Clean up the session from the registry
|
||||||
|
|
@ -0,0 +1,275 @@
|
||||||
|
import arrow
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any, TypeVar, Type, Union, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Integer, String, Float, ForeignKey, UUID, TIMESTAMP, Boolean, SmallInteger, Numeric, func, text, NUMERIC, ColumnExpressionArgument
|
||||||
|
from sqlalchemy.orm import InstrumentedAttribute, Mapped, mapped_column, Query, Session
|
||||||
|
from sqlalchemy.sql.elements import BinaryExpression
|
||||||
|
|
||||||
|
from sqlalchemy_mixins.serialize import SerializeMixin
|
||||||
|
from sqlalchemy_mixins.repr import ReprMixin
|
||||||
|
from sqlalchemy_mixins.smartquery import SmartQueryMixin
|
||||||
|
from sqlalchemy_mixins.activerecord import ActiveRecordMixin
|
||||||
|
|
||||||
|
from api_controllers.postgres.engine import get_db, Base
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("CrudMixin", bound="CrudMixin")
|
||||||
|
|
||||||
|
|
||||||
|
class BasicMixin(Base, ActiveRecordMixin, SerializeMixin, ReprMixin, SmartQueryMixin):
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
__repr__ = ReprMixin.__repr__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def new_session(cls):
|
||||||
|
"""Get database session."""
|
||||||
|
return get_db()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def iterate_over_variables(cls, val: Any, key: str) -> tuple[bool, Optional[Any]]:
|
||||||
|
"""
|
||||||
|
Process a field value based on its type and convert it to the appropriate format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val: Field value
|
||||||
|
key: Field name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (should_include, processed_value)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key_ = cls.__annotations__.get(key, None)
|
||||||
|
is_primary = key in getattr(cls, "primary_keys", [])
|
||||||
|
row_attr = bool(getattr(getattr(cls, key), "foreign_keys", None))
|
||||||
|
|
||||||
|
# Skip primary keys and foreign keys
|
||||||
|
if is_primary or row_attr:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
if val is None: # Handle None values
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
if str(key[-5:]).lower() == "uu_id": # Special handling for UUID fields
|
||||||
|
return True, str(val)
|
||||||
|
|
||||||
|
if key_: # Handle typed fields
|
||||||
|
if key_ == Mapped[int]:
|
||||||
|
return True, int(val)
|
||||||
|
elif key_ == Mapped[bool]:
|
||||||
|
return True, bool(val)
|
||||||
|
elif key_ == Mapped[float] or key_ == Mapped[NUMERIC]:
|
||||||
|
return True, round(float(val), 3)
|
||||||
|
elif key_ == Mapped[TIMESTAMP]:
|
||||||
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
|
||||||
|
elif key_ == Mapped[str]:
|
||||||
|
return True, str(val)
|
||||||
|
else: # Handle based on Python types
|
||||||
|
if isinstance(val, datetime.datetime):
|
||||||
|
return True, str(arrow.get(str(val)).format("YYYY-MM-DD HH:mm:ss"))
|
||||||
|
elif isinstance(val, bool):
|
||||||
|
return True, bool(val)
|
||||||
|
elif isinstance(val, (float, Decimal)):
|
||||||
|
return True, round(float(val), 3)
|
||||||
|
elif isinstance(val, int):
|
||||||
|
return True, int(val)
|
||||||
|
elif isinstance(val, str):
|
||||||
|
return True, str(val)
|
||||||
|
elif val is None:
|
||||||
|
return True, None
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err = e
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert(cls: Type[T], smart_options: dict[str, Any], validate_model: Any = None) -> Optional[tuple[BinaryExpression, ...]]:
|
||||||
|
"""
|
||||||
|
Convert smart options to SQLAlchemy filter expressions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smart_options: Dictionary of filter options
|
||||||
|
validate_model: Optional model to validate against
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of SQLAlchemy filter expressions or None if validation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Let SQLAlchemy handle the validation by attempting to create the filter expressions
|
||||||
|
return tuple(cls.filter_expr(**smart_options))
|
||||||
|
except Exception as e:
|
||||||
|
# If there's an error, provide a helpful message with valid columns and relationships
|
||||||
|
valid_columns = set()
|
||||||
|
relationship_names = set()
|
||||||
|
|
||||||
|
# Get column names if available
|
||||||
|
if hasattr(cls, '__table__') and hasattr(cls.__table__, 'columns'):
|
||||||
|
valid_columns = set(column.key for column in cls.__table__.columns)
|
||||||
|
|
||||||
|
# Get relationship names if available
|
||||||
|
if hasattr(cls, '__mapper__') and hasattr(cls.__mapper__, 'relationships'):
|
||||||
|
relationship_names = set(rel.key for rel in cls.__mapper__.relationships)
|
||||||
|
|
||||||
|
# Create a helpful error message
|
||||||
|
error_msg = f"Error in filter expression: {str(e)}\n"
|
||||||
|
error_msg += f"Attempted to filter with: {smart_options}\n"
|
||||||
|
error_msg += f"Valid columns are: {', '.join(valid_columns)}\n"
|
||||||
|
error_msg += f"Valid relationships are: {', '.join(relationship_names)}"
|
||||||
|
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
def get_dict(self, exclude_list: Optional[list[InstrumentedAttribute]] = None) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert model instance to dictionary with customizable fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exclude_list: List of fields to exclude from the dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation of the model
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return_dict: Dict[str, Any] = {}
|
||||||
|
exclude_list = exclude_list or []
|
||||||
|
exclude_list = [exclude_arg.key for exclude_arg in exclude_list]
|
||||||
|
|
||||||
|
# Get all column names from the model
|
||||||
|
columns = [col.name for col in self.__table__.columns]
|
||||||
|
columns_set = set(columns)
|
||||||
|
|
||||||
|
# Filter columns
|
||||||
|
columns_list = set([col for col in columns_set if str(col)[-2:] != "id"])
|
||||||
|
columns_extend = set(col for col in columns_set if str(col)[-5:].lower() == "uu_id")
|
||||||
|
columns_list = set(columns_list) | set(columns_extend)
|
||||||
|
columns_list = list(set(columns_list) - set(exclude_list))
|
||||||
|
|
||||||
|
for key in columns_list:
|
||||||
|
val = getattr(self, key)
|
||||||
|
correct, value_of_database = self.iterate_over_variables(val, key)
|
||||||
|
if correct:
|
||||||
|
return_dict[key] = value_of_database
|
||||||
|
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err = e
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CrudMixin(BasicMixin):
|
||||||
|
"""
|
||||||
|
Base mixin providing CRUD operations and common fields for PostgreSQL models.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Automatic timestamps (created_at, updated_at)
|
||||||
|
- Soft delete capability
|
||||||
|
- User tracking (created_by, updated_by)
|
||||||
|
- Data serialization
|
||||||
|
- Multi-language support
|
||||||
|
"""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
# Primary and reference fields
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
|
uu_id: Mapped[str] = mapped_column(
|
||||||
|
UUID,
|
||||||
|
server_default=text("gen_random_uuid()"),
|
||||||
|
index=True,
|
||||||
|
unique=True,
|
||||||
|
comment="Unique identifier UUID",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Common timestamp fields for all models
|
||||||
|
expiry_starts: Mapped[TIMESTAMP] = mapped_column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
comment="Record validity start timestamp",
|
||||||
|
)
|
||||||
|
expiry_ends: Mapped[TIMESTAMP] = mapped_column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
default=str(arrow.get("2099-12-31")),
|
||||||
|
server_default=func.now(),
|
||||||
|
comment="Record validity end timestamp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timestamps
|
||||||
|
created_at: Mapped[TIMESTAMP] = mapped_column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
comment="Record creation timestamp",
|
||||||
|
)
|
||||||
|
updated_at: Mapped[TIMESTAMP] = mapped_column(
|
||||||
|
TIMESTAMP(timezone=True),
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
comment="Last update timestamp",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CrudCollection(CrudMixin):
|
||||||
|
"""
|
||||||
|
Full-featured model class with all common fields.
|
||||||
|
|
||||||
|
Includes:
|
||||||
|
- UUID and reference ID
|
||||||
|
- Timestamps
|
||||||
|
- User tracking
|
||||||
|
- Confirmation status
|
||||||
|
- Soft delete
|
||||||
|
- Notification flags
|
||||||
|
"""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
__repr__ = ReprMixin.__repr__
|
||||||
|
|
||||||
|
# Outer reference fields
|
||||||
|
ref_id: Mapped[str] = mapped_column(
|
||||||
|
String(100), nullable=True, index=True, comment="External reference ID"
|
||||||
|
)
|
||||||
|
replication_id: Mapped[int] = mapped_column(
|
||||||
|
SmallInteger, server_default="0", comment="Replication identifier"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cryptographic and user tracking
|
||||||
|
cryp_uu_id: Mapped[str] = mapped_column(
|
||||||
|
String, nullable=True, index=True, comment="Cryptographic UUID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token fields of modification
|
||||||
|
created_credentials_token: Mapped[str] = mapped_column(
|
||||||
|
String, nullable=True, comment="Created Credentials token"
|
||||||
|
)
|
||||||
|
updated_credentials_token: Mapped[str] = mapped_column(
|
||||||
|
String, nullable=True, comment="Updated Credentials token"
|
||||||
|
)
|
||||||
|
confirmed_credentials_token: Mapped[str] = mapped_column(
|
||||||
|
String, nullable=True, comment="Confirmed Credentials token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Status flags
|
||||||
|
is_confirmed: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, server_default="0", comment="Record confirmation status"
|
||||||
|
)
|
||||||
|
deleted: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, server_default="0", comment="Soft delete flag"
|
||||||
|
)
|
||||||
|
active: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, server_default="1", comment="Record active status"
|
||||||
|
)
|
||||||
|
is_notification_send: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, server_default="0", comment="Notification sent flag"
|
||||||
|
)
|
||||||
|
is_email_send: Mapped[bool] = mapped_column(
|
||||||
|
Boolean, server_default="0", comment="Email sent flag"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(BaseSettings):
|
||||||
|
"""
|
||||||
|
ApiTemplate configuration settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ACCESS_TOKEN_LENGTH: int = 90
|
||||||
|
REFRESHER_TOKEN_LENGTH: int = 144
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_prefix="API_")
|
||||||
|
|
||||||
|
|
||||||
|
token_config = Configs()
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
import secrets
|
||||||
|
import random
|
||||||
|
|
||||||
|
from .config import token_config
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordModule:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_random_uu_id(str_std: bool = True):
|
||||||
|
return str(uuid.uuid4()) if str_std else uuid.uuid4()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_token(length=32) -> str:
|
||||||
|
letters = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
merged_letters = [letter for letter in letters] + [letter.upper() for letter in letters]
|
||||||
|
token_generated = secrets.token_urlsafe(length)
|
||||||
|
for i in str(token_generated):
|
||||||
|
if i not in merged_letters:
|
||||||
|
token_generated = token_generated.replace(i, random.choice(merged_letters), 1)
|
||||||
|
return token_generated
|
||||||
|
raise ValueError("EYS_0004")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_access_token(cls) -> str:
|
||||||
|
return cls.generate_token(int(token_config.ACCESS_TOKEN_LENGTH))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_refresher_token(cls) -> str:
|
||||||
|
return cls.generate_token(int(token_config.REFRESHER_TOKEN_LENGTH))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_hashed_password(domain: str, id_: str, password: str) -> str:
|
||||||
|
return hashlib.sha256(f"{domain}:{id_}:{password}".encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_password(cls, domain, id_, password, password_hashed) -> bool:
|
||||||
|
return cls.create_hashed_password(domain, id_, password) == password_hashed
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from api_initializer.config import api_config
|
||||||
|
from api_initializer.create_app import create_app
|
||||||
|
|
||||||
|
# from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
|
app = create_app() # Create FastAPI application
|
||||||
|
# Instrumentator().instrument(app=app).expose(app=app) # Setup Prometheus metrics
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn_config = uvicorn.Config(**api_config.app_as_dict, workers=1) # Run the application with Uvicorn Server
|
||||||
|
uvicorn.Server(uvicorn_config).run()
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class Configs(BaseSettings):
|
||||||
|
"""
|
||||||
|
ApiTemplate configuration settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATH: str = ""
|
||||||
|
HOST: str = ""
|
||||||
|
PORT: int = 0
|
||||||
|
LOG_LEVEL: str = "info"
|
||||||
|
RELOAD: int = 0
|
||||||
|
ACCESS_TOKEN_TAG: str = ""
|
||||||
|
|
||||||
|
ACCESS_EMAIL_EXT: str = ""
|
||||||
|
TITLE: str = ""
|
||||||
|
ALGORITHM: str = ""
|
||||||
|
ACCESS_TOKEN_LENGTH: int = 90
|
||||||
|
REFRESHER_TOKEN_LENGTH: int = 144
|
||||||
|
EMAIL_HOST: str = ""
|
||||||
|
DATETIME_FORMAT: str = ""
|
||||||
|
FORGOT_LINK: str = ""
|
||||||
|
ALLOW_ORIGINS: list = ["http://localhost:3000", "http://localhost:3001", "http://localhost:3001/api", "http://localhost:3001/api/"]
|
||||||
|
VERSION: str = "0.1.001"
|
||||||
|
DESCRIPTION: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_as_dict(self) -> dict:
|
||||||
|
"""
|
||||||
|
Convert the settings to a dictionary.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"app": self.PATH,
|
||||||
|
"host": self.HOST,
|
||||||
|
"port": int(self.PORT),
|
||||||
|
"log_level": self.LOG_LEVEL,
|
||||||
|
"reload": bool(self.RELOAD),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_info(self):
|
||||||
|
"""
|
||||||
|
Returns a dictionary with application information.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"title": self.TITLE,
|
||||||
|
"description": self.DESCRIPTION,
|
||||||
|
"default_response_class": JSONResponse,
|
||||||
|
"version": self.VERSION,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def forgot_link(cls, forgot_key):
|
||||||
|
"""
|
||||||
|
Generate a forgot password link.
|
||||||
|
"""
|
||||||
|
return cls.FORGOT_LINK + forgot_key
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_prefix="API_")
|
||||||
|
|
||||||
|
|
||||||
|
api_config = Configs()
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from event_clusters import RouterCluster, EventCluster
|
||||||
|
from config import api_config
|
||||||
|
from open_api_creator import create_openapi_schema
|
||||||
|
from create_route import RouteRegisterController
|
||||||
|
|
||||||
|
from api_middlewares.token_middleware import token_middleware
|
||||||
|
from endpoints.routes import get_routes
|
||||||
|
import events
|
||||||
|
|
||||||
|
|
||||||
|
cluster_is_set = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_app():
|
||||||
|
|
||||||
|
def create_events_if_any_cluster_set():
|
||||||
|
|
||||||
|
global cluster_is_set
|
||||||
|
if not events.__all__ or cluster_is_set:
|
||||||
|
return
|
||||||
|
|
||||||
|
router_cluster_stack: list[RouterCluster] = [getattr(events, e, None) for e in events.__all__]
|
||||||
|
for router_cluster in router_cluster_stack:
|
||||||
|
event_cluster_stack: list[EventCluster] = list(router_cluster.event_clusters.values())
|
||||||
|
for event_cluster in event_cluster_stack:
|
||||||
|
try:
|
||||||
|
event_cluster.set_events_to_database()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error creating event cluster: {e}")
|
||||||
|
|
||||||
|
cluster_is_set = True
|
||||||
|
|
||||||
|
application = FastAPI(**api_config.api_info)
|
||||||
|
application.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=api_config.ALLOW_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@application.middleware("http")
|
||||||
|
async def add_token_middleware(request: Request, call_next):
|
||||||
|
return await token_middleware(request, call_next)
|
||||||
|
|
||||||
|
@application.get("/", description="Redirect Route", include_in_schema=False)
|
||||||
|
async def redirect_to_docs():
|
||||||
|
return RedirectResponse(url="/docs")
|
||||||
|
|
||||||
|
route_register = RouteRegisterController(app=application, router_list=get_routes())
|
||||||
|
application = route_register.register_routes()
|
||||||
|
|
||||||
|
create_events_if_any_cluster_set()
|
||||||
|
application.openapi = lambda _=application: create_openapi_schema(_)
|
||||||
|
return application
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import List
|
||||||
|
from fastapi import APIRouter, FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
class RouteRegisterController:
|
||||||
|
|
||||||
|
def __init__(self, app: FastAPI, router_list: List[APIRouter]):
|
||||||
|
self.router_list = router_list
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_router_with_event_to_database(router: APIRouter):
|
||||||
|
from schemas import EndpointRestriction
|
||||||
|
|
||||||
|
with EndpointRestriction.new_session() as db_session:
|
||||||
|
EndpointRestriction.set_session(db_session)
|
||||||
|
for route in router.routes:
|
||||||
|
route_path = str(getattr(route, "path"))
|
||||||
|
route_summary = str(getattr(route, "name"))
|
||||||
|
operation_id = getattr(route, "operation_id", None)
|
||||||
|
if not operation_id:
|
||||||
|
raise ValueError(f"Route {route_path} operation_id is not found")
|
||||||
|
if not getattr(route, "methods") and isinstance(getattr(route, "methods")):
|
||||||
|
raise ValueError(f"Route {route_path} methods is not found")
|
||||||
|
|
||||||
|
route_method = [method.lower() for method in getattr(route, "methods")][0]
|
||||||
|
add_or_update_dict = dict(
|
||||||
|
endpoint_method=route_method, endpoint_name=route_path, endpoint_desc=route_summary.replace("_", " "), endpoint_function=route_summary, is_confirmed=True
|
||||||
|
)
|
||||||
|
if to_save_endpoint := EndpointRestriction.query.filter(EndpointRestriction.operation_uu_id == operation_id).first():
|
||||||
|
to_save_endpoint.update(**add_or_update_dict)
|
||||||
|
to_save_endpoint.save()
|
||||||
|
else:
|
||||||
|
created_endpoint = EndpointRestriction.create(**add_or_update_dict, operation_uu_id=operation_id)
|
||||||
|
created_endpoint.save()
|
||||||
|
|
||||||
|
def register_routes(self):
|
||||||
|
for router in self.router_list:
|
||||||
|
self.app.include_router(router)
|
||||||
|
self.add_router_with_event_to_database(router)
|
||||||
|
return self.app
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
from typing import Optional, Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EventCluster:
|
||||||
|
"""
|
||||||
|
EventCluster
|
||||||
|
"""
|
||||||
|
def __repr__(self):
|
||||||
|
return f"EventCluster(name={self.name})"
|
||||||
|
|
||||||
|
def __init__(self, endpoint_uu_id: str, name: str):
|
||||||
|
self.endpoint_uu_id = endpoint_uu_id
|
||||||
|
self.name = name
|
||||||
|
self.events: list["Event"] = []
|
||||||
|
|
||||||
|
def add_event(self, event: "Event"):
|
||||||
|
"""
|
||||||
|
Add an event to the cluster
|
||||||
|
"""
|
||||||
|
if event.key not in [e.key for e in self.events]:
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
|
def get_event(self, event_key: str):
|
||||||
|
"""
|
||||||
|
Get an event by its key
|
||||||
|
"""
|
||||||
|
|
||||||
|
for event in self.events:
|
||||||
|
if event.key == event_key:
|
||||||
|
return event
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_events_to_database(self):
|
||||||
|
from schemas import Events, EndpointRestriction
|
||||||
|
|
||||||
|
with Events.new_session() as db_session:
|
||||||
|
Events.set_session(db_session)
|
||||||
|
EndpointRestriction.set_session(db_session)
|
||||||
|
|
||||||
|
if to_save_endpoint := EndpointRestriction.query.filter(EndpointRestriction.operation_uu_id == self.endpoint_uu_id).first():
|
||||||
|
print('to_save_endpoint', to_save_endpoint)
|
||||||
|
for event in self.events:
|
||||||
|
event_dict_to_save = dict(
|
||||||
|
function_code=event.key,
|
||||||
|
function_class=event.name,
|
||||||
|
description=event.description,
|
||||||
|
endpoint_code=self.endpoint_uu_id,
|
||||||
|
endpoint_id=to_save_endpoint.id,
|
||||||
|
endpoint_uu_id=str(to_save_endpoint.uu_id),
|
||||||
|
is_confirmed=True,
|
||||||
|
)
|
||||||
|
print('set_events_to_database event_dict_to_save', event_dict_to_save)
|
||||||
|
check_event = Events.query.filter(Events.endpoint_uu_id == event_dict_to_save["endpoint_uu_id"]).first()
|
||||||
|
if check_event:
|
||||||
|
check_event.update(**event_dict_to_save)
|
||||||
|
check_event.save()
|
||||||
|
else:
|
||||||
|
event_created = Events.create(**event_dict_to_save)
|
||||||
|
print(f"UUID: {event_created.uu_id} event is saved to {to_save_endpoint.uu_id}")
|
||||||
|
event_created.save()
|
||||||
|
|
||||||
|
def match_event(self, event_key: str) -> "Event":
|
||||||
|
"""
|
||||||
|
Match an event by its key
|
||||||
|
"""
|
||||||
|
if event := self.get_event(event_key=event_key):
|
||||||
|
return event
|
||||||
|
raise ValueError("Event key not found")
|
||||||
|
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
key: str,
|
||||||
|
request_validator: Optional[Type[BaseModel]] = None,
|
||||||
|
response_validator: Optional[Type[BaseModel]] = None,
|
||||||
|
description: str = "",
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.key = key
|
||||||
|
self.request_validator = request_validator
|
||||||
|
self.response_validator = response_validator
|
||||||
|
self.description = description
|
||||||
|
|
||||||
|
def event_callable(self):
|
||||||
|
"""
|
||||||
|
Example callable method
|
||||||
|
"""
|
||||||
|
print(self.name)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class RouterCluster:
|
||||||
|
"""
|
||||||
|
RouterCluster
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"RouterCluster(name={self.name})"
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self.name = name
|
||||||
|
self.event_clusters: dict[str, EventCluster] = {}
|
||||||
|
|
||||||
|
def set_event_cluster(self, event_cluster: EventCluster):
|
||||||
|
"""
|
||||||
|
Add an event cluster to the set
|
||||||
|
"""
|
||||||
|
print("Setting event cluster:", event_cluster.name)
|
||||||
|
if event_cluster.name not in self.event_clusters:
|
||||||
|
self.event_clusters[event_cluster.name] = event_cluster
|
||||||
|
|
||||||
|
def get_event_cluster(self, event_cluster_name: str) -> EventCluster:
|
||||||
|
"""
|
||||||
|
Get an event cluster by its name
|
||||||
|
"""
|
||||||
|
if event_cluster_name not in self.event_clusters:
|
||||||
|
raise ValueError("Event cluster not found")
|
||||||
|
return self.event_clusters[event_cluster_name]
|
||||||
|
|
@ -0,0 +1,116 @@
|
||||||
|
from typing import Any, Dict
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
|
||||||
|
from endpoints.routes import get_safe_endpoint_urls
|
||||||
|
from config import api_config
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAPISchemaCreator:
|
||||||
|
"""
|
||||||
|
OpenAPI schema creator and customizer for FastAPI applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: FastAPI):
|
||||||
|
"""
|
||||||
|
Initialize the OpenAPI schema creator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
"""
|
||||||
|
self.app = app
|
||||||
|
self.safe_endpoint_list: list[tuple[str, str]] = get_safe_endpoint_urls()
|
||||||
|
self.routers_list = self.app.routes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_security_schemes() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create security scheme definitions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Security scheme configurations
|
||||||
|
"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"BearerAuth": {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": api_config.ACCESS_TOKEN_TAG,
|
||||||
|
"description": "Enter: **'Bearer <JWT>'**, where JWT is the access token",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def configure_route_security(
|
||||||
|
self, path: str, method: str, schema: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Configure security requirements for a specific route.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Route path
|
||||||
|
method: HTTP method
|
||||||
|
schema: OpenAPI schema to modify
|
||||||
|
"""
|
||||||
|
if not schema.get("paths", {}).get(path, {}).get(method):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if endpoint is in safe list
|
||||||
|
endpoint_path = f"{path}:{method}"
|
||||||
|
list_of_safe_endpoints = [
|
||||||
|
f"{e[0]}:{str(e[1]).lower()}" for e in self.safe_endpoint_list
|
||||||
|
]
|
||||||
|
if endpoint_path not in list_of_safe_endpoints:
|
||||||
|
if "security" not in schema["paths"][path][method]:
|
||||||
|
schema["paths"][path][method]["security"] = []
|
||||||
|
schema["paths"][path][method]["security"].append({"BearerAuth": []})
|
||||||
|
|
||||||
|
def create_schema(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create the complete OpenAPI schema.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Complete OpenAPI schema
|
||||||
|
"""
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=api_config.TITLE,
|
||||||
|
description=api_config.DESCRIPTION,
|
||||||
|
version=api_config.VERSION,
|
||||||
|
routes=self.app.routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add security schemes
|
||||||
|
if "components" not in openapi_schema:
|
||||||
|
openapi_schema["components"] = {}
|
||||||
|
|
||||||
|
openapi_schema["components"]["securitySchemes"] = self.create_security_schemes()
|
||||||
|
|
||||||
|
# Configure route security and responses
|
||||||
|
for route in self.app.routes:
|
||||||
|
if isinstance(route, APIRoute) and route.include_in_schema:
|
||||||
|
path = str(route.path)
|
||||||
|
methods = [method.lower() for method in route.methods]
|
||||||
|
for method in methods:
|
||||||
|
self.configure_route_security(path, method, openapi_schema)
|
||||||
|
|
||||||
|
# Add custom documentation extensions
|
||||||
|
openapi_schema["x-documentation"] = {
|
||||||
|
"postman_collection": "/docs/postman",
|
||||||
|
"swagger_ui": "/docs",
|
||||||
|
"redoc": "/redoc",
|
||||||
|
}
|
||||||
|
return openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def create_openapi_schema(app: FastAPI) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create OpenAPI schema for a FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Complete OpenAPI schema
|
||||||
|
"""
|
||||||
|
creator = OpenAPISchemaCreator(app)
|
||||||
|
return creator.create_schema()
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import time
|
||||||
|
|
||||||
|
while True:
|
||||||
|
|
||||||
|
time.sleep(10)
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
|
||||||
|
result_with_keys_dict {'Employees.id': 3, 'Employees.uu_id': UUID('2b757a5c-01bb-4213-9cf1-402480e73edc'), 'People.id': 2, 'People.uu_id': UUID('d945d320-2a4e-48db-a18c-6fd024beb517'), 'Users.id': 3, 'Users.uu_id': UUID('bdfa84d9-0e05-418c-9406-d6d1d41ae2a1'), 'Companies.id': 1, 'Companies.uu_id': UUID('da1de172-2f89-42d2-87f3-656b36a79d5b'), 'Departments.id': 3, 'Departments.uu_id': UUID('4edcec87-e072-408d-a780-3a62151b3971'), 'Duty.id': 9, 'Duty.uu_id': UUID('00d29292-c29e-4435-be41-9704ccf4b24d'), 'Addresses.id': None, 'Addresses.letter_address': None}
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
from fastapi import Header, Request, Response
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from config import api_config
|
||||||
|
|
||||||
|
|
||||||
|
class CommonHeaders(BaseModel):
|
||||||
|
language: str | None = None
|
||||||
|
domain: str | None = None
|
||||||
|
timezone: str | None = None
|
||||||
|
token: str | None = None
|
||||||
|
request: Request | None = None
|
||||||
|
response: Response | None = None
|
||||||
|
operation_id: str | None = None
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_dependency(
|
||||||
|
cls,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
language: str = Header(None, alias="language"),
|
||||||
|
domain: str = Header(None, alias="domain"),
|
||||||
|
tz: str = Header(None, alias="timezone"),
|
||||||
|
):
|
||||||
|
token = request.headers.get(api_config.ACCESS_TOKEN_TAG, None)
|
||||||
|
|
||||||
|
# Extract operation_id from the route
|
||||||
|
operation_id = None
|
||||||
|
if hasattr(request.scope.get("route"), "operation_id"):
|
||||||
|
operation_id = request.scope.get("route").operation_id
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
language=language,
|
||||||
|
domain=domain,
|
||||||
|
timezone=tz,
|
||||||
|
token=token,
|
||||||
|
request=request,
|
||||||
|
response=response,
|
||||||
|
operation_id=operation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_headers_dict(self):
|
||||||
|
"""Convert the headers to a dictionary format used in the application"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
return {
|
||||||
|
"language": self.language or "",
|
||||||
|
"domain": self.domain or "",
|
||||||
|
"eys-ext": f"{str(uuid.uuid4())}",
|
||||||
|
"tz": self.timezone or "GMT+3",
|
||||||
|
"token": self.token,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelCore(BaseModel):
|
||||||
|
|
||||||
|
"""
|
||||||
|
BaseModelCore
|
||||||
|
model_dump override for alias support Users.name -> Table[Users] Field(alias="name")
|
||||||
|
"""
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
validate_by_name = True
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
data = super().model_dump(*args, **kwargs)
|
||||||
|
return {self.__class__.model_fields[field].alias: value for field, value in data.items()}
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .pagination import PaginateOnly, ListOptions, PaginationConfig
|
||||||
|
from .result import Pagination, PaginationResult
|
||||||
|
from .base import PostgresResponseSingle, PostgresResponse, ResultQueryJoin, ResultQueryJoinSingle
|
||||||
|
from .api import EndpointResponse, CreateEndpointResponse
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
from .result import PaginationResult
|
||||||
|
from .base import PostgresResponseSingle
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Any, Type
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointResponse(BaseModel):
|
||||||
|
"""Endpoint response model."""
|
||||||
|
|
||||||
|
completed: bool = True
|
||||||
|
message: str = "Success"
|
||||||
|
pagination_result: PaginationResult
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
result_data = getattr(self.pagination_result, "data", None)
|
||||||
|
if not result_data:
|
||||||
|
return {
|
||||||
|
"completed": False,
|
||||||
|
"message": "MSG0004-NODATA",
|
||||||
|
"data": None,
|
||||||
|
"pagination": None,
|
||||||
|
}
|
||||||
|
result_pagination = getattr(self.pagination_result, "pagination", None)
|
||||||
|
if not result_pagination:
|
||||||
|
raise ValueError("Invalid pagination result pagination.")
|
||||||
|
pagination_dict = getattr(result_pagination, "as_dict", None)
|
||||||
|
if not pagination_dict:
|
||||||
|
raise ValueError("Invalid pagination result as_dict.")
|
||||||
|
return {
|
||||||
|
"completed": self.completed,
|
||||||
|
"message": self.message,
|
||||||
|
"data": result_data,
|
||||||
|
"pagination": pagination_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True
|
||||||
|
}
|
||||||
|
|
||||||
|
class CreateEndpointResponse(BaseModel):
|
||||||
|
"""Create endpoint response model."""
|
||||||
|
|
||||||
|
completed: bool = True
|
||||||
|
message: str = "Success"
|
||||||
|
data: PostgresResponseSingle
|
||||||
|
|
||||||
|
@property
|
||||||
|
def response(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
return {
|
||||||
|
"completed": self.completed,
|
||||||
|
"message": self.message,
|
||||||
|
"data": self.data.data,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""
|
||||||
|
Response handler for PostgreSQL query results.
|
||||||
|
|
||||||
|
This module provides a wrapper class for SQLAlchemy query results,
|
||||||
|
adding convenience methods for accessing data and managing query state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, TypeVar, Generic, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Query
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresResponse(Generic[T]):
|
||||||
|
"""
|
||||||
|
Wrapper for PostgreSQL/SQLAlchemy query results.
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
count: Total count of results
|
||||||
|
query: Get query object
|
||||||
|
as_dict: Convert response to dictionary format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, query: Query, base_model: Optional[BaseModel] = None):
|
||||||
|
self._query = query
|
||||||
|
self._count: Optional[int] = None
|
||||||
|
self._base_model: Optional[BaseModel] = base_model
|
||||||
|
self.single = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self) -> Query:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Union[list[T], T]:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query.all()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query.count()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_dict(self, **kwargs) -> list[dict]:
|
||||||
|
"""Get query object."""
|
||||||
|
if self._base_model:
|
||||||
|
return [self._base_model(**item.to_dict()).model_dump(**kwargs) for item in self.data]
|
||||||
|
return [item.to_dict() for item in self.data]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
return {
|
||||||
|
"query": str(self.query),
|
||||||
|
"count": self.count,
|
||||||
|
"data": self.to_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PostgresResponseSingle(Generic[T]):
|
||||||
|
"""
|
||||||
|
Wrapper for PostgreSQL/SQLAlchemy query results.
|
||||||
|
|
||||||
|
Properties:
|
||||||
|
count: Total count of results
|
||||||
|
query: Get query object
|
||||||
|
as_dict: Convert response to dictionary format
|
||||||
|
data: Get query object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, query: Query, base_model: Optional[BaseModel] = None):
|
||||||
|
self._query = query
|
||||||
|
self._count: Optional[int] = None
|
||||||
|
self._base_model: Optional[BaseModel] = base_model
|
||||||
|
self.single = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self) -> Query:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_dict(self, **kwargs) -> dict:
|
||||||
|
"""Get query object."""
|
||||||
|
if self._base_model:
|
||||||
|
return self._base_model(**self._query.first().to_dict()).model_dump(**kwargs)
|
||||||
|
return self._query.first().to_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> T:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query.first()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query.count()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
return {"query": str(self.query),"data": self.to_dict, "count": self.count}
|
||||||
|
|
||||||
|
|
||||||
|
class ResultQueryJoin:
|
||||||
|
"""
|
||||||
|
ResultQueryJoin
|
||||||
|
params:
|
||||||
|
list_of_instrumented_attributes: list of instrumented attributes
|
||||||
|
query: query object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, list_of_instrumented_attributes, query):
|
||||||
|
"""Initialize ResultQueryJoin"""
|
||||||
|
self.list_of_instrumented_attributes = list_of_instrumented_attributes
|
||||||
|
self._query = query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self):
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
list_of_dictionaries, result = [], dict()
|
||||||
|
for user_orders_shipping_iter in self.query.all():
|
||||||
|
for index, instrumented_attribute_iter in enumerate(self.list_of_instrumented_attributes):
|
||||||
|
result[str(instrumented_attribute_iter)] = user_orders_shipping_iter[index]
|
||||||
|
list_of_dictionaries.append(result)
|
||||||
|
return list_of_dictionaries
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self):
|
||||||
|
"""Get count of query."""
|
||||||
|
return self.query.count()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
"""Get query object."""
|
||||||
|
return self.query.all()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
return {"query": str(self.query), "data": self.data, "count": self.count}
|
||||||
|
|
||||||
|
|
||||||
|
class ResultQueryJoinSingle:
|
||||||
|
"""
|
||||||
|
ResultQueryJoinSingle
|
||||||
|
params:
|
||||||
|
list_of_instrumented_attributes: list of instrumented attributes
|
||||||
|
query: query object
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, list_of_instrumented_attributes, query):
|
||||||
|
"""Initialize ResultQueryJoinSingle"""
|
||||||
|
self.list_of_instrumented_attributes = list_of_instrumented_attributes
|
||||||
|
self._query = query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self):
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_dict(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
data, result = self.query.first(), dict()
|
||||||
|
for index, instrumented_attribute_iter in enumerate(self.list_of_instrumented_attributes):
|
||||||
|
result[str(instrumented_attribute_iter)] = data[index]
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self):
|
||||||
|
"""Get count of query."""
|
||||||
|
return self.query.count()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
"""Get query object."""
|
||||||
|
return self._query.first()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self):
|
||||||
|
"""Convert response to dictionary format."""
|
||||||
|
return {"query": str(self.query), "data": self.data, "count": self.count}
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
|
||||||
|
|
||||||
|
class UserPydantic(BaseModel):
|
||||||
|
|
||||||
|
username: str = Field(..., alias='user.username')
|
||||||
|
account_balance: float = Field(..., alias='user.account_balance')
|
||||||
|
preferred_category_id: Optional[int] = Field(None, alias='user.preferred_category_id')
|
||||||
|
last_ordered_product_id: Optional[int] = Field(None, alias='user.last_ordered_product_id')
|
||||||
|
supplier_rating_id: Optional[int] = Field(None, alias='user.supplier_rating_id')
|
||||||
|
other_rating_id: Optional[int] = Field(None, alias='product.supplier_rating_id')
|
||||||
|
id: int = Field(..., alias='user.id')
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
validate_by_name = True
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs):
|
||||||
|
data = super().model_dump(*args, **kwargs)
|
||||||
|
return {self.__class__.model_fields[field].alias: value for field, value in data.items()}
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
from typing import Any, Dict, Optional, Union, TypeVar, Type
|
||||||
|
from sqlalchemy import desc, asc
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .base import PostgresResponse
|
||||||
|
|
||||||
|
# Type variable for class methods returning self
|
||||||
|
T = TypeVar("T", bound="BaseModel")
|
||||||
|
|
||||||
|
|
||||||
|
class PaginateConfig:
|
||||||
|
"""
|
||||||
|
Configuration for pagination settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
DEFAULT_SIZE: Default number of items per page (10)
|
||||||
|
MIN_SIZE: Minimum allowed page size (10)
|
||||||
|
MAX_SIZE: Maximum allowed page size (40)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_SIZE = 10
|
||||||
|
MIN_SIZE = 5
|
||||||
|
MAX_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class ListOptions(BaseModel):
|
||||||
|
"""
|
||||||
|
Query for list option abilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
page: Optional[int] = 1
|
||||||
|
size: Optional[int] = 10
|
||||||
|
orderField: Optional[Union[tuple[str], list[str]]] = ["uu_id"]
|
||||||
|
orderType: Optional[Union[tuple[str], list[str]]] = ["asc"]
|
||||||
|
# include_joins: Optional[list] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PaginateOnly(ListOptions):
|
||||||
|
"""
|
||||||
|
Query for list option abilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
query: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for pagination settings.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
page: Current page number (default: 1)
|
||||||
|
size: Items per page (default: 10)
|
||||||
|
orderField: Field to order by (default: "created_at")
|
||||||
|
orderType: Order direction (default: "desc")
|
||||||
|
"""
|
||||||
|
|
||||||
|
page: int = 1
|
||||||
|
size: int = 10
|
||||||
|
orderField: Optional[Union[tuple[str], list[str]]] = ["created_at"]
|
||||||
|
orderType: Optional[Union[tuple[str], list[str]]] = ["desc"]
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
if self.orderField is None:
|
||||||
|
self.orderField = ["created_at"]
|
||||||
|
if self.orderType is None:
|
||||||
|
self.orderType = ["desc"]
|
||||||
|
|
||||||
|
|
||||||
|
default_paginate_config = PaginateConfig()
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
from typing import Optional, Union, Type, Any, Dict, TypeVar
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Query
|
||||||
|
from sqlalchemy import asc, desc
|
||||||
|
|
||||||
|
from .pagination import default_paginate_config
|
||||||
|
from .base import PostgresResponse
|
||||||
|
from .pagination import PaginationConfig
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
class Pagination:
|
||||||
|
"""
|
||||||
|
Handles pagination logic for query results.
|
||||||
|
|
||||||
|
Manages page size, current page, ordering, and calculates total pages
|
||||||
|
and items based on the data source.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
DEFAULT_SIZE: Default number of items per page (10)
|
||||||
|
MIN_SIZE: Minimum allowed page size (10)
|
||||||
|
MAX_SIZE: Maximum allowed page size (40)
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_SIZE = default_paginate_config.DEFAULT_SIZE
|
||||||
|
MIN_SIZE = default_paginate_config.MIN_SIZE
|
||||||
|
MAX_SIZE = default_paginate_config.MAX_SIZE
|
||||||
|
|
||||||
|
def __init__(self, data: PostgresResponse):
|
||||||
|
self.query = data
|
||||||
|
self.size: int = self.DEFAULT_SIZE
|
||||||
|
self.page: int = 1
|
||||||
|
self.orderField: Optional[Union[tuple[str], list[str]]] = ["uu_id"]
|
||||||
|
self.orderType: Optional[Union[tuple[str], list[str]]] = ["asc"]
|
||||||
|
self.page_count: int = 1
|
||||||
|
self.total_count: int = 0
|
||||||
|
self.all_count: int = 0
|
||||||
|
self.total_pages: int = 1
|
||||||
|
self._update_page_counts()
|
||||||
|
|
||||||
|
def change(self, **kwargs) -> None:
|
||||||
|
"""Update pagination settings from config."""
|
||||||
|
config = PaginationConfig(**kwargs)
|
||||||
|
self.size = (
|
||||||
|
config.size
|
||||||
|
if self.MIN_SIZE <= config.size <= self.MAX_SIZE
|
||||||
|
else self.DEFAULT_SIZE
|
||||||
|
)
|
||||||
|
self.page = config.page
|
||||||
|
self.orderField = config.orderField
|
||||||
|
self.orderType = config.orderType
|
||||||
|
self._update_page_counts()
|
||||||
|
|
||||||
|
def feed(self, data: PostgresResponse) -> None:
|
||||||
|
"""Calculate pagination based on data source."""
|
||||||
|
self.query = data
|
||||||
|
self._update_page_counts()
|
||||||
|
|
||||||
|
def _update_page_counts(self) -> None:
|
||||||
|
"""Update page counts and validate current page."""
|
||||||
|
if self.query:
|
||||||
|
self.total_count = self.query.count()
|
||||||
|
self.all_count = self.query.count()
|
||||||
|
|
||||||
|
self.size = (
|
||||||
|
self.size
|
||||||
|
if self.MIN_SIZE <= self.size <= self.MAX_SIZE
|
||||||
|
else self.DEFAULT_SIZE
|
||||||
|
)
|
||||||
|
self.total_pages = max(1, (self.total_count + self.size - 1) // self.size)
|
||||||
|
self.page = max(1, min(self.page, self.total_pages))
|
||||||
|
self.page_count = (
|
||||||
|
self.total_count % self.size
|
||||||
|
if self.page == self.total_pages and self.total_count % self.size
|
||||||
|
else self.size
|
||||||
|
)
|
||||||
|
|
||||||
|
def refresh(self) -> None:
|
||||||
|
"""Reset pagination state to defaults."""
|
||||||
|
self._update_page_counts()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset pagination state to defaults."""
|
||||||
|
self.size = self.DEFAULT_SIZE
|
||||||
|
self.page = 1
|
||||||
|
self.orderField = "uu_id"
|
||||||
|
self.orderType = "asc"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_available(self) -> bool:
|
||||||
|
if self.page < self.total_pages:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def back_available(self) -> bool:
|
||||||
|
if self.page > 1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert pagination state to dictionary format."""
|
||||||
|
self.refresh()
|
||||||
|
return {
|
||||||
|
"size": self.size,
|
||||||
|
"page": self.page,
|
||||||
|
"allCount": self.all_count,
|
||||||
|
"totalCount": self.total_count,
|
||||||
|
"totalPages": self.total_pages,
|
||||||
|
"pageCount": self.page_count,
|
||||||
|
"orderField": self.orderField,
|
||||||
|
"orderType": self.orderType,
|
||||||
|
"next": self.next_available,
|
||||||
|
"back": self.back_available,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationResult:
|
||||||
|
"""
|
||||||
|
Result of a paginated query.
|
||||||
|
|
||||||
|
Contains the query result and pagination state.
|
||||||
|
data: PostgresResponse of query results
|
||||||
|
pagination: Pagination state
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_query: Original query object
|
||||||
|
pagination: Pagination state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
data: PostgresResponse,
|
||||||
|
pagination: Pagination,
|
||||||
|
is_list: bool = True,
|
||||||
|
response_model: Type[T] = None,
|
||||||
|
):
|
||||||
|
self._query = data
|
||||||
|
self.pagination = pagination
|
||||||
|
self.response_type = is_list
|
||||||
|
self.limit = self.pagination.size
|
||||||
|
self.offset = self.pagination.size * (self.pagination.page - 1)
|
||||||
|
self.order_by = self.pagination.orderField
|
||||||
|
self.response_model = response_model
|
||||||
|
|
||||||
|
def dynamic_order_by(self):
|
||||||
|
"""
|
||||||
|
Dynamically order a query by multiple fields.
|
||||||
|
Returns:
|
||||||
|
Ordered query object.
|
||||||
|
"""
|
||||||
|
if not len(self.order_by) == len(self.pagination.orderType):
|
||||||
|
raise ValueError(
|
||||||
|
"Order by fields and order types must have the same length."
|
||||||
|
)
|
||||||
|
order_criteria = zip(self.order_by, self.pagination.orderType)
|
||||||
|
for field, direction in order_criteria:
|
||||||
|
if hasattr(self._query.column_descriptions[0]["entity"], field):
|
||||||
|
if direction.lower().startswith("d"):
|
||||||
|
self._query = self._query.order_by(
|
||||||
|
desc(
|
||||||
|
getattr(self._query.column_descriptions[0]["entity"], field)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._query = self._query.order_by(
|
||||||
|
asc(
|
||||||
|
getattr(self._query.column_descriptions[0]["entity"], field)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Union[list | dict]:
|
||||||
|
"""Get query object."""
|
||||||
|
query_paginated = self.dynamic_order_by().limit(self.limit).offset(self.offset)
|
||||||
|
queried_data = (query_paginated.all() if self.response_type else query_paginated.first())
|
||||||
|
data = ([result.get_dict() for result in queried_data] if self.response_type else queried_data.get_dict())
|
||||||
|
return [self.response_model(**item).model_dump() for item in data] if self.response_model else data
|
||||||
|
|
@ -0,0 +1,123 @@
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class UserType(Enum):
|
||||||
|
|
||||||
|
employee = 1
|
||||||
|
occupant = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Credentials(BaseModel):
|
||||||
|
|
||||||
|
person_id: int
|
||||||
|
person_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationToken(BaseModel):
|
||||||
|
# Application Token Object -> is the main object for the user
|
||||||
|
|
||||||
|
user_type: int = UserType.occupant.value
|
||||||
|
credential_token: str = ""
|
||||||
|
|
||||||
|
user_uu_id: str
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
person_id: int
|
||||||
|
person_uu_id: str
|
||||||
|
|
||||||
|
request: Optional[dict] = None # Request Info of Client
|
||||||
|
expires_at: Optional[float] = None # Expiry timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class OccupantToken(BaseModel):
|
||||||
|
|
||||||
|
# Selection of the occupant type for a build part is made by the user
|
||||||
|
|
||||||
|
living_space_id: int # Internal use
|
||||||
|
living_space_uu_id: str # Outer use
|
||||||
|
|
||||||
|
occupant_type_id: int
|
||||||
|
occupant_type_uu_id: str
|
||||||
|
occupant_type: str
|
||||||
|
|
||||||
|
build_id: int
|
||||||
|
build_uuid: str
|
||||||
|
build_part_id: int
|
||||||
|
build_part_uuid: str
|
||||||
|
|
||||||
|
responsible_company_id: Optional[int] = None
|
||||||
|
responsible_company_uuid: Optional[str] = None
|
||||||
|
responsible_employee_id: Optional[int] = None
|
||||||
|
responsible_employee_uuid: Optional[str] = None
|
||||||
|
|
||||||
|
# ID list of reachable event codes as "endpoint_code": ["UUID", "UUID"]
|
||||||
|
reachable_event_codes: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
# ID list of reachable applications as "page_url": ["UUID", "UUID"]
|
||||||
|
reachable_app_codes: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompanyToken(BaseModel):
|
||||||
|
|
||||||
|
# Selection of the company for an employee is made by the user
|
||||||
|
company_id: int
|
||||||
|
company_uu_id: str
|
||||||
|
|
||||||
|
department_id: int # ID list of departments
|
||||||
|
department_uu_id: str # ID list of departments
|
||||||
|
|
||||||
|
duty_id: int
|
||||||
|
duty_uu_id: str
|
||||||
|
|
||||||
|
staff_id: int
|
||||||
|
staff_uu_id: str
|
||||||
|
|
||||||
|
employee_id: int
|
||||||
|
employee_uu_id: str
|
||||||
|
bulk_duties_id: int
|
||||||
|
|
||||||
|
# ID list of reachable event codes as "endpoint_code": ["UUID", "UUID"]
|
||||||
|
reachable_event_codes: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
# ID list of reachable applications as "page_url": ["UUID", "UUID"]
|
||||||
|
reachable_app_codes: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class OccupantTokenObject(ApplicationToken):
|
||||||
|
# Occupant Token Object -> Requires selection of the occupant type for a specific build part
|
||||||
|
|
||||||
|
available_occupants: dict = None
|
||||||
|
selected_occupant: Optional[OccupantToken] = None # Selected Occupant Type
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_employee(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_occupant(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class EmployeeTokenObject(ApplicationToken):
|
||||||
|
# Full hierarchy Employee[staff_id] -> Staff -> Duty -> Department -> Company
|
||||||
|
|
||||||
|
companies_id_list: list[int] # List of company objects
|
||||||
|
companies_uu_id_list: list[str] # List of company objects
|
||||||
|
|
||||||
|
duty_id_list: list[int] # List of duty objects
|
||||||
|
duty_uu_id_list: list[str] # List of duty objects
|
||||||
|
|
||||||
|
selected_company: Optional[CompanyToken] = None # Selected Company Object
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_employee(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_occupant(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
TokenDictType = Union[EmployeeTokenObject, OccupantTokenObject]
|
||||||
Loading…
Reference in New Issue