wag-managment-api-service-v.../Services/MongoDb/Models/action_models/base.py

116 lines
3.7 KiB
Python

"""Base models for MongoDB documents."""
from typing import Any, Dict, Optional, Union
from bson import ObjectId
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema, core_schema
from ApiLibrary import system_arrow
class PyObjectId(ObjectId):
"""Custom type for handling MongoDB ObjectId in Pydantic models."""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: Any,
) -> CoreSchema:
"""Define the core schema for PyObjectId."""
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(ObjectId),
core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(cls.validate),
]
),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: str(x),
return_schema=core_schema.str_schema(),
when_used="json",
),
)
@classmethod
def validate(cls, value: Any) -> ObjectId:
"""Validate and convert the value to ObjectId."""
if not ObjectId.is_valid(value):
raise ValueError("Invalid ObjectId")
return ObjectId(value)
@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: CoreSchema,
_handler: Any,
) -> JsonSchemaValue:
"""Define the JSON schema for PyObjectId."""
return {"type": "string"}
class MongoBaseModel(BaseModel):
"""Base model for all MongoDB documents."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={ObjectId: str},
populate_by_name=True,
from_attributes=True,
validate_assignment=True,
extra="allow",
)
# Optional _id field that will be ignored in create operations
id: Optional[PyObjectId] = Field(None, alias="_id")
def get_extra(self, field_name: str, default: Any = None) -> Any:
"""Safely get extra field value.
Args:
field_name: Name of the extra field to retrieve
default: Default value to return if field doesn't exist
Returns:
Value of the extra field if it exists, otherwise the default value
"""
return getattr(self, field_name, default)
def as_dict(self) -> Dict[str, Any]:
"""Convert model to dictionary including all fields and extra fields.
Returns:
Dict containing all model fields and extra fields with proper type conversion
"""
return self.model_dump(by_alias=True)
class MongoDocument(MongoBaseModel):
"""Base document model with timestamps."""
created_at: float = Field(default_factory=lambda: system_arrow.now().timestamp())
updated_at: float = Field(default_factory=lambda: system_arrow.now().timestamp())
@model_validator(mode="before")
@classmethod
def prevent_protected_fields(cls, data: Any) -> Any:
"""Prevent user from setting protected fields like _id and timestamps."""
if isinstance(data, dict):
# Remove protected fields from input
data.pop("_id", None)
data.pop("created_at", None)
data.pop("updated_at", None)
# Set timestamps
data["created_at"] = system_arrow.now().timestamp()
data["updated_at"] = system_arrow.now().timestamp()
return data