database: combine namespace wrappers

Rather than creating two wrappers, use a single wrapper whose methods
always return a future or awaitable.  If the operation occurs during
the __init__() method of a component it will be syncrhonous, and the
result from the future can be immediately queried.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-30 11:57:49 -05:00
parent 4d0a43cb25
commit 781e3e6250
1 changed files with 74 additions and 145 deletions

View File

@ -10,6 +10,7 @@ import json
import struct import struct
import operator import operator
import logging import logging
from asyncio import Future, Task
from io import BytesIO from io import BytesIO
from functools import reduce from functools import reduce
from threading import Lock as ThreadLock from threading import Lock as ThreadLock
@ -21,7 +22,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Coroutine,
ItemsView, ItemsView,
ValuesView, ValuesView,
Tuple, Tuple,
@ -29,6 +29,7 @@ from typing import (
Union, Union,
Dict, Dict,
List, List,
cast
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from confighelper import ConfigHelper from confighelper import ConfigHelper
@ -126,18 +127,20 @@ class MoonrakerDatabase:
# be granted by enabling the debug option. Forbidden namespaces # be granted by enabling the debug option. Forbidden namespaces
# have no API access. This cannot be overridden. # have no API access. This cannot be overridden.
self.protected_namespaces = set(self.get_item( self.protected_namespaces = set(self.get_item(
"moonraker", "database.protected_namespaces", ["moonraker"])) "moonraker", "database.protected_namespaces",
["moonraker"]).result())
self.forbidden_namespaces = set(self.get_item( self.forbidden_namespaces = set(self.get_item(
"moonraker", "database.forbidden_namespaces", [])) "moonraker", "database.forbidden_namespaces",
[]).result())
# Track debug access and unsafe shutdowns # Track debug access and unsafe shutdowns
debug_counter: int = self.get_item( debug_counter: int = self.get_item(
"moonraker", "database.debug_counter", 0) "moonraker", "database.debug_counter", 0).result()
if self.enable_debug: if self.enable_debug:
debug_counter += 1 debug_counter += 1
self.insert_item("moonraker", "database.debug_counter", self.insert_item("moonraker", "database.debug_counter",
debug_counter) debug_counter)
unsafe_shutdowns: int = self.get_item( unsafe_shutdowns: int = self.get_item(
"moonraker", "database.unsafe_shutdowns", 0) "moonraker", "database.unsafe_shutdowns", 0).result()
msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}" msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}"
if debug_counter: if debug_counter:
msg += f"; Database Debug Count: {debug_counter}" msg += f"; Database Debug Count: {debug_counter}"
@ -242,12 +245,15 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
key: Union[List[str], str], key: Union[List[str], str],
drop_empty_db: bool = False drop_empty_db: bool = False
) -> Any: ) -> Future[Any]:
if self.eventloop.is_running(): if self.eventloop.is_running():
return self.eventloop.run_in_thread( return cast(Future, self.eventloop.run_in_thread(
self._delete_impl, namespace, key, drop_empty_db) self._delete_impl, namespace, key, drop_empty_db))
else: else:
return self._delete_impl(namespace, key, drop_empty_db) ret = self._delete_impl(namespace, key, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
def _delete_impl(self, def _delete_impl(self,
namespace: str, namespace: str,
@ -291,12 +297,15 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
key: Optional[Union[List[str], str]] = None, key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL default: Any = SENTINEL
) -> Any: ) -> Future[Any]:
if self.eventloop.is_running(): if self.eventloop.is_running():
return self.eventloop.run_in_thread( return cast(Future, self.eventloop.run_in_thread(
self._get_impl, namespace, key, default) self._get_impl, namespace, key, default))
else: else:
return self._get_impl(namespace, key, default) ret = self._get_impl(namespace, key, default)
fut = self.eventloop.create_future()
fut.set_result(ret)
return ret
def _get_impl(self, def _get_impl(self,
namespace: str, namespace: str,
@ -474,19 +483,6 @@ class MoonrakerDatabase:
f"Namespace '{namespace}' not found", 404) f"Namespace '{namespace}' not found", 404)
return NamespaceWrapper(namespace, self, parse_keys) return NamespaceWrapper(namespace, self, parse_keys)
def wrap_async_namespace(self,
namespace: str,
parse_keys: bool = True
) -> AsyncNamespaceWrapper:
if self.eventloop.is_running():
raise self.server.error(
"Cannot wrap a namespace while the "
"eventloop is running")
if namespace not in self.namespaces:
raise self.server.error(
f"Namespace '{namespace}' not found", 404)
return AsyncNamespaceWrapper(namespace, self, parse_keys)
def _process_key(self, key: Union[List[str], str]) -> List[str]: def _process_key(self, key: Union[List[str], str]) -> List[str]:
try: try:
key_list = key if isinstance(key, list) else key.split('.') key_list = key if isinstance(key, list) else key.split('.')
@ -568,11 +564,6 @@ class MoonrakerDatabase:
raise self.server.error( raise self.server.error(
f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}") f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
def can_call_sync(self, name: str = "") -> None:
if self.eventloop.is_running():
raise self.server.error(
f"Cannot call method {name} while the eventloop is running")
async def _handle_list_request(self, async def _handle_list_request(self,
web_request: WebRequest web_request: WebRequest
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
@ -641,103 +632,6 @@ class MoonrakerDatabase:
self.thread_lock.release() self.thread_lock.release()
class NamespaceWrapper: class NamespaceWrapper:
def __init__(self,
namespace: str,
database: MoonrakerDatabase,
parse_keys: bool
) -> None:
self.namespace = namespace
self.db = database
# If parse keys is true, keys of a string type
# will be passed straight to the DB methods.
self.parse_keys = parse_keys
def to_async_wrapper(self) -> AsyncNamespaceWrapper:
return AsyncNamespaceWrapper(self.namespace, self.db, self.parse_keys)
def insert(self, key: Union[List[str], str], value: DBType) -> None:
self.db.can_call_sync("insert")
if isinstance(key, str) and not self.parse_keys:
key = [key]
self.db.insert_item(self.namespace, key, value)
def update_child(self, key: Union[List[str], str], value: DBType) -> None:
self.db.can_call_sync("update_child")
if isinstance(key, str) and not self.parse_keys:
key = [key]
self.db.update_item(self.namespace, key, value)
def update(self, value: Dict[str, DBRecord]) -> None:
self.db.can_call_sync("update")
self.db.update_namespace(self.namespace, value)
def get(self,
key: Union[List[str], str],
default: Any = None
) -> Any:
self.db.can_call_sync("get")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.get_item(self.namespace, key, default)
def delete(self, key: Union[List[str], str]) -> Any:
self.db.can_call_sync("delete")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.delete_item(self.namespace, key)
def __len__(self) -> int:
self.db.can_call_sync("length")
return self.db.ns_length(self.namespace)
def __getitem__(self, key: Union[List[str], str]) -> Any:
return self.get(key, default=SENTINEL)
def __setitem__(self,
key: Union[List[str], str],
value: DBType
) -> None:
self.insert(key, value)
def __delitem__(self, key: Union[List[str], str]):
self.delete(key)
def __contains__(self, key: Union[List[str], str]) -> bool:
self.db.can_call_sync("contains")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.ns_contains(self.namespace, key)
def keys(self) -> List[str]:
self.db.can_call_sync("keys")
return self.db.ns_keys(self.namespace)
def values(self) -> ValuesView:
self.db.can_call_sync("values")
return self.db.ns_values(self.namespace)
def items(self) -> ItemsView:
self.db.can_call_sync("items")
return self.db.ns_items(self.namespace)
def pop(self,
key: Union[List[str], str],
default: Any = SENTINEL
) -> Any:
self.db.can_call_sync("pop")
try:
val = self.delete(key)
except Exception:
if isinstance(default, SentinelClass):
raise
val = default
return val
def clear(self) -> None:
self.db.can_call_sync("clear")
self.db.clear_namespace(self.namespace)
class AsyncNamespaceWrapper:
def __init__(self, def __init__(self,
namespace: str, namespace: str,
database: MoonrakerDatabase, database: MoonrakerDatabase,
@ -746,13 +640,11 @@ class AsyncNamespaceWrapper:
self.namespace = namespace self.namespace = namespace
self.db = database self.db = database
self.eventloop = database.eventloop self.eventloop = database.eventloop
self.server = database.server
# If parse keys is true, keys of a string type # If parse keys is true, keys of a string type
# will be passed straight to the DB methods. # will be passed straight to the DB methods.
self.parse_keys = parse_keys self.parse_keys = parse_keys
def to_sync_wrapper(self) -> NamespaceWrapper:
return NamespaceWrapper(self.namespace, self.db, self.parse_keys)
def insert(self, def insert(self,
key: Union[List[str], str], key: Union[List[str], str],
value: DBType value: DBType
@ -772,15 +664,15 @@ class AsyncNamespaceWrapper:
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]: def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
return self.db.update_namespace(self.namespace, value) return self.db.update_namespace(self.namespace, value)
async def get(self, def get(self,
key: Union[List[str], str], key: Union[List[str], str],
default: Any = None default: Any = None
) -> Any: ) -> Future[Any]:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return await self.db.get_item(self.namespace, key, default) return self.db.get_item(self.namespace, key, default)
def delete(self, key: Union[List[str], str]) -> Awaitable[Any]: def delete(self, key: Union[List[str], str]) -> Future[Any]:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return self.db.delete_item(self.namespace, key) return self.db.delete_item(self.namespace, key)
@ -788,7 +680,7 @@ class AsyncNamespaceWrapper:
async def length(self) -> int: async def length(self) -> int:
return await self.db.ns_length_async(self.namespace) return await self.db.ns_length_async(self.namespace)
def __getitem__(self, key: Union[List[str], str]) -> Coroutine: def __getitem__(self, key: Union[List[str], str]) -> Future[Any]:
return self.get(key, default=SENTINEL) return self.get(key, default=SENTINEL)
def __setitem__(self, def __setitem__(self,
@ -800,24 +692,56 @@ class AsyncNamespaceWrapper:
def __delitem__(self, key: Union[List[str], str]): def __delitem__(self, key: Union[List[str], str]):
self.delete(key) self.delete(key)
def __contains__(self, key: Union[List[str], str]) -> bool:
self._check_sync_method("__contains__")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.ns_contains(self.namespace, key)
async def contains(self, key: Union[List[str], str]) -> bool: async def contains(self, key: Union[List[str], str]) -> bool:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return await self.db.ns_contains_async(self.namespace, key) return await self.db.ns_contains_async(self.namespace, key)
async def keys(self) -> List[str]: def keys(self) -> Future[List[str]]:
return await self.db.ns_keys_async(self.namespace) if not self.eventloop.is_running:
ret = self.db.ns_keys()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_keys_async(self.namespace))
async def values(self) -> ValuesView: def values(self) -> Future[ValuesView]:
return await self.db.ns_values_async(self.namespace) if not self.eventloop.is_running:
ret = self.db.ns_values()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_values_async(self.namespace))
async def items(self) -> ItemsView: def items(self) -> Future[ItemsView]:
return await self.db.ns_items_async(self.namespace) if not self.eventloop.is_running:
ret = self.db.ns_items()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_items_async(self.namespace))
def pop(self, def pop(self,
key: Union[List[str], str], key: Union[List[str], str],
default: Any = SENTINEL default: Any = SENTINEL
) -> Awaitable[Any]: ) -> Union[Future[Any], Task[Any]]:
if not self.eventloop.is_running():
try:
val = self.delete(key).result()
except Exception:
if isinstance(default, SentinelClass):
raise
val = default
fut = self.eventloop.create_future()
fut.set_result(val)
return fut
async def _do_pop() -> Any: async def _do_pop() -> Any:
try: try:
val = await self.delete(key) val = await self.delete(key)
@ -831,6 +755,11 @@ class AsyncNamespaceWrapper:
def clear(self) -> Awaitable[None]: def clear(self) -> Awaitable[None]:
return self.db.clear_namespace(self.namespace) return self.db.clear_namespace(self.namespace)
def _check_sync_method(self, func_name: str) -> None:
if self.eventloop.is_running():
raise self.server.error(
f"Cannot call method {func_name} while "
"the eventloop is running")
def load_component(config: ConfigHelper) -> MoonrakerDatabase: def load_component(config: ConfigHelper) -> MoonrakerDatabase:
return MoonrakerDatabase(config) return MoonrakerDatabase(config)