"""Base models for MongoDB documents.""" from typing import Any, Dict, Optional, Union from bson import ObjectId from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic.json_schema import JsonSchemaValue from pydantic_core import CoreSchema, core_schema from ApiLibrary import system_arrow class PyObjectId(ObjectId): """Custom type for handling MongoDB ObjectId in Pydantic models.""" @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: Any, ) -> CoreSchema: """Define the core schema for PyObjectId.""" return core_schema.json_or_python_schema( json_schema=core_schema.str_schema(), python_schema=core_schema.union_schema( [ core_schema.is_instance_schema(ObjectId), core_schema.chain_schema( [ core_schema.str_schema(), core_schema.no_info_plain_validator_function(cls.validate), ] ), ] ), serialization=core_schema.plain_serializer_function_ser_schema( lambda x: str(x), return_schema=core_schema.str_schema(), when_used="json", ), ) @classmethod def validate(cls, value: Any) -> ObjectId: """Validate and convert the value to ObjectId.""" if not ObjectId.is_valid(value): raise ValueError("Invalid ObjectId") return ObjectId(value) @classmethod def __get_pydantic_json_schema__( cls, _core_schema: CoreSchema, _handler: Any, ) -> JsonSchemaValue: """Define the JSON schema for PyObjectId.""" return {"type": "string"} class MongoBaseModel(BaseModel): """Base model for all MongoDB documents.""" model_config = ConfigDict( arbitrary_types_allowed=True, json_encoders={ObjectId: str}, populate_by_name=True, from_attributes=True, validate_assignment=True, extra="allow", ) # Optional _id field that will be ignored in create operations id: Optional[PyObjectId] = Field(None, alias="_id") def get_extra(self, field_name: str, default: Any = None) -> Any: """Safely get extra field value. Args: field_name: Name of the extra field to retrieve default: Default value to return if field doesn't exist Returns: Value of the extra field if it exists, otherwise the default value """ return getattr(self, field_name, default) def as_dict(self) -> Dict[str, Any]: """Convert model to dictionary including all fields and extra fields. Returns: Dict containing all model fields and extra fields with proper type conversion """ return self.model_dump(by_alias=True) class MongoDocument(MongoBaseModel): """Base document model with timestamps.""" created_at: float = Field(default_factory=lambda: system_arrow.now().timestamp()) updated_at: float = Field(default_factory=lambda: system_arrow.now().timestamp()) @model_validator(mode="before") @classmethod def prevent_protected_fields(cls, data: Any) -> Any: """Prevent user from setting protected fields like _id and timestamps.""" if isinstance(data, dict): # Remove protected fields from input data.pop("_id", None) data.pop("created_at", None) data.pop("updated_at", None) # Set timestamps data["created_at"] = system_arrow.now().timestamp() data["updated_at"] = system_arrow.now().timestamp() return data