295 lines
13 KiB
Python
295 lines
13 KiB
Python
import asyncio
|
|
import time
|
|
import logging
|
|
import uvloop
|
|
import threading
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, AsyncGenerator, Protocol, Any
|
|
from contextlib import asynccontextmanager
|
|
from prisma import Prisma
|
|
|
|
|
|
logger = logging.getLogger("prisma-service")
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
|
|
|
|
|
class BaseModelClient(Protocol):
|
|
|
|
async def find_many(self, **kwargs) -> list[Any]: ...
|
|
async def find_first(self, **kwargs) -> Any: ...
|
|
async def find_first_or_raise(self, **kwargs) -> Any: ...
|
|
async def find_unique(self, **kwargs) -> Any: ...
|
|
async def find_unique_or_raise(self, **kwargs) -> Any: ...
|
|
async def create(self, **kwargs) -> Any: ...
|
|
async def update(self, **kwargs) -> Any: ...
|
|
async def delete(self, **kwargs) -> Any: ...
|
|
async def delete_many(self, **kwargs) -> Any: ...
|
|
|
|
|
|
class PrismaService:
|
|
|
|
def __init__(self) -> None:
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._thread: Optional[threading.Thread] = None
|
|
self._client: Optional[Prisma] = None
|
|
self._start_loop_thread()
|
|
|
|
def _loop_runner(self) -> None:
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
self._loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._loop)
|
|
try:
|
|
self._loop.run_forever()
|
|
finally:
|
|
self._loop.close()
|
|
|
|
def _submit(self, coro):
|
|
if self._loop is None or not self._loop.is_running():
|
|
raise RuntimeError("PrismaService event loop is not running.")
|
|
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
|
return fut.result()
|
|
|
|
async def _aconnect(self) -> Prisma:
|
|
if self._client is not None:
|
|
return self._client
|
|
logger.info("Connecting Prisma client...")
|
|
client = Prisma()
|
|
await client.connect()
|
|
self._client = client
|
|
logger.info("Prisma client connected.")
|
|
return self._client
|
|
|
|
async def _adisconnect(self) -> None:
|
|
if self._client is not None:
|
|
logger.info("Disconnecting Prisma client...")
|
|
try:
|
|
await self._client.disconnect()
|
|
finally:
|
|
self._client = None
|
|
logger.info("Prisma client disconnected.")
|
|
|
|
@asynccontextmanager
|
|
async def _asession(self) -> AsyncGenerator[Prisma, None]:
|
|
yield await self._aconnect()
|
|
|
|
def _start_loop_thread(self) -> None:
|
|
t = threading.Thread(target=self._loop_runner, name="PrismaLoop", daemon=True)
|
|
t.start()
|
|
self._thread = t
|
|
while self._loop is None:
|
|
time.sleep(0.005)
|
|
|
|
async def _lock(self):
|
|
lock = asyncio.Lock()
|
|
async with lock:
|
|
return
|
|
|
|
async def _connect(self) -> Prisma:
|
|
if self._client is not None:
|
|
return self._client
|
|
async with self._lock:
|
|
if self._client is None:
|
|
logger.info("Connecting Prisma client...")
|
|
client = Prisma()
|
|
await client.connect()
|
|
self._client = client
|
|
logger.info("Prisma client connected.")
|
|
return self._client
|
|
|
|
async def _disconnect(self) -> None:
|
|
async with self._lock:
|
|
if self._client is not None:
|
|
try:
|
|
logger.info("Disconnecting Prisma client...")
|
|
await self._client.disconnect()
|
|
logger.info("Prisma client disconnected.")
|
|
finally:
|
|
self._client = None
|
|
|
|
@asynccontextmanager
|
|
async def _session(self) -> AsyncGenerator[Prisma, None]:
|
|
client = await self._connect()
|
|
try:
|
|
yield client
|
|
except Exception:
|
|
logger.exception("Database operation error")
|
|
raise
|
|
|
|
def _run(self, coro):
|
|
try:
|
|
asyncio.get_running_loop()
|
|
raise RuntimeError("Async run is not allowed. Use sync methods instead.")
|
|
except RuntimeError as e:
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
with asyncio.Runner() as runner:
|
|
return runner.run(coro)
|
|
|
|
async def _a_find_many(self, table: str, query: Optional[dict] = None, take: int = None, skip: int = None,
|
|
order: Optional[list[dict]] = None, select: Optional[dict] = None, include: Optional[dict] = None
|
|
) -> list[dict]:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
rows = await table_selected.find_many(where=query, take=take, skip=skip, order=order or [], include=include)
|
|
# print(f"[{datetime.now()}] Find many query completed in {time.time() - start:.2f}s")
|
|
return rows
|
|
|
|
async def _a_find_first(self, table: str, query: Optional[dict] = None, order : Optional[list[dict]] = None, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.find_first(where=query, order=order or [], include=include)
|
|
# print(f"[{datetime.now()}] Find first query completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_find_first_or_throw(self, table: str, query: Optional[dict] = None, order : Optional[list[dict]] = None,
|
|
include: Optional[dict] = None
|
|
) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.find_first_or_raise(where=query, order=order or [], include=include)
|
|
# print(f"[{datetime.now()}] Find first or throw query completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_create(self, table: str, data: dict, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.create(data=data, include=include)
|
|
# print(f"[{datetime.now()}] Create operation completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_update(self, table: str, where: dict, data: dict, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.update(where=where, data=data, include=include)
|
|
# print(f"[{datetime.now()}] Update operation completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_delete(self, table: str, where: dict, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.delete(where=where, include=include)
|
|
# print(f"[{datetime.now()}] Delete operation completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_delete_many(self, table: str, where: dict, include: Optional[dict] = None):
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.delete_many(where=where, include=include)
|
|
# print(f"[{datetime.now()}] Delete many operation completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_find_unique(self, table: str, query: dict, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.find_unique(where=query, include=include)
|
|
# print(f"[{datetime.now()}] Find unique query completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
async def _a_find_unique_or_throw(self, table: str, query: dict, include: Optional[dict] = None) -> Any:
|
|
start = time.time()
|
|
async with self._asession() as db:
|
|
table_selected: BaseModelClient = getattr(db, table, None)
|
|
if not table_selected:
|
|
raise ValueError(f"Table {table} not found")
|
|
result = await table_selected.find_unique_or_raise(where=query, include=include)
|
|
# print(f"[{datetime.now()}] Find unique or throw query completed in {time.time() - start:.2f}s")
|
|
return result
|
|
|
|
def find_unique_or_throw(self, table: str, query: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_find_unique_or_throw(table=table, query=query, include=include))
|
|
if select:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def find_unique(self, table: str, query: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_find_unique(table=table, query=query, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def find_many(
|
|
self, table: str, query: Optional[dict] = None, take: int = None, skip: int = None,
|
|
order: Optional[list[dict]] = None, select: Optional[dict] = None, include: Optional[dict] = None
|
|
):
|
|
result = self._submit(self._a_find_many(table=table, query=query, take=take, skip=skip, order=order, include=include))
|
|
if select and result:
|
|
result = [{k: v for k, v in item.items() if k in select} for item in result]
|
|
return result
|
|
|
|
def create(self, table: str, data: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_create(table=table, data=data, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def find_first_or_throw(self, table: str, query: Optional[dict] = None,
|
|
order: Optional[list[dict]] = None, select: Optional[dict] = None, include: Optional[dict] = None
|
|
):
|
|
result = self._submit(self._a_find_first_or_throw(table=table, query=query, order=order, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def find_first(self, table: str, query: Optional[dict] = None, select: Optional[dict] = None, order: Optional[list[dict]] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_find_first(table=table, query=query, order=order, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def update(self, table: str, where: dict, data: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_update(table=table, where=where, data=data, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def delete(self, table: str, where: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_delete(table=table, where=where, select=select, include=include))
|
|
if select and result:
|
|
result = {k: v for k, v in result if k in select}
|
|
return result
|
|
|
|
def delete_many(self, table: str, where: dict, select: Optional[dict] = None, include: Optional[dict] = None):
|
|
result = self._submit(self._a_delete_many(table=table, where=where, select=select, include=include))
|
|
if select and result:
|
|
result = [{k: v for k, v in item if k in select} for item in result]
|
|
return result
|
|
|
|
def disconnect(self) -> None:
|
|
try:
|
|
self._submit(self._adisconnect())
|
|
finally:
|
|
if self._loop and self._loop.is_running():
|
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
if self._thread and self._thread.is_alive():
|
|
self._thread.join(timeout=2.0)
|
|
self._loop = None
|
|
self._thread = None
|