"""Base models for MongoDB documents.""" from typing import Any, Dict, Optional, Union from bson import ObjectId from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic.json_schema import JsonSchemaValue from pydantic_core import CoreSchema, core_schema from ApiLibrary import system_arrow class PyObjectId(ObjectId): """Custom type for handling MongoDB ObjectId in Pydantic models.""" @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: Any, ) -> CoreSchema: """Define the core schema for PyObjectId.""" return core_schema.json_or_python_schema( json_schema=core_schema.str_schema(), python_schema=core_schema.union_schema([ core_schema.is_instance_schema(ObjectId), core_schema.chain_schema([ core_schema.str_schema(), core_schema.no_info_plain_validator_function(cls.validate), ]), ]), serialization=core_schema.plain_serializer_function_ser_schema( lambda x: str(x), return_schema=core_schema.str_schema(), when_used='json', ), ) @classmethod def validate(cls, value: Any) -> ObjectId: """Validate and convert the value to ObjectId.""" if not ObjectId.is_valid(value): raise ValueError("Invalid ObjectId") return ObjectId(value) @classmethod def __get_pydantic_json_schema__( cls, _core_schema: CoreSchema, _handler: Any, ) -> JsonSchemaValue: """Define the JSON schema for PyObjectId.""" return {"type": "string"} class MongoBaseModel(BaseModel): """Base model for all MongoDB documents.""" model_config = ConfigDict( arbitrary_types_allowed=True, json_encoders={ObjectId: str}, populate_by_name=True, from_attributes=True, validate_assignment=True, extra='allow' ) # Optional _id field that will be ignored in create operations id: Optional[PyObjectId] = Field(None, alias="_id") def get_extra(self, field_name: str, default: Any = None) -> Any: """Safely get extra field value. Args: field_name: Name of the extra field to retrieve default: Default value to return if field doesn't exist Returns: Value of the extra field if it exists, otherwise the default value """ return getattr(self, field_name, default) def as_dict(self) -> Dict[str, Any]: """Convert model to dictionary including all fields and extra fields. Returns: Dict containing all model fields and extra fields with proper type conversion """ return self.model_dump(by_alias=True) class MongoDocument(MongoBaseModel): """Base document model with timestamps.""" created_at: float = Field(default_factory=lambda: system_arrow.now().timestamp()) updated_at: float = Field(default_factory=lambda: system_arrow.now().timestamp()) @model_validator(mode='before') @classmethod def prevent_protected_fields(cls, data: Any) -> Any: """Prevent user from setting protected fields like _id and timestamps.""" if isinstance(data, dict): # Remove protected fields from input data.pop('_id', None) data.pop('created_at', None) data.pop('updated_at', None) # Set timestamps data['created_at'] = system_arrow.now().timestamp() data['updated_at'] = system_arrow.now().timestamp() return data