116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
"""Base models for MongoDB documents."""
|
|
|
|
from typing import Any, Dict, Optional, Union
|
|
from bson import ObjectId
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
from pydantic.json_schema import JsonSchemaValue
|
|
from pydantic_core import CoreSchema, core_schema
|
|
|
|
from ApiLibrary import system_arrow
|
|
|
|
|
|
class PyObjectId(ObjectId):
|
|
"""Custom type for handling MongoDB ObjectId in Pydantic models."""
|
|
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(
|
|
cls,
|
|
_source_type: Any,
|
|
_handler: Any,
|
|
) -> CoreSchema:
|
|
"""Define the core schema for PyObjectId."""
|
|
return core_schema.json_or_python_schema(
|
|
json_schema=core_schema.str_schema(),
|
|
python_schema=core_schema.union_schema(
|
|
[
|
|
core_schema.is_instance_schema(ObjectId),
|
|
core_schema.chain_schema(
|
|
[
|
|
core_schema.str_schema(),
|
|
core_schema.no_info_plain_validator_function(cls.validate),
|
|
]
|
|
),
|
|
]
|
|
),
|
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
|
lambda x: str(x),
|
|
return_schema=core_schema.str_schema(),
|
|
when_used="json",
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
def validate(cls, value: Any) -> ObjectId:
|
|
"""Validate and convert the value to ObjectId."""
|
|
if not ObjectId.is_valid(value):
|
|
raise ValueError("Invalid ObjectId")
|
|
return ObjectId(value)
|
|
|
|
@classmethod
|
|
def __get_pydantic_json_schema__(
|
|
cls,
|
|
_core_schema: CoreSchema,
|
|
_handler: Any,
|
|
) -> JsonSchemaValue:
|
|
"""Define the JSON schema for PyObjectId."""
|
|
return {"type": "string"}
|
|
|
|
|
|
class MongoBaseModel(BaseModel):
|
|
"""Base model for all MongoDB documents."""
|
|
|
|
model_config = ConfigDict(
|
|
arbitrary_types_allowed=True,
|
|
json_encoders={ObjectId: str},
|
|
populate_by_name=True,
|
|
from_attributes=True,
|
|
validate_assignment=True,
|
|
extra="allow",
|
|
)
|
|
|
|
# Optional _id field that will be ignored in create operations
|
|
id: Optional[PyObjectId] = Field(None, alias="_id")
|
|
|
|
def get_extra(self, field_name: str, default: Any = None) -> Any:
|
|
"""Safely get extra field value.
|
|
|
|
Args:
|
|
field_name: Name of the extra field to retrieve
|
|
default: Default value to return if field doesn't exist
|
|
|
|
Returns:
|
|
Value of the extra field if it exists, otherwise the default value
|
|
"""
|
|
return getattr(self, field_name, default)
|
|
|
|
def as_dict(self) -> Dict[str, Any]:
|
|
"""Convert model to dictionary including all fields and extra fields.
|
|
|
|
Returns:
|
|
Dict containing all model fields and extra fields with proper type conversion
|
|
"""
|
|
return self.model_dump(by_alias=True)
|
|
|
|
|
|
class MongoDocument(MongoBaseModel):
|
|
"""Base document model with timestamps."""
|
|
|
|
created_at: float = Field(default_factory=lambda: system_arrow.now().timestamp())
|
|
updated_at: float = Field(default_factory=lambda: system_arrow.now().timestamp())
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def prevent_protected_fields(cls, data: Any) -> Any:
|
|
"""Prevent user from setting protected fields like _id and timestamps."""
|
|
if isinstance(data, dict):
|
|
# Remove protected fields from input
|
|
data.pop("_id", None)
|
|
data.pop("created_at", None)
|
|
data.pop("updated_at", None)
|
|
|
|
# Set timestamps
|
|
data["created_at"] = system_arrow.now().timestamp()
|
|
data["updated_at"] = system_arrow.now().timestamp()
|
|
|
|
return data
|