190 lines
5.9 KiB
Python
190 lines
5.9 KiB
Python
import asyncio
|
|
import time
|
|
import logging
|
|
import uvloop
|
|
import threading
|
|
import datetime
|
|
import uuid
|
|
|
|
from typing import Optional, AsyncGenerator, Any, TypeVar, Union
|
|
from contextlib import asynccontextmanager
|
|
from prisma import Prisma
|
|
from prisma.client import _PrismaModel
|
|
|
|
|
|
_PrismaModelT = TypeVar('_PrismaModelT', bound='_PrismaModel')
|
|
|
|
|
|
logger = logging.getLogger("prisma-service")
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
|
|
|
|
class PrismaService:
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self._lock = asyncio.Lock()
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._thread: Optional[threading.Thread] = None
|
|
self._client: Optional[Prisma] = None
|
|
self.result: Optional[Any] = None
|
|
self.select: Optional[dict] = None
|
|
self._start_loop_thread()
|
|
|
|
def _loop_runner(self) -> None:
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
try:
|
|
self._loop.run_forever()
|
|
finally:
|
|
self._loop.close()
|
|
|
|
def _submit(self, coro):
|
|
|
|
if self._loop is None or not self._loop.is_running():
|
|
raise RuntimeError("PrismaService event loop is not running.")
|
|
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
return fut.result()
|
|
|
|
async def _lock(self):
|
|
|
|
lock = asyncio.Lock()
|
|
async with lock:
|
|
return
|
|
|
|
async def _aconnect(self) -> Prisma:
|
|
|
|
if self._client is not None:
|
|
return self._client
|
|
logger.info("Connecting Prisma client...")
|
|
client = Prisma()
|
|
await client.connect()
|
|
self._client = client
|
|
logger.info("Prisma client connected.")
|
|
return self._client
|
|
|
|
async def _adisconnect(self) -> None:
|
|
|
|
if self._client is not None:
|
|
logger.info("Disconnecting Prisma client...")
|
|
try:
|
|
await self._client.disconnect()
|
|
finally:
|
|
self._client = None
|
|
logger.info("Prisma client disconnected.")
|
|
|
|
@asynccontextmanager
|
|
async def _asession(self) -> AsyncGenerator[Prisma, None]:
|
|
yield await self._aconnect()
|
|
|
|
def _start_loop_thread(self) -> None:
|
|
t = threading.Thread(target=self._loop_runner, name="PrismaLoop", daemon=True)
|
|
t.start()
|
|
self._thread = t
|
|
while self._loop is None:
|
|
time.sleep(0.005)
|
|
|
|
async def _connect(self) -> Prisma:
|
|
|
|
if self._client is not None:
|
|
return self._client
|
|
async with self._lock:
|
|
if self._client is None:
|
|
logger.info("Connecting Prisma client...")
|
|
client = Prisma()
|
|
await client.connect()
|
|
self._client = client
|
|
logger.info("Prisma client connected.")
|
|
return self._client
|
|
|
|
async def _disconnect(self) -> None:
|
|
|
|
async with self._lock:
|
|
if self._client is not None:
|
|
try:
|
|
logger.info("Disconnecting Prisma client...")
|
|
await self._client.disconnect()
|
|
logger.info("Prisma client disconnected.")
|
|
finally:
|
|
self._client = None
|
|
|
|
@staticmethod
|
|
def to_dict(result: Union[list, Any], select: dict = None):
|
|
if isinstance(result, list):
|
|
list_result = []
|
|
for item_iter in result:
|
|
item = {}
|
|
for k, v in item_iter:
|
|
if k not in select:
|
|
continue
|
|
if isinstance(v, datetime.datetime):
|
|
item[k] = str(v)
|
|
if isinstance(v, uuid.UUID):
|
|
item[k] = str(v)
|
|
if isinstance(v, int):
|
|
item[k] = int(v)
|
|
if isinstance(v, float):
|
|
item[k] = float(v)
|
|
if isinstance(v, bool):
|
|
item[k] = bool(v)
|
|
else:
|
|
item[k] = str(v)
|
|
list_result.append(item)
|
|
return list_result
|
|
else:
|
|
dict_result = {}
|
|
for k,v in result:
|
|
if k not in select:
|
|
continue
|
|
if isinstance(v, datetime.datetime):
|
|
dict_result[k] = str(v)
|
|
if isinstance(v, uuid.UUID):
|
|
dict_result[k] = str(v)
|
|
if isinstance(v, int):
|
|
dict_result[k] = int(v)
|
|
if isinstance(v, float):
|
|
dict_result[k] = float(v)
|
|
if isinstance(v, bool):
|
|
dict_result[k] = bool(v)
|
|
else:
|
|
dict_result[k] = str(v)
|
|
return dict_result
|
|
|
|
@asynccontextmanager
|
|
async def _session(self) -> AsyncGenerator[Prisma, None]:
|
|
|
|
client = await self._connect()
|
|
try:
|
|
yield client
|
|
except Exception:
|
|
logger.exception("Database operation error")
|
|
raise
|
|
|
|
def _run(self, coro):
|
|
|
|
try:
|
|
asyncio.get_running_loop()
|
|
raise RuntimeError("Async run is not allowed. Use sync methods instead.")
|
|
except RuntimeError as e:
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
with asyncio.Runner() as runner:
|
|
return runner.run(coro)
|
|
|
|
def disconnect(self) -> None:
|
|
|
|
try:
|
|
self._submit(self._adisconnect())
|
|
finally:
|
|
if self._loop and self._loop.is_running():
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
if self._thread and self._thread.is_alive():
|
|
self._thread.join(timeout=2.0)
|
|
self._loop = None
|
|
self._thread = None
|