database: refactor to remove duplicate code

Wrap command implementations in with a _run_command() method.  All
database commands now return a Future object.  If the command was
run before the eventloop starts its possible to immediately query
the Future's result.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-31 11:45:22 -05:00
parent 65d1f23352
commit e029b6c582
1 changed files with 235 additions and 282 deletions

View File

@ -22,9 +22,9 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
ItemsView, Callable,
Mapping, Mapping,
ValuesView, TypeVar,
Tuple, Tuple,
Optional, Optional,
Union, Union,
@ -37,6 +37,7 @@ if TYPE_CHECKING:
from websockets import WebRequest from websockets import WebRequest
DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]] DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]]
DBType = Optional[DBRecord] DBType = Optional[DBRecord]
_T = TypeVar("_T")
DATABASE_VERSION = 1 DATABASE_VERSION = 1
MAX_NAMESPACES = 100 MAX_NAMESPACES = 100
@ -157,26 +158,34 @@ class MoonrakerDatabase:
"/server/database/item", ["GET", "POST", "DELETE"], "/server/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request) self._handle_item_request)
def _run_command(self,
command_func: Callable[..., _T],
*args
) -> Future[_T]:
def func_wrapper():
with self.thread_lock:
return command_func(*args)
if self.eventloop.is_running():
return cast(Future, self.eventloop.run_in_thread(func_wrapper))
else:
ret = func_wrapper()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
def insert_item(self, def insert_item(self,
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
value: DBType value: DBType
) -> Awaitable[None]: ) -> Future[None]:
if self.eventloop.is_running(): return self._run_command(self._insert_impl, namespace, key, value)
return self.eventloop.run_in_thread(
self._insert_impl, namespace, key, value)
else:
self._insert_impl(namespace, key, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _insert_impl(self, def _insert_impl(self,
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
value: DBType value: DBType
) -> None: ) -> None:
with self.thread_lock:
key_list = self._process_key(key) key_list = self._process_key(key)
if namespace not in self.namespaces: if namespace not in self.namespaces:
self.namespaces[namespace] = self.lmdb_env.open_db( self.namespaces[namespace] = self.lmdb_env.open_db(
@ -200,22 +209,14 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
value: DBType value: DBType
) -> Awaitable[None]: ) -> Future[None]:
if self.eventloop.is_running(): return self._run_command(self._update_impl, namespace, key, value)
return self.eventloop.run_in_thread(
self._update_impl, namespace, key, value)
else:
self._update_impl(namespace, key, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _update_impl(self, def _update_impl(self,
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
value: DBType value: DBType
) -> None: ) -> None:
with self.thread_lock:
key_list = self._process_key(key) key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0]) record = self._get_record(namespace, key_list[0])
if len(key_list) == 1: if len(key_list) == 1:
@ -247,21 +248,14 @@ class MoonrakerDatabase:
key: Union[List[str], str], key: Union[List[str], str],
drop_empty_db: bool = False drop_empty_db: bool = False
) -> Future[Any]: ) -> Future[Any]:
if self.eventloop.is_running(): return self._run_command(self._delete_impl, namespace, key,
return cast(Future, self.eventloop.run_in_thread( drop_empty_db)
self._delete_impl, namespace, key, drop_empty_db))
else:
ret = self._delete_impl(namespace, key, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
def _delete_impl(self, def _delete_impl(self,
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
drop_empty_db: bool = False drop_empty_db: bool = False
) -> Any: ) -> Any:
with self.thread_lock:
key_list = self._process_key(key) key_list = self._process_key(key)
val = record = self._get_record(namespace, key_list[0]) val = record = self._get_record(namespace, key_list[0])
remove_record = True remove_record = True
@ -278,9 +272,7 @@ class MoonrakerDatabase:
remove_record = False if record else True remove_record = False if record else True
if remove_record: if remove_record:
db = self.namespaces[namespace] db = self.namespaces[namespace]
with ( with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
self.lmdb_env.begin(write=True, buffers=True, db=db) as txn
):
ret = txn.delete(key_list[0].encode()) ret = txn.delete(key_list[0].encode())
with txn.cursor() as cursor: with txn.cursor() as cursor:
if not cursor.first() and drop_empty_db: if not cursor.first() and drop_empty_db:
@ -299,21 +291,13 @@ class MoonrakerDatabase:
key: Optional[Union[List[str], str]] = None, key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL default: Any = SENTINEL
) -> Future[Any]: ) -> Future[Any]:
if self.eventloop.is_running(): return self._run_command(self._get_impl, namespace, key, default)
return cast(Future, self.eventloop.run_in_thread(
self._get_impl, namespace, key, default))
else:
ret = self._get_impl(namespace, key, default)
fut = self.eventloop.create_future()
fut.set_result(ret)
return ret
def _get_impl(self, def _get_impl(self,
namespace: str, namespace: str,
key: Optional[Union[List[str], str]] = None, key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL default: Any = SENTINEL
) -> Any: ) -> Any:
with self.thread_lock:
try: try:
if key is None: if key is None:
return self._get_namespace(namespace) return self._get_namespace(namespace)
@ -331,21 +315,13 @@ class MoonrakerDatabase:
def update_namespace(self, def update_namespace(self,
namespace: str, namespace: str,
value: Mapping[str, DBRecord] value: Mapping[str, DBRecord]
) -> Awaitable[None]: ) -> Future[None]:
if self.eventloop.is_running(): return self._run_command(self._update_ns_impl, namespace, value)
return self.eventloop.run_in_thread(
self._update_ns_impl, namespace, value)
else:
self._update_ns_impl(namespace, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _update_ns_impl(self, def _update_ns_impl(self,
namespace: str, namespace: str,
value: Mapping[str, DBRecord] value: Mapping[str, DBRecord]
) -> None: ) -> None:
with self.thread_lock:
if not value: if not value:
return return
if namespace not in self.namespaces: if namespace not in self.namespaces:
@ -368,21 +344,13 @@ class MoonrakerDatabase:
def clear_namespace(self, def clear_namespace(self,
namespace: str, namespace: str,
drop_empty_db: bool = False drop_empty_db: bool = False
) -> Awaitable[None]: ) -> Future[None]:
if self.eventloop.is_running(): return self._run_command(self._clear_ns_impl, namespace, drop_empty_db)
return self.eventloop.run_in_thread(
self._clear_ns_impl, namespace, drop_empty_db)
else:
self._clear_ns_impl(namespace, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _clear_ns_impl(self, def _clear_ns_impl(self,
namespace: str, namespace: str,
drop_empty_db: bool = False drop_empty_db: bool = False
) -> None: ) -> None:
with self.thread_lock:
if namespace not in self.namespaces: if namespace not in self.namespaces:
raise self.server.error( raise self.server.error(
f"Invalid database namespace '{namespace}'") f"Invalid database namespace '{namespace}'")
@ -395,21 +363,13 @@ class MoonrakerDatabase:
def sync_namespace(self, def sync_namespace(self,
namespace: str, namespace: str,
value: Mapping[str, DBRecord] value: Mapping[str, DBRecord]
) -> Awaitable[None]: ) -> Future[None]:
if self.eventloop.is_running(): return self._run_command(self._sync_ns_impl, namespace, value)
return self.eventloop.run_in_thread(
self._sync_ns_impl, namespace, value)
else:
self._sync_ns_impl(namespace, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _sync_ns_impl(self, def _sync_ns_impl(self,
namespace: str, namespace: str,
value: Mapping[str, DBRecord] value: Mapping[str, DBRecord]
) -> None: ) -> None:
with self.thread_lock:
if not value: if not value:
return return
if namespace not in self.namespaces: if namespace not in self.namespaces:
@ -438,17 +398,19 @@ class MoonrakerDatabase:
logging.info(f"Error inserting key '{k}' " logging.info(f"Error inserting key '{k}' "
f"in namespace '{namespace}'") f"in namespace '{namespace}'")
async def ns_length_async(self, namespace: str) -> int: def ns_length(self, namespace: str) -> Future[int]:
return len(await self.ns_keys_async(namespace)) return self._run_command(self._ns_length_impl, namespace)
def ns_length(self, namespace: str) -> int: def _ns_length_impl(self, namespace: str) -> int:
return len(self.ns_keys(namespace)) db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn:
stats = txn.stat(db)
return stats['entries']
def ns_keys_async(self, namespace: str) -> Awaitable[List[str]]: def ns_keys(self, namespace: str) -> Future[List[str]]:
return self.eventloop.run_in_thread(self.ns_keys, namespace) return self._run_command(self._ns_keys_impl, namespace)
def ns_keys(self, namespace: str) -> List[str]: def _ns_keys_impl(self, namespace: str) -> List[str]:
with self.thread_lock:
keys: List[str] = [] keys: List[str] = []
db = self.namespaces[namespace] db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn: with self.lmdb_env.begin(db=db) as txn:
@ -459,31 +421,37 @@ class MoonrakerDatabase:
remaining = cursor.next() remaining = cursor.next()
return keys return keys
def ns_values_async(self, namespace: str) -> Awaitable[ValuesView]: def ns_values(self, namespace: str) -> Future[List[Any]]:
return self.eventloop.run_in_thread(self.ns_values, namespace) return self._run_command(self._ns_values_impl, namespace)
def ns_values(self, namespace: str) -> ValuesView: def _ns_values_impl(self, namespace: str) -> List[Any]:
with self.thread_lock: values: List[Any] = []
db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db, buffers=True) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
values.append(self._decode_value(cursor.value()))
remaining = cursor.next()
return values
def ns_items(self, namespace: str) -> Future[List[Tuple[str, Any]]]:
return self._run_command(self._ns_items_impl, namespace)
def _ns_items_impl(self, namespace: str) -> List[Tuple[str, Any]]:
ns = self._get_namespace(namespace) ns = self._get_namespace(namespace)
return ns.values() return list(ns.items())
def ns_items_async(self, namespace: str) -> Awaitable[ItemsView]: def ns_contains(self,
return self.eventloop.run_in_thread(self.ns_items, namespace)
def ns_items(self, namespace: str) -> ItemsView:
with self.thread_lock:
ns = self._get_namespace(namespace)
return ns.items()
def ns_contains_async(self,
namespace: str, namespace: str,
key: Union[List[str], str] key: Union[List[str], str]
) -> Awaitable[bool]: ) -> Future[bool]:
return self.eventloop.run_in_thread( return self._run_command(self._ns_contains_impl, namespace, key)
self.ns_contains, namespace, key)
def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool: def _ns_contains_impl(self,
with self.thread_lock: namespace: str,
key: Union[List[str], str]
) -> bool:
try: try:
key_list = self._process_key(key) key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0]) record = self._get_record(namespace, key_list[0])
@ -727,8 +695,8 @@ class NamespaceWrapper:
key = [key] key = [key]
return self.db.delete_item(self.namespace, key) return self.db.delete_item(self.namespace, key)
async def length(self) -> int: def length(self) -> Future[int]:
return await self.db.ns_length_async(self.namespace) return self.db.ns_length(self.namespace)
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
self._check_sync_method("as_dict") self._check_sync_method("as_dict")
@ -750,36 +718,21 @@ class NamespaceWrapper:
self._check_sync_method("__contains__") self._check_sync_method("__contains__")
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return self.db.ns_contains(self.namespace, key) return self.db.ns_contains(self.namespace, key).result()
async def contains(self, key: Union[List[str], str]) -> bool: def contains(self, key: Union[List[str], str]) -> Future[bool]:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return await self.db.ns_contains_async(self.namespace, key) return self.db.ns_contains(self.namespace, key)
def keys(self) -> Future[List[str]]: def keys(self) -> Future[List[str]]:
if not self.eventloop.is_running: return self.db.ns_keys(self.namespace)
ret = self.db.ns_keys()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_keys_async(self.namespace))
def values(self) -> Future[ValuesView]: def values(self) -> Future[List[Any]]:
if not self.eventloop.is_running: return self.db.ns_values(self.namespace)
ret = self.db.ns_values()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_values_async(self.namespace))
def items(self) -> Future[ItemsView]: def items(self) -> Future[List[Tuple[str, Any]]]:
if not self.eventloop.is_running: return self.db.ns_items(self.namespace)
ret = self.db.ns_items()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_items_async(self.namespace))
def pop(self, def pop(self,
key: Union[List[str], str], key: Union[List[str], str],