database: combine namespace wrappers

Rather than creating two wrappers, use a single wrapper whose methods
always return a future or awaitable.  If the operation occurs during
the __init__() method of a component it will be syncrhonous, and the
result from the future can be immediately queried.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-30 11:57:49 -05:00
parent 4d0a43cb25
commit 781e3e6250
1 changed files with 74 additions and 145 deletions

View File

@ -10,6 +10,7 @@ import json
import struct
import operator
import logging
from asyncio import Future, Task
from io import BytesIO
from functools import reduce
from threading import Lock as ThreadLock
@ -21,7 +22,6 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Coroutine,
ItemsView,
ValuesView,
Tuple,
@ -29,6 +29,7 @@ from typing import (
Union,
Dict,
List,
cast
)
if TYPE_CHECKING:
from confighelper import ConfigHelper
@ -126,18 +127,20 @@ class MoonrakerDatabase:
# be granted by enabling the debug option. Forbidden namespaces
# have no API access. This cannot be overridden.
self.protected_namespaces = set(self.get_item(
"moonraker", "database.protected_namespaces", ["moonraker"]))
"moonraker", "database.protected_namespaces",
["moonraker"]).result())
self.forbidden_namespaces = set(self.get_item(
"moonraker", "database.forbidden_namespaces", []))
"moonraker", "database.forbidden_namespaces",
[]).result())
# Track debug access and unsafe shutdowns
debug_counter: int = self.get_item(
"moonraker", "database.debug_counter", 0)
"moonraker", "database.debug_counter", 0).result()
if self.enable_debug:
debug_counter += 1
self.insert_item("moonraker", "database.debug_counter",
debug_counter)
unsafe_shutdowns: int = self.get_item(
"moonraker", "database.unsafe_shutdowns", 0)
"moonraker", "database.unsafe_shutdowns", 0).result()
msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}"
if debug_counter:
msg += f"; Database Debug Count: {debug_counter}"
@ -242,12 +245,15 @@ class MoonrakerDatabase:
namespace: str,
key: Union[List[str], str],
drop_empty_db: bool = False
) -> Any:
) -> Future[Any]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._delete_impl, namespace, key, drop_empty_db)
return cast(Future, self.eventloop.run_in_thread(
self._delete_impl, namespace, key, drop_empty_db))
else:
return self._delete_impl(namespace, key, drop_empty_db)
ret = self._delete_impl(namespace, key, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
def _delete_impl(self,
namespace: str,
@ -291,12 +297,15 @@ class MoonrakerDatabase:
namespace: str,
key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL
) -> Any:
) -> Future[Any]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._get_impl, namespace, key, default)
return cast(Future, self.eventloop.run_in_thread(
self._get_impl, namespace, key, default))
else:
return self._get_impl(namespace, key, default)
ret = self._get_impl(namespace, key, default)
fut = self.eventloop.create_future()
fut.set_result(ret)
return ret
def _get_impl(self,
namespace: str,
@ -474,19 +483,6 @@ class MoonrakerDatabase:
f"Namespace '{namespace}' not found", 404)
return NamespaceWrapper(namespace, self, parse_keys)
def wrap_async_namespace(self,
namespace: str,
parse_keys: bool = True
) -> AsyncNamespaceWrapper:
if self.eventloop.is_running():
raise self.server.error(
"Cannot wrap a namespace while the "
"eventloop is running")
if namespace not in self.namespaces:
raise self.server.error(
f"Namespace '{namespace}' not found", 404)
return AsyncNamespaceWrapper(namespace, self, parse_keys)
def _process_key(self, key: Union[List[str], str]) -> List[str]:
try:
key_list = key if isinstance(key, list) else key.split('.')
@ -568,11 +564,6 @@ class MoonrakerDatabase:
raise self.server.error(
f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
def can_call_sync(self, name: str = "") -> None:
if self.eventloop.is_running():
raise self.server.error(
f"Cannot call method {name} while the eventloop is running")
async def _handle_list_request(self,
web_request: WebRequest
) -> Dict[str, List[str]]:
@ -641,103 +632,6 @@ class MoonrakerDatabase:
self.thread_lock.release()
class NamespaceWrapper:
def __init__(self,
namespace: str,
database: MoonrakerDatabase,
parse_keys: bool
) -> None:
self.namespace = namespace
self.db = database
# If parse keys is true, keys of a string type
# will be passed straight to the DB methods.
self.parse_keys = parse_keys
def to_async_wrapper(self) -> AsyncNamespaceWrapper:
return AsyncNamespaceWrapper(self.namespace, self.db, self.parse_keys)
def insert(self, key: Union[List[str], str], value: DBType) -> None:
self.db.can_call_sync("insert")
if isinstance(key, str) and not self.parse_keys:
key = [key]
self.db.insert_item(self.namespace, key, value)
def update_child(self, key: Union[List[str], str], value: DBType) -> None:
self.db.can_call_sync("update_child")
if isinstance(key, str) and not self.parse_keys:
key = [key]
self.db.update_item(self.namespace, key, value)
def update(self, value: Dict[str, DBRecord]) -> None:
self.db.can_call_sync("update")
self.db.update_namespace(self.namespace, value)
def get(self,
key: Union[List[str], str],
default: Any = None
) -> Any:
self.db.can_call_sync("get")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.get_item(self.namespace, key, default)
def delete(self, key: Union[List[str], str]) -> Any:
self.db.can_call_sync("delete")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.delete_item(self.namespace, key)
def __len__(self) -> int:
self.db.can_call_sync("length")
return self.db.ns_length(self.namespace)
def __getitem__(self, key: Union[List[str], str]) -> Any:
return self.get(key, default=SENTINEL)
def __setitem__(self,
key: Union[List[str], str],
value: DBType
) -> None:
self.insert(key, value)
def __delitem__(self, key: Union[List[str], str]):
self.delete(key)
def __contains__(self, key: Union[List[str], str]) -> bool:
self.db.can_call_sync("contains")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.ns_contains(self.namespace, key)
def keys(self) -> List[str]:
self.db.can_call_sync("keys")
return self.db.ns_keys(self.namespace)
def values(self) -> ValuesView:
self.db.can_call_sync("values")
return self.db.ns_values(self.namespace)
def items(self) -> ItemsView:
self.db.can_call_sync("items")
return self.db.ns_items(self.namespace)
def pop(self,
key: Union[List[str], str],
default: Any = SENTINEL
) -> Any:
self.db.can_call_sync("pop")
try:
val = self.delete(key)
except Exception:
if isinstance(default, SentinelClass):
raise
val = default
return val
def clear(self) -> None:
self.db.can_call_sync("clear")
self.db.clear_namespace(self.namespace)
class AsyncNamespaceWrapper:
def __init__(self,
namespace: str,
database: MoonrakerDatabase,
@ -746,13 +640,11 @@ class AsyncNamespaceWrapper:
self.namespace = namespace
self.db = database
self.eventloop = database.eventloop
self.server = database.server
# If parse keys is true, keys of a string type
# will be passed straight to the DB methods.
self.parse_keys = parse_keys
def to_sync_wrapper(self) -> NamespaceWrapper:
return NamespaceWrapper(self.namespace, self.db, self.parse_keys)
def insert(self,
key: Union[List[str], str],
value: DBType
@ -772,15 +664,15 @@ class AsyncNamespaceWrapper:
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
return self.db.update_namespace(self.namespace, value)
async def get(self,
key: Union[List[str], str],
default: Any = None
) -> Any:
def get(self,
key: Union[List[str], str],
default: Any = None
) -> Future[Any]:
if isinstance(key, str) and not self.parse_keys:
key = [key]
return await self.db.get_item(self.namespace, key, default)
return self.db.get_item(self.namespace, key, default)
def delete(self, key: Union[List[str], str]) -> Awaitable[Any]:
def delete(self, key: Union[List[str], str]) -> Future[Any]:
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.delete_item(self.namespace, key)
@ -788,7 +680,7 @@ class AsyncNamespaceWrapper:
async def length(self) -> int:
return await self.db.ns_length_async(self.namespace)
def __getitem__(self, key: Union[List[str], str]) -> Coroutine:
def __getitem__(self, key: Union[List[str], str]) -> Future[Any]:
return self.get(key, default=SENTINEL)
def __setitem__(self,
@ -800,24 +692,56 @@ class AsyncNamespaceWrapper:
def __delitem__(self, key: Union[List[str], str]):
self.delete(key)
def __contains__(self, key: Union[List[str], str]) -> bool:
self._check_sync_method("__contains__")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.ns_contains(self.namespace, key)
async def contains(self, key: Union[List[str], str]) -> bool:
if isinstance(key, str) and not self.parse_keys:
key = [key]
return await self.db.ns_contains_async(self.namespace, key)
async def keys(self) -> List[str]:
return await self.db.ns_keys_async(self.namespace)
def keys(self) -> Future[List[str]]:
if not self.eventloop.is_running:
ret = self.db.ns_keys()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_keys_async(self.namespace))
async def values(self) -> ValuesView:
return await self.db.ns_values_async(self.namespace)
def values(self) -> Future[ValuesView]:
if not self.eventloop.is_running:
ret = self.db.ns_values()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_values_async(self.namespace))
async def items(self) -> ItemsView:
return await self.db.ns_items_async(self.namespace)
def items(self) -> Future[ItemsView]:
if not self.eventloop.is_running:
ret = self.db.ns_items()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_items_async(self.namespace))
def pop(self,
key: Union[List[str], str],
default: Any = SENTINEL
) -> Awaitable[Any]:
) -> Union[Future[Any], Task[Any]]:
if not self.eventloop.is_running():
try:
val = self.delete(key).result()
except Exception:
if isinstance(default, SentinelClass):
raise
val = default
fut = self.eventloop.create_future()
fut.set_result(val)
return fut
async def _do_pop() -> Any:
try:
val = await self.delete(key)
@ -831,6 +755,11 @@ class AsyncNamespaceWrapper:
def clear(self) -> Awaitable[None]:
return self.db.clear_namespace(self.namespace)
def _check_sync_method(self, func_name: str) -> None:
if self.eventloop.is_running():
raise self.server.error(
f"Cannot call method {func_name} while "
"the eventloop is running")
def load_component(config: ConfigHelper) -> MoonrakerDatabase:
return MoonrakerDatabase(config)