database: refactor to remove duplicate code
Wrap command implementations in with a _run_command() method. All database commands now return a Future object. If the command was run before the eventloop starts its possible to immediately query the Future's result. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
65d1f23352
commit
e029b6c582
|
@ -22,9 +22,9 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
ItemsView,
|
||||
Callable,
|
||||
Mapping,
|
||||
ValuesView,
|
||||
TypeVar,
|
||||
Tuple,
|
||||
Optional,
|
||||
Union,
|
||||
|
@ -37,6 +37,7 @@ if TYPE_CHECKING:
|
|||
from websockets import WebRequest
|
||||
DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]]
|
||||
DBType = Optional[DBRecord]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
DATABASE_VERSION = 1
|
||||
MAX_NAMESPACES = 100
|
||||
|
@ -157,26 +158,34 @@ class MoonrakerDatabase:
|
|||
"/server/database/item", ["GET", "POST", "DELETE"],
|
||||
self._handle_item_request)
|
||||
|
||||
def _run_command(self,
|
||||
command_func: Callable[..., _T],
|
||||
*args
|
||||
) -> Future[_T]:
|
||||
def func_wrapper():
|
||||
with self.thread_lock:
|
||||
return command_func(*args)
|
||||
|
||||
if self.eventloop.is_running():
|
||||
return cast(Future, self.eventloop.run_in_thread(func_wrapper))
|
||||
else:
|
||||
ret = func_wrapper()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
|
||||
def insert_item(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
) -> Awaitable[None]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._insert_impl, namespace, key, value)
|
||||
else:
|
||||
self._insert_impl(namespace, key, value)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
) -> Future[None]:
|
||||
return self._run_command(self._insert_impl, namespace, key, value)
|
||||
|
||||
def _insert_impl(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
) -> None:
|
||||
with self.thread_lock:
|
||||
key_list = self._process_key(key)
|
||||
if namespace not in self.namespaces:
|
||||
self.namespaces[namespace] = self.lmdb_env.open_db(
|
||||
|
@ -200,22 +209,14 @@ class MoonrakerDatabase:
|
|||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
) -> Awaitable[None]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._update_impl, namespace, key, value)
|
||||
else:
|
||||
self._update_impl(namespace, key, value)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
) -> Future[None]:
|
||||
return self._run_command(self._update_impl, namespace, key, value)
|
||||
|
||||
def _update_impl(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
) -> None:
|
||||
with self.thread_lock:
|
||||
key_list = self._process_key(key)
|
||||
record = self._get_record(namespace, key_list[0])
|
||||
if len(key_list) == 1:
|
||||
|
@ -247,21 +248,14 @@ class MoonrakerDatabase:
|
|||
key: Union[List[str], str],
|
||||
drop_empty_db: bool = False
|
||||
) -> Future[Any]:
|
||||
if self.eventloop.is_running():
|
||||
return cast(Future, self.eventloop.run_in_thread(
|
||||
self._delete_impl, namespace, key, drop_empty_db))
|
||||
else:
|
||||
ret = self._delete_impl(namespace, key, drop_empty_db)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return self._run_command(self._delete_impl, namespace, key,
|
||||
drop_empty_db)
|
||||
|
||||
def _delete_impl(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
drop_empty_db: bool = False
|
||||
) -> Any:
|
||||
with self.thread_lock:
|
||||
key_list = self._process_key(key)
|
||||
val = record = self._get_record(namespace, key_list[0])
|
||||
remove_record = True
|
||||
|
@ -278,9 +272,7 @@ class MoonrakerDatabase:
|
|||
remove_record = False if record else True
|
||||
if remove_record:
|
||||
db = self.namespaces[namespace]
|
||||
with (
|
||||
self.lmdb_env.begin(write=True, buffers=True, db=db) as txn
|
||||
):
|
||||
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
|
||||
ret = txn.delete(key_list[0].encode())
|
||||
with txn.cursor() as cursor:
|
||||
if not cursor.first() and drop_empty_db:
|
||||
|
@ -299,21 +291,13 @@ class MoonrakerDatabase:
|
|||
key: Optional[Union[List[str], str]] = None,
|
||||
default: Any = SENTINEL
|
||||
) -> Future[Any]:
|
||||
if self.eventloop.is_running():
|
||||
return cast(Future, self.eventloop.run_in_thread(
|
||||
self._get_impl, namespace, key, default))
|
||||
else:
|
||||
ret = self._get_impl(namespace, key, default)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return ret
|
||||
return self._run_command(self._get_impl, namespace, key, default)
|
||||
|
||||
def _get_impl(self,
|
||||
namespace: str,
|
||||
key: Optional[Union[List[str], str]] = None,
|
||||
default: Any = SENTINEL
|
||||
) -> Any:
|
||||
with self.thread_lock:
|
||||
try:
|
||||
if key is None:
|
||||
return self._get_namespace(namespace)
|
||||
|
@ -331,21 +315,13 @@ class MoonrakerDatabase:
|
|||
def update_namespace(self,
|
||||
namespace: str,
|
||||
value: Mapping[str, DBRecord]
|
||||
) -> Awaitable[None]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._update_ns_impl, namespace, value)
|
||||
else:
|
||||
self._update_ns_impl(namespace, value)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
) -> Future[None]:
|
||||
return self._run_command(self._update_ns_impl, namespace, value)
|
||||
|
||||
def _update_ns_impl(self,
|
||||
namespace: str,
|
||||
value: Mapping[str, DBRecord]
|
||||
) -> None:
|
||||
with self.thread_lock:
|
||||
if not value:
|
||||
return
|
||||
if namespace not in self.namespaces:
|
||||
|
@ -368,21 +344,13 @@ class MoonrakerDatabase:
|
|||
def clear_namespace(self,
|
||||
namespace: str,
|
||||
drop_empty_db: bool = False
|
||||
) -> Awaitable[None]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._clear_ns_impl, namespace, drop_empty_db)
|
||||
else:
|
||||
self._clear_ns_impl(namespace, drop_empty_db)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
) -> Future[None]:
|
||||
return self._run_command(self._clear_ns_impl, namespace, drop_empty_db)
|
||||
|
||||
def _clear_ns_impl(self,
|
||||
namespace: str,
|
||||
drop_empty_db: bool = False
|
||||
) -> None:
|
||||
with self.thread_lock:
|
||||
if namespace not in self.namespaces:
|
||||
raise self.server.error(
|
||||
f"Invalid database namespace '{namespace}'")
|
||||
|
@ -395,21 +363,13 @@ class MoonrakerDatabase:
|
|||
def sync_namespace(self,
|
||||
namespace: str,
|
||||
value: Mapping[str, DBRecord]
|
||||
) -> Awaitable[None]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._sync_ns_impl, namespace, value)
|
||||
else:
|
||||
self._sync_ns_impl(namespace, value)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(None)
|
||||
return fut
|
||||
) -> Future[None]:
|
||||
return self._run_command(self._sync_ns_impl, namespace, value)
|
||||
|
||||
def _sync_ns_impl(self,
|
||||
namespace: str,
|
||||
value: Mapping[str, DBRecord]
|
||||
) -> None:
|
||||
with self.thread_lock:
|
||||
if not value:
|
||||
return
|
||||
if namespace not in self.namespaces:
|
||||
|
@ -438,17 +398,19 @@ class MoonrakerDatabase:
|
|||
logging.info(f"Error inserting key '{k}' "
|
||||
f"in namespace '{namespace}'")
|
||||
|
||||
async def ns_length_async(self, namespace: str) -> int:
|
||||
return len(await self.ns_keys_async(namespace))
|
||||
def ns_length(self, namespace: str) -> Future[int]:
|
||||
return self._run_command(self._ns_length_impl, namespace)
|
||||
|
||||
def ns_length(self, namespace: str) -> int:
|
||||
return len(self.ns_keys(namespace))
|
||||
def _ns_length_impl(self, namespace: str) -> int:
|
||||
db = self.namespaces[namespace]
|
||||
with self.lmdb_env.begin(db=db) as txn:
|
||||
stats = txn.stat(db)
|
||||
return stats['entries']
|
||||
|
||||
def ns_keys_async(self, namespace: str) -> Awaitable[List[str]]:
|
||||
return self.eventloop.run_in_thread(self.ns_keys, namespace)
|
||||
def ns_keys(self, namespace: str) -> Future[List[str]]:
|
||||
return self._run_command(self._ns_keys_impl, namespace)
|
||||
|
||||
def ns_keys(self, namespace: str) -> List[str]:
|
||||
with self.thread_lock:
|
||||
def _ns_keys_impl(self, namespace: str) -> List[str]:
|
||||
keys: List[str] = []
|
||||
db = self.namespaces[namespace]
|
||||
with self.lmdb_env.begin(db=db) as txn:
|
||||
|
@ -459,31 +421,37 @@ class MoonrakerDatabase:
|
|||
remaining = cursor.next()
|
||||
return keys
|
||||
|
||||
def ns_values_async(self, namespace: str) -> Awaitable[ValuesView]:
|
||||
return self.eventloop.run_in_thread(self.ns_values, namespace)
|
||||
def ns_values(self, namespace: str) -> Future[List[Any]]:
|
||||
return self._run_command(self._ns_values_impl, namespace)
|
||||
|
||||
def ns_values(self, namespace: str) -> ValuesView:
|
||||
with self.thread_lock:
|
||||
def _ns_values_impl(self, namespace: str) -> List[Any]:
|
||||
values: List[Any] = []
|
||||
db = self.namespaces[namespace]
|
||||
with self.lmdb_env.begin(db=db, buffers=True) as txn:
|
||||
with txn.cursor() as cursor:
|
||||
remaining = cursor.first()
|
||||
while remaining:
|
||||
values.append(self._decode_value(cursor.value()))
|
||||
remaining = cursor.next()
|
||||
return values
|
||||
|
||||
def ns_items(self, namespace: str) -> Future[List[Tuple[str, Any]]]:
|
||||
return self._run_command(self._ns_items_impl, namespace)
|
||||
|
||||
def _ns_items_impl(self, namespace: str) -> List[Tuple[str, Any]]:
|
||||
ns = self._get_namespace(namespace)
|
||||
return ns.values()
|
||||
return list(ns.items())
|
||||
|
||||
def ns_items_async(self, namespace: str) -> Awaitable[ItemsView]:
|
||||
return self.eventloop.run_in_thread(self.ns_items, namespace)
|
||||
|
||||
def ns_items(self, namespace: str) -> ItemsView:
|
||||
with self.thread_lock:
|
||||
ns = self._get_namespace(namespace)
|
||||
return ns.items()
|
||||
|
||||
def ns_contains_async(self,
|
||||
def ns_contains(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str]
|
||||
) -> Awaitable[bool]:
|
||||
return self.eventloop.run_in_thread(
|
||||
self.ns_contains, namespace, key)
|
||||
) -> Future[bool]:
|
||||
return self._run_command(self._ns_contains_impl, namespace, key)
|
||||
|
||||
def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool:
|
||||
with self.thread_lock:
|
||||
def _ns_contains_impl(self,
|
||||
namespace: str,
|
||||
key: Union[List[str], str]
|
||||
) -> bool:
|
||||
try:
|
||||
key_list = self._process_key(key)
|
||||
record = self._get_record(namespace, key_list[0])
|
||||
|
@ -727,8 +695,8 @@ class NamespaceWrapper:
|
|||
key = [key]
|
||||
return self.db.delete_item(self.namespace, key)
|
||||
|
||||
async def length(self) -> int:
|
||||
return await self.db.ns_length_async(self.namespace)
|
||||
def length(self) -> Future[int]:
|
||||
return self.db.ns_length(self.namespace)
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
self._check_sync_method("as_dict")
|
||||
|
@ -750,36 +718,21 @@ class NamespaceWrapper:
|
|||
self._check_sync_method("__contains__")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.ns_contains(self.namespace, key)
|
||||
return self.db.ns_contains(self.namespace, key).result()
|
||||
|
||||
async def contains(self, key: Union[List[str], str]) -> bool:
|
||||
def contains(self, key: Union[List[str], str]) -> Future[bool]:
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return await self.db.ns_contains_async(self.namespace, key)
|
||||
return self.db.ns_contains(self.namespace, key)
|
||||
|
||||
def keys(self) -> Future[List[str]]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_keys()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_keys_async(self.namespace))
|
||||
return self.db.ns_keys(self.namespace)
|
||||
|
||||
def values(self) -> Future[ValuesView]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_values()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_values_async(self.namespace))
|
||||
def values(self) -> Future[List[Any]]:
|
||||
return self.db.ns_values(self.namespace)
|
||||
|
||||
def items(self) -> Future[ItemsView]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_items()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_items_async(self.namespace))
|
||||
def items(self) -> Future[List[Tuple[str, Any]]]:
|
||||
return self.db.ns_items(self.namespace)
|
||||
|
||||
def pop(self,
|
||||
key: Union[List[str], str],
|
||||
|
|
Loading…
Reference in New Issue