database: combine namespace wrappers
Rather than creating two wrappers, use a single wrapper whose methods always return a future or awaitable. If the operation occurs during the __init__() method of a component it will be syncrhonous, and the result from the future can be immediately queried. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
4d0a43cb25
commit
781e3e6250
|
@ -10,6 +10,7 @@ import json
|
||||||
import struct
|
import struct
|
||||||
import operator
|
import operator
|
||||||
import logging
|
import logging
|
||||||
|
from asyncio import Future, Task
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from threading import Lock as ThreadLock
|
from threading import Lock as ThreadLock
|
||||||
|
@ -21,7 +22,6 @@ from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Coroutine,
|
|
||||||
ItemsView,
|
ItemsView,
|
||||||
ValuesView,
|
ValuesView,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
@ -29,6 +29,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
cast
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from confighelper import ConfigHelper
|
from confighelper import ConfigHelper
|
||||||
|
@ -126,18 +127,20 @@ class MoonrakerDatabase:
|
||||||
# be granted by enabling the debug option. Forbidden namespaces
|
# be granted by enabling the debug option. Forbidden namespaces
|
||||||
# have no API access. This cannot be overridden.
|
# have no API access. This cannot be overridden.
|
||||||
self.protected_namespaces = set(self.get_item(
|
self.protected_namespaces = set(self.get_item(
|
||||||
"moonraker", "database.protected_namespaces", ["moonraker"]))
|
"moonraker", "database.protected_namespaces",
|
||||||
|
["moonraker"]).result())
|
||||||
self.forbidden_namespaces = set(self.get_item(
|
self.forbidden_namespaces = set(self.get_item(
|
||||||
"moonraker", "database.forbidden_namespaces", []))
|
"moonraker", "database.forbidden_namespaces",
|
||||||
|
[]).result())
|
||||||
# Track debug access and unsafe shutdowns
|
# Track debug access and unsafe shutdowns
|
||||||
debug_counter: int = self.get_item(
|
debug_counter: int = self.get_item(
|
||||||
"moonraker", "database.debug_counter", 0)
|
"moonraker", "database.debug_counter", 0).result()
|
||||||
if self.enable_debug:
|
if self.enable_debug:
|
||||||
debug_counter += 1
|
debug_counter += 1
|
||||||
self.insert_item("moonraker", "database.debug_counter",
|
self.insert_item("moonraker", "database.debug_counter",
|
||||||
debug_counter)
|
debug_counter)
|
||||||
unsafe_shutdowns: int = self.get_item(
|
unsafe_shutdowns: int = self.get_item(
|
||||||
"moonraker", "database.unsafe_shutdowns", 0)
|
"moonraker", "database.unsafe_shutdowns", 0).result()
|
||||||
msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}"
|
msg = f"Unsafe Shutdown Count: {unsafe_shutdowns}"
|
||||||
if debug_counter:
|
if debug_counter:
|
||||||
msg += f"; Database Debug Count: {debug_counter}"
|
msg += f"; Database Debug Count: {debug_counter}"
|
||||||
|
@ -242,12 +245,15 @@ class MoonrakerDatabase:
|
||||||
namespace: str,
|
namespace: str,
|
||||||
key: Union[List[str], str],
|
key: Union[List[str], str],
|
||||||
drop_empty_db: bool = False
|
drop_empty_db: bool = False
|
||||||
) -> Any:
|
) -> Future[Any]:
|
||||||
if self.eventloop.is_running():
|
if self.eventloop.is_running():
|
||||||
return self.eventloop.run_in_thread(
|
return cast(Future, self.eventloop.run_in_thread(
|
||||||
self._delete_impl, namespace, key, drop_empty_db)
|
self._delete_impl, namespace, key, drop_empty_db))
|
||||||
else:
|
else:
|
||||||
return self._delete_impl(namespace, key, drop_empty_db)
|
ret = self._delete_impl(namespace, key, drop_empty_db)
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(ret)
|
||||||
|
return fut
|
||||||
|
|
||||||
def _delete_impl(self,
|
def _delete_impl(self,
|
||||||
namespace: str,
|
namespace: str,
|
||||||
|
@ -291,12 +297,15 @@ class MoonrakerDatabase:
|
||||||
namespace: str,
|
namespace: str,
|
||||||
key: Optional[Union[List[str], str]] = None,
|
key: Optional[Union[List[str], str]] = None,
|
||||||
default: Any = SENTINEL
|
default: Any = SENTINEL
|
||||||
) -> Any:
|
) -> Future[Any]:
|
||||||
if self.eventloop.is_running():
|
if self.eventloop.is_running():
|
||||||
return self.eventloop.run_in_thread(
|
return cast(Future, self.eventloop.run_in_thread(
|
||||||
self._get_impl, namespace, key, default)
|
self._get_impl, namespace, key, default))
|
||||||
else:
|
else:
|
||||||
return self._get_impl(namespace, key, default)
|
ret = self._get_impl(namespace, key, default)
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(ret)
|
||||||
|
return ret
|
||||||
|
|
||||||
def _get_impl(self,
|
def _get_impl(self,
|
||||||
namespace: str,
|
namespace: str,
|
||||||
|
@ -474,19 +483,6 @@ class MoonrakerDatabase:
|
||||||
f"Namespace '{namespace}' not found", 404)
|
f"Namespace '{namespace}' not found", 404)
|
||||||
return NamespaceWrapper(namespace, self, parse_keys)
|
return NamespaceWrapper(namespace, self, parse_keys)
|
||||||
|
|
||||||
def wrap_async_namespace(self,
|
|
||||||
namespace: str,
|
|
||||||
parse_keys: bool = True
|
|
||||||
) -> AsyncNamespaceWrapper:
|
|
||||||
if self.eventloop.is_running():
|
|
||||||
raise self.server.error(
|
|
||||||
"Cannot wrap a namespace while the "
|
|
||||||
"eventloop is running")
|
|
||||||
if namespace not in self.namespaces:
|
|
||||||
raise self.server.error(
|
|
||||||
f"Namespace '{namespace}' not found", 404)
|
|
||||||
return AsyncNamespaceWrapper(namespace, self, parse_keys)
|
|
||||||
|
|
||||||
def _process_key(self, key: Union[List[str], str]) -> List[str]:
|
def _process_key(self, key: Union[List[str], str]) -> List[str]:
|
||||||
try:
|
try:
|
||||||
key_list = key if isinstance(key, list) else key.split('.')
|
key_list = key if isinstance(key, list) else key.split('.')
|
||||||
|
@ -568,11 +564,6 @@ class MoonrakerDatabase:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
|
f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
|
||||||
|
|
||||||
def can_call_sync(self, name: str = "") -> None:
|
|
||||||
if self.eventloop.is_running():
|
|
||||||
raise self.server.error(
|
|
||||||
f"Cannot call method {name} while the eventloop is running")
|
|
||||||
|
|
||||||
async def _handle_list_request(self,
|
async def _handle_list_request(self,
|
||||||
web_request: WebRequest
|
web_request: WebRequest
|
||||||
) -> Dict[str, List[str]]:
|
) -> Dict[str, List[str]]:
|
||||||
|
@ -641,103 +632,6 @@ class MoonrakerDatabase:
|
||||||
self.thread_lock.release()
|
self.thread_lock.release()
|
||||||
|
|
||||||
class NamespaceWrapper:
|
class NamespaceWrapper:
|
||||||
def __init__(self,
|
|
||||||
namespace: str,
|
|
||||||
database: MoonrakerDatabase,
|
|
||||||
parse_keys: bool
|
|
||||||
) -> None:
|
|
||||||
self.namespace = namespace
|
|
||||||
self.db = database
|
|
||||||
# If parse keys is true, keys of a string type
|
|
||||||
# will be passed straight to the DB methods.
|
|
||||||
self.parse_keys = parse_keys
|
|
||||||
|
|
||||||
def to_async_wrapper(self) -> AsyncNamespaceWrapper:
|
|
||||||
return AsyncNamespaceWrapper(self.namespace, self.db, self.parse_keys)
|
|
||||||
|
|
||||||
def insert(self, key: Union[List[str], str], value: DBType) -> None:
|
|
||||||
self.db.can_call_sync("insert")
|
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
|
||||||
key = [key]
|
|
||||||
self.db.insert_item(self.namespace, key, value)
|
|
||||||
|
|
||||||
def update_child(self, key: Union[List[str], str], value: DBType) -> None:
|
|
||||||
self.db.can_call_sync("update_child")
|
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
|
||||||
key = [key]
|
|
||||||
self.db.update_item(self.namespace, key, value)
|
|
||||||
|
|
||||||
def update(self, value: Dict[str, DBRecord]) -> None:
|
|
||||||
self.db.can_call_sync("update")
|
|
||||||
self.db.update_namespace(self.namespace, value)
|
|
||||||
|
|
||||||
def get(self,
|
|
||||||
key: Union[List[str], str],
|
|
||||||
default: Any = None
|
|
||||||
) -> Any:
|
|
||||||
self.db.can_call_sync("get")
|
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
|
||||||
key = [key]
|
|
||||||
return self.db.get_item(self.namespace, key, default)
|
|
||||||
|
|
||||||
def delete(self, key: Union[List[str], str]) -> Any:
|
|
||||||
self.db.can_call_sync("delete")
|
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
|
||||||
key = [key]
|
|
||||||
return self.db.delete_item(self.namespace, key)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
self.db.can_call_sync("length")
|
|
||||||
return self.db.ns_length(self.namespace)
|
|
||||||
|
|
||||||
def __getitem__(self, key: Union[List[str], str]) -> Any:
|
|
||||||
return self.get(key, default=SENTINEL)
|
|
||||||
|
|
||||||
def __setitem__(self,
|
|
||||||
key: Union[List[str], str],
|
|
||||||
value: DBType
|
|
||||||
) -> None:
|
|
||||||
self.insert(key, value)
|
|
||||||
|
|
||||||
def __delitem__(self, key: Union[List[str], str]):
|
|
||||||
self.delete(key)
|
|
||||||
|
|
||||||
def __contains__(self, key: Union[List[str], str]) -> bool:
|
|
||||||
self.db.can_call_sync("contains")
|
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
|
||||||
key = [key]
|
|
||||||
return self.db.ns_contains(self.namespace, key)
|
|
||||||
|
|
||||||
def keys(self) -> List[str]:
|
|
||||||
self.db.can_call_sync("keys")
|
|
||||||
return self.db.ns_keys(self.namespace)
|
|
||||||
|
|
||||||
def values(self) -> ValuesView:
|
|
||||||
self.db.can_call_sync("values")
|
|
||||||
return self.db.ns_values(self.namespace)
|
|
||||||
|
|
||||||
def items(self) -> ItemsView:
|
|
||||||
self.db.can_call_sync("items")
|
|
||||||
return self.db.ns_items(self.namespace)
|
|
||||||
|
|
||||||
def pop(self,
|
|
||||||
key: Union[List[str], str],
|
|
||||||
default: Any = SENTINEL
|
|
||||||
) -> Any:
|
|
||||||
self.db.can_call_sync("pop")
|
|
||||||
try:
|
|
||||||
val = self.delete(key)
|
|
||||||
except Exception:
|
|
||||||
if isinstance(default, SentinelClass):
|
|
||||||
raise
|
|
||||||
val = default
|
|
||||||
return val
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
self.db.can_call_sync("clear")
|
|
||||||
self.db.clear_namespace(self.namespace)
|
|
||||||
|
|
||||||
class AsyncNamespaceWrapper:
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
namespace: str,
|
namespace: str,
|
||||||
database: MoonrakerDatabase,
|
database: MoonrakerDatabase,
|
||||||
|
@ -746,13 +640,11 @@ class AsyncNamespaceWrapper:
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.db = database
|
self.db = database
|
||||||
self.eventloop = database.eventloop
|
self.eventloop = database.eventloop
|
||||||
|
self.server = database.server
|
||||||
# If parse keys is true, keys of a string type
|
# If parse keys is true, keys of a string type
|
||||||
# will be passed straight to the DB methods.
|
# will be passed straight to the DB methods.
|
||||||
self.parse_keys = parse_keys
|
self.parse_keys = parse_keys
|
||||||
|
|
||||||
def to_sync_wrapper(self) -> NamespaceWrapper:
|
|
||||||
return NamespaceWrapper(self.namespace, self.db, self.parse_keys)
|
|
||||||
|
|
||||||
def insert(self,
|
def insert(self,
|
||||||
key: Union[List[str], str],
|
key: Union[List[str], str],
|
||||||
value: DBType
|
value: DBType
|
||||||
|
@ -772,15 +664,15 @@ class AsyncNamespaceWrapper:
|
||||||
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
|
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
|
||||||
return self.db.update_namespace(self.namespace, value)
|
return self.db.update_namespace(self.namespace, value)
|
||||||
|
|
||||||
async def get(self,
|
def get(self,
|
||||||
key: Union[List[str], str],
|
key: Union[List[str], str],
|
||||||
default: Any = None
|
default: Any = None
|
||||||
) -> Any:
|
) -> Future[Any]:
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
if isinstance(key, str) and not self.parse_keys:
|
||||||
key = [key]
|
key = [key]
|
||||||
return await self.db.get_item(self.namespace, key, default)
|
return self.db.get_item(self.namespace, key, default)
|
||||||
|
|
||||||
def delete(self, key: Union[List[str], str]) -> Awaitable[Any]:
|
def delete(self, key: Union[List[str], str]) -> Future[Any]:
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
if isinstance(key, str) and not self.parse_keys:
|
||||||
key = [key]
|
key = [key]
|
||||||
return self.db.delete_item(self.namespace, key)
|
return self.db.delete_item(self.namespace, key)
|
||||||
|
@ -788,7 +680,7 @@ class AsyncNamespaceWrapper:
|
||||||
async def length(self) -> int:
|
async def length(self) -> int:
|
||||||
return await self.db.ns_length_async(self.namespace)
|
return await self.db.ns_length_async(self.namespace)
|
||||||
|
|
||||||
def __getitem__(self, key: Union[List[str], str]) -> Coroutine:
|
def __getitem__(self, key: Union[List[str], str]) -> Future[Any]:
|
||||||
return self.get(key, default=SENTINEL)
|
return self.get(key, default=SENTINEL)
|
||||||
|
|
||||||
def __setitem__(self,
|
def __setitem__(self,
|
||||||
|
@ -800,24 +692,56 @@ class AsyncNamespaceWrapper:
|
||||||
def __delitem__(self, key: Union[List[str], str]):
|
def __delitem__(self, key: Union[List[str], str]):
|
||||||
self.delete(key)
|
self.delete(key)
|
||||||
|
|
||||||
|
def __contains__(self, key: Union[List[str], str]) -> bool:
|
||||||
|
self._check_sync_method("__contains__")
|
||||||
|
if isinstance(key, str) and not self.parse_keys:
|
||||||
|
key = [key]
|
||||||
|
return self.db.ns_contains(self.namespace, key)
|
||||||
|
|
||||||
async def contains(self, key: Union[List[str], str]) -> bool:
|
async def contains(self, key: Union[List[str], str]) -> bool:
|
||||||
if isinstance(key, str) and not self.parse_keys:
|
if isinstance(key, str) and not self.parse_keys:
|
||||||
key = [key]
|
key = [key]
|
||||||
return await self.db.ns_contains_async(self.namespace, key)
|
return await self.db.ns_contains_async(self.namespace, key)
|
||||||
|
|
||||||
async def keys(self) -> List[str]:
|
def keys(self) -> Future[List[str]]:
|
||||||
return await self.db.ns_keys_async(self.namespace)
|
if not self.eventloop.is_running:
|
||||||
|
ret = self.db.ns_keys()
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(ret)
|
||||||
|
return fut
|
||||||
|
return cast(Future, self.db.ns_keys_async(self.namespace))
|
||||||
|
|
||||||
async def values(self) -> ValuesView:
|
def values(self) -> Future[ValuesView]:
|
||||||
return await self.db.ns_values_async(self.namespace)
|
if not self.eventloop.is_running:
|
||||||
|
ret = self.db.ns_values()
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(ret)
|
||||||
|
return fut
|
||||||
|
return cast(Future, self.db.ns_values_async(self.namespace))
|
||||||
|
|
||||||
async def items(self) -> ItemsView:
|
def items(self) -> Future[ItemsView]:
|
||||||
return await self.db.ns_items_async(self.namespace)
|
if not self.eventloop.is_running:
|
||||||
|
ret = self.db.ns_items()
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(ret)
|
||||||
|
return fut
|
||||||
|
return cast(Future, self.db.ns_items_async(self.namespace))
|
||||||
|
|
||||||
def pop(self,
|
def pop(self,
|
||||||
key: Union[List[str], str],
|
key: Union[List[str], str],
|
||||||
default: Any = SENTINEL
|
default: Any = SENTINEL
|
||||||
) -> Awaitable[Any]:
|
) -> Union[Future[Any], Task[Any]]:
|
||||||
|
if not self.eventloop.is_running():
|
||||||
|
try:
|
||||||
|
val = self.delete(key).result()
|
||||||
|
except Exception:
|
||||||
|
if isinstance(default, SentinelClass):
|
||||||
|
raise
|
||||||
|
val = default
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
fut.set_result(val)
|
||||||
|
return fut
|
||||||
|
|
||||||
async def _do_pop() -> Any:
|
async def _do_pop() -> Any:
|
||||||
try:
|
try:
|
||||||
val = await self.delete(key)
|
val = await self.delete(key)
|
||||||
|
@ -831,6 +755,11 @@ class AsyncNamespaceWrapper:
|
||||||
def clear(self) -> Awaitable[None]:
|
def clear(self) -> Awaitable[None]:
|
||||||
return self.db.clear_namespace(self.namespace)
|
return self.db.clear_namespace(self.namespace)
|
||||||
|
|
||||||
|
def _check_sync_method(self, func_name: str) -> None:
|
||||||
|
if self.eventloop.is_running():
|
||||||
|
raise self.server.error(
|
||||||
|
f"Cannot call method {func_name} while "
|
||||||
|
"the eventloop is running")
|
||||||
|
|
||||||
def load_component(config: ConfigHelper) -> MoonrakerDatabase:
|
def load_component(config: ConfigHelper) -> MoonrakerDatabase:
|
||||||
return MoonrakerDatabase(config)
|
return MoonrakerDatabase(config)
|
||||||
|
|
Loading…
Reference in New Issue