database: combine namespace wrappers
Rather than creating two wrappers, use a single wrapper whose methods always return a future or awaitable. If the operation occurs during the __init__() method of a component it will be syncrhonous, and the result from the future can be immediately queried. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
4d0a43cb25
commit
781e3e6250
|
@ -10,6 +10,7 @@ import json
|
|||
import struct
|
||||
import operator
|
||||
import logging
|
||||
from asyncio import Future, Task
|
||||
from io import BytesIO
|
||||
from functools import reduce
|
||||
from threading import Lock as ThreadLock
|
||||
|
@ -21,7 +22,6 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Coroutine,
|
||||
ItemsView,
|
||||
ValuesView,
|
||||
Tuple,
|
||||
|
@ -29,6 +29,7 @@ from typing import (
|
|||
Union,
|
||||
Dict,
|
||||
List,
|
||||
cast
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from confighelper import ConfigHelper
|
||||
|
@ -126,18 +127,20 @@ class MoonrakerDatabase:
|
|||
# be granted by enabling the debug option. Forbidden namespaces
|
||||
# have no API access. This cannot be overridden.
|
||||
self.protected_namespaces = set(self.get_item(
|
||||
"moonraker", "database.protected_namespaces", ["moonraker"]))
|
||||
"moonraker", "database.protected_namespaces",
|
||||
["moonraker"]).result())
|
||||
self.forbidden_namespaces = set(self.get_item(
|
||||
"moonraker", "database.forbidden_namespaces", []))
|
||||
"moonraker", "database.forbidden_namespaces",
|
||||
[]).result())
|
||||
# Track debug access and unsafe shutdowns
|
||||
debug_counter: int = self.get_item(
|
||||
"moonraker", "database.debug_counter", 0)
|
||||
"moonraker", "database.debug_counter", 0).result()
|
||||
if self.enable_debug:
|
||||
debug_counter += 1
|
||||
self.insert_item("moonraker", "database.debug_counter",
|
||||
debug_counter)
|
||||
unsafe_shutdowns: int = self.get_item(
|
||||
"moonraker", "database.unsafe_shutdowns", 0)
|
||||
"moonraker", "database.unsafe_shutdowns", 0).result()
|
||||
msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}"
|
||||
if debug_counter:
|
||||
msg += f"; Database Debug Count: {debug_counter}"
|
||||
|
@ -242,12 +245,15 @@ class MoonrakerDatabase:
|
|||
namespace: str,
|
||||
key: Union[List[str], str],
|
||||
drop_empty_db: bool = False
|
||||
) -> Any:
|
||||
) -> Future[Any]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._delete_impl, namespace, key, drop_empty_db)
|
||||
return cast(Future, self.eventloop.run_in_thread(
|
||||
self._delete_impl, namespace, key, drop_empty_db))
|
||||
else:
|
||||
return self._delete_impl(namespace, key, drop_empty_db)
|
||||
ret = self._delete_impl(namespace, key, drop_empty_db)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
|
||||
def _delete_impl(self,
|
||||
namespace: str,
|
||||
|
@ -291,12 +297,15 @@ class MoonrakerDatabase:
|
|||
namespace: str,
|
||||
key: Optional[Union[List[str], str]] = None,
|
||||
default: Any = SENTINEL
|
||||
) -> Any:
|
||||
) -> Future[Any]:
|
||||
if self.eventloop.is_running():
|
||||
return self.eventloop.run_in_thread(
|
||||
self._get_impl, namespace, key, default)
|
||||
return cast(Future, self.eventloop.run_in_thread(
|
||||
self._get_impl, namespace, key, default))
|
||||
else:
|
||||
return self._get_impl(namespace, key, default)
|
||||
ret = self._get_impl(namespace, key, default)
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return ret
|
||||
|
||||
def _get_impl(self,
|
||||
namespace: str,
|
||||
|
@ -474,19 +483,6 @@ class MoonrakerDatabase:
|
|||
f"Namespace '{namespace}' not found", 404)
|
||||
return NamespaceWrapper(namespace, self, parse_keys)
|
||||
|
||||
def wrap_async_namespace(self,
|
||||
namespace: str,
|
||||
parse_keys: bool = True
|
||||
) -> AsyncNamespaceWrapper:
|
||||
if self.eventloop.is_running():
|
||||
raise self.server.error(
|
||||
"Cannot wrap a namespace while the "
|
||||
"eventloop is running")
|
||||
if namespace not in self.namespaces:
|
||||
raise self.server.error(
|
||||
f"Namespace '{namespace}' not found", 404)
|
||||
return AsyncNamespaceWrapper(namespace, self, parse_keys)
|
||||
|
||||
def _process_key(self, key: Union[List[str], str]) -> List[str]:
|
||||
try:
|
||||
key_list = key if isinstance(key, list) else key.split('.')
|
||||
|
@ -568,11 +564,6 @@ class MoonrakerDatabase:
|
|||
raise self.server.error(
|
||||
f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
|
||||
|
||||
def can_call_sync(self, name: str = "") -> None:
|
||||
if self.eventloop.is_running():
|
||||
raise self.server.error(
|
||||
f"Cannot call method {name} while the eventloop is running")
|
||||
|
||||
async def _handle_list_request(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, List[str]]:
|
||||
|
@ -641,103 +632,6 @@ class MoonrakerDatabase:
|
|||
self.thread_lock.release()
|
||||
|
||||
class NamespaceWrapper:
|
||||
def __init__(self,
|
||||
namespace: str,
|
||||
database: MoonrakerDatabase,
|
||||
parse_keys: bool
|
||||
) -> None:
|
||||
self.namespace = namespace
|
||||
self.db = database
|
||||
# If parse keys is true, keys of a string type
|
||||
# will be passed straight to the DB methods.
|
||||
self.parse_keys = parse_keys
|
||||
|
||||
def to_async_wrapper(self) -> AsyncNamespaceWrapper:
|
||||
return AsyncNamespaceWrapper(self.namespace, self.db, self.parse_keys)
|
||||
|
||||
def insert(self, key: Union[List[str], str], value: DBType) -> None:
|
||||
self.db.can_call_sync("insert")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
self.db.insert_item(self.namespace, key, value)
|
||||
|
||||
def update_child(self, key: Union[List[str], str], value: DBType) -> None:
|
||||
self.db.can_call_sync("update_child")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
self.db.update_item(self.namespace, key, value)
|
||||
|
||||
def update(self, value: Dict[str, DBRecord]) -> None:
|
||||
self.db.can_call_sync("update")
|
||||
self.db.update_namespace(self.namespace, value)
|
||||
|
||||
def get(self,
|
||||
key: Union[List[str], str],
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
self.db.can_call_sync("get")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.get_item(self.namespace, key, default)
|
||||
|
||||
def delete(self, key: Union[List[str], str]) -> Any:
|
||||
self.db.can_call_sync("delete")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.delete_item(self.namespace, key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
self.db.can_call_sync("length")
|
||||
return self.db.ns_length(self.namespace)
|
||||
|
||||
def __getitem__(self, key: Union[List[str], str]) -> Any:
|
||||
return self.get(key, default=SENTINEL)
|
||||
|
||||
def __setitem__(self,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
) -> None:
|
||||
self.insert(key, value)
|
||||
|
||||
def __delitem__(self, key: Union[List[str], str]):
|
||||
self.delete(key)
|
||||
|
||||
def __contains__(self, key: Union[List[str], str]) -> bool:
|
||||
self.db.can_call_sync("contains")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.ns_contains(self.namespace, key)
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
self.db.can_call_sync("keys")
|
||||
return self.db.ns_keys(self.namespace)
|
||||
|
||||
def values(self) -> ValuesView:
|
||||
self.db.can_call_sync("values")
|
||||
return self.db.ns_values(self.namespace)
|
||||
|
||||
def items(self) -> ItemsView:
|
||||
self.db.can_call_sync("items")
|
||||
return self.db.ns_items(self.namespace)
|
||||
|
||||
def pop(self,
|
||||
key: Union[List[str], str],
|
||||
default: Any = SENTINEL
|
||||
) -> Any:
|
||||
self.db.can_call_sync("pop")
|
||||
try:
|
||||
val = self.delete(key)
|
||||
except Exception:
|
||||
if isinstance(default, SentinelClass):
|
||||
raise
|
||||
val = default
|
||||
return val
|
||||
|
||||
def clear(self) -> None:
|
||||
self.db.can_call_sync("clear")
|
||||
self.db.clear_namespace(self.namespace)
|
||||
|
||||
class AsyncNamespaceWrapper:
|
||||
def __init__(self,
|
||||
namespace: str,
|
||||
database: MoonrakerDatabase,
|
||||
|
@ -746,13 +640,11 @@ class AsyncNamespaceWrapper:
|
|||
self.namespace = namespace
|
||||
self.db = database
|
||||
self.eventloop = database.eventloop
|
||||
self.server = database.server
|
||||
# If parse keys is true, keys of a string type
|
||||
# will be passed straight to the DB methods.
|
||||
self.parse_keys = parse_keys
|
||||
|
||||
def to_sync_wrapper(self) -> NamespaceWrapper:
|
||||
return NamespaceWrapper(self.namespace, self.db, self.parse_keys)
|
||||
|
||||
def insert(self,
|
||||
key: Union[List[str], str],
|
||||
value: DBType
|
||||
|
@ -772,15 +664,15 @@ class AsyncNamespaceWrapper:
|
|||
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
|
||||
return self.db.update_namespace(self.namespace, value)
|
||||
|
||||
async def get(self,
|
||||
key: Union[List[str], str],
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
def get(self,
|
||||
key: Union[List[str], str],
|
||||
default: Any = None
|
||||
) -> Future[Any]:
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return await self.db.get_item(self.namespace, key, default)
|
||||
return self.db.get_item(self.namespace, key, default)
|
||||
|
||||
def delete(self, key: Union[List[str], str]) -> Awaitable[Any]:
|
||||
def delete(self, key: Union[List[str], str]) -> Future[Any]:
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.delete_item(self.namespace, key)
|
||||
|
@ -788,7 +680,7 @@ class AsyncNamespaceWrapper:
|
|||
async def length(self) -> int:
|
||||
return await self.db.ns_length_async(self.namespace)
|
||||
|
||||
def __getitem__(self, key: Union[List[str], str]) -> Coroutine:
|
||||
def __getitem__(self, key: Union[List[str], str]) -> Future[Any]:
|
||||
return self.get(key, default=SENTINEL)
|
||||
|
||||
def __setitem__(self,
|
||||
|
@ -800,24 +692,56 @@ class AsyncNamespaceWrapper:
|
|||
def __delitem__(self, key: Union[List[str], str]):
|
||||
self.delete(key)
|
||||
|
||||
def __contains__(self, key: Union[List[str], str]) -> bool:
|
||||
self._check_sync_method("__contains__")
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return self.db.ns_contains(self.namespace, key)
|
||||
|
||||
async def contains(self, key: Union[List[str], str]) -> bool:
|
||||
if isinstance(key, str) and not self.parse_keys:
|
||||
key = [key]
|
||||
return await self.db.ns_contains_async(self.namespace, key)
|
||||
|
||||
async def keys(self) -> List[str]:
|
||||
return await self.db.ns_keys_async(self.namespace)
|
||||
def keys(self) -> Future[List[str]]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_keys()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_keys_async(self.namespace))
|
||||
|
||||
async def values(self) -> ValuesView:
|
||||
return await self.db.ns_values_async(self.namespace)
|
||||
def values(self) -> Future[ValuesView]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_values()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_values_async(self.namespace))
|
||||
|
||||
async def items(self) -> ItemsView:
|
||||
return await self.db.ns_items_async(self.namespace)
|
||||
def items(self) -> Future[ItemsView]:
|
||||
if not self.eventloop.is_running:
|
||||
ret = self.db.ns_items()
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(ret)
|
||||
return fut
|
||||
return cast(Future, self.db.ns_items_async(self.namespace))
|
||||
|
||||
def pop(self,
|
||||
key: Union[List[str], str],
|
||||
default: Any = SENTINEL
|
||||
) -> Awaitable[Any]:
|
||||
) -> Union[Future[Any], Task[Any]]:
|
||||
if not self.eventloop.is_running():
|
||||
try:
|
||||
val = self.delete(key).result()
|
||||
except Exception:
|
||||
if isinstance(default, SentinelClass):
|
||||
raise
|
||||
val = default
|
||||
fut = self.eventloop.create_future()
|
||||
fut.set_result(val)
|
||||
return fut
|
||||
|
||||
async def _do_pop() -> Any:
|
||||
try:
|
||||
val = await self.delete(key)
|
||||
|
@ -831,6 +755,11 @@ class AsyncNamespaceWrapper:
|
|||
def clear(self) -> Awaitable[None]:
|
||||
return self.db.clear_namespace(self.namespace)
|
||||
|
||||
def _check_sync_method(self, func_name: str) -> None:
|
||||
if self.eventloop.is_running():
|
||||
raise self.server.error(
|
||||
f"Cannot call method {func_name} while "
|
||||
"the eventloop is running")
|
||||
|
||||
def load_component(config: ConfigHelper) -> MoonrakerDatabase:
|
||||
return MoonrakerDatabase(config)
|
||||
|
|
Loading…
Reference in New Issue