From e029b6c582ac1305ce2dd74dacf38db2c62ae4fe Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Mon, 31 Jan 2022 11:45:22 -0500 Subject: [PATCH] database: refactor to remove duplicate code Wrap command implementations in with a _run_command() method. All database commands now return a Future object. If the command was run before the eventloop starts its possible to immediately query the Future's result. Signed-off-by: Eric Callahan --- moonraker/components/database.py | 517 ++++++++++++++----------------- 1 file changed, 235 insertions(+), 282 deletions(-) diff --git a/moonraker/components/database.py b/moonraker/components/database.py index 0fbb7f9..b124a4f 100644 --- a/moonraker/components/database.py +++ b/moonraker/components/database.py @@ -22,9 +22,9 @@ from typing import ( TYPE_CHECKING, Any, Awaitable, - ItemsView, + Callable, Mapping, - ValuesView, + TypeVar, Tuple, Optional, Union, @@ -37,6 +37,7 @@ if TYPE_CHECKING: from websockets import WebRequest DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]] DBType = Optional[DBRecord] + _T = TypeVar("_T") DATABASE_VERSION = 1 MAX_NAMESPACES = 100 @@ -157,343 +158,310 @@ class MoonrakerDatabase: "/server/database/item", ["GET", "POST", "DELETE"], self._handle_item_request) + def _run_command(self, + command_func: Callable[..., _T], + *args + ) -> Future[_T]: + def func_wrapper(): + with self.thread_lock: + return command_func(*args) + + if self.eventloop.is_running(): + return cast(Future, self.eventloop.run_in_thread(func_wrapper)) + else: + ret = func_wrapper() + fut = self.eventloop.create_future() + fut.set_result(ret) + return fut + def insert_item(self, namespace: str, key: Union[List[str], str], value: DBType - ) -> Awaitable[None]: - if self.eventloop.is_running(): - return self.eventloop.run_in_thread( - self._insert_impl, namespace, key, value) - else: - self._insert_impl(namespace, key, value) - fut = self.eventloop.create_future() - fut.set_result(None) - return fut + ) -> Future[None]: + return self._run_command(self._insert_impl, namespace, key, value) def _insert_impl(self, namespace: str, key: Union[List[str], str], value: DBType ) -> None: - with self.thread_lock: - key_list = self._process_key(key) - if namespace not in self.namespaces: - self.namespaces[namespace] = self.lmdb_env.open_db( - namespace.encode()) - record = value - if len(key_list) > 1: - record = self._get_record(namespace, key_list[0], force=True) - if not isinstance(record, dict): - record = {} - logging.info( - f"Warning: Key {key_list[0]} contains a value of type" - f" {type(record)}. Overwriting with an object.") - item: Dict[str, Any] = reduce( - getitem_with_default, key_list[1:-1], record) - item[key_list[-1]] = value - if not self._insert_record(namespace, key_list[0], record): + key_list = self._process_key(key) + if namespace not in self.namespaces: + self.namespaces[namespace] = self.lmdb_env.open_db( + namespace.encode()) + record = value + if len(key_list) > 1: + record = self._get_record(namespace, key_list[0], force=True) + if not isinstance(record, dict): + record = {} logging.info( - f"Error inserting key '{key}' in namespace '{namespace}'") + f"Warning: Key {key_list[0]} contains a value of type" + f" {type(record)}. Overwriting with an object.") + item: Dict[str, Any] = reduce( + getitem_with_default, key_list[1:-1], record) + item[key_list[-1]] = value + if not self._insert_record(namespace, key_list[0], record): + logging.info( + f"Error inserting key '{key}' in namespace '{namespace}'") def update_item(self, namespace: str, key: Union[List[str], str], value: DBType - ) -> Awaitable[None]: - if self.eventloop.is_running(): - return self.eventloop.run_in_thread( - self._update_impl, namespace, key, value) - else: - self._update_impl(namespace, key, value) - fut = self.eventloop.create_future() - fut.set_result(None) - return fut + ) -> Future[None]: + return self._run_command(self._update_impl, namespace, key, value) def _update_impl(self, namespace: str, key: Union[List[str], str], value: DBType ) -> None: - with self.thread_lock: - key_list = self._process_key(key) - record = self._get_record(namespace, key_list[0]) - if len(key_list) == 1: - if isinstance(record, dict) and isinstance(value, dict): - record.update(value) - else: - assert value is not None - record = value + key_list = self._process_key(key) + record = self._get_record(namespace, key_list[0]) + if len(key_list) == 1: + if isinstance(record, dict) and isinstance(value, dict): + record.update(value) else: - try: - assert isinstance(record, dict) - item: Dict[str, Any] = reduce( - operator.getitem, key_list[1:-1], record) - except Exception: - raise self.server.error( - f"Key '{key}' in namespace '{namespace}' not found", - 404) - if isinstance(item[key_list[-1]], dict) \ - and isinstance(value, dict): - item[key_list[-1]].update(value) - else: - item[key_list[-1]] = value - if not self._insert_record(namespace, key_list[0], record): - logging.info( - f"Error updating key '{key}' in namespace '{namespace}'") + assert value is not None + record = value + else: + try: + assert isinstance(record, dict) + item: Dict[str, Any] = reduce( + operator.getitem, key_list[1:-1], record) + except Exception: + raise self.server.error( + f"Key '{key}' in namespace '{namespace}' not found", + 404) + if isinstance(item[key_list[-1]], dict) \ + and isinstance(value, dict): + item[key_list[-1]].update(value) + else: + item[key_list[-1]] = value + if not self._insert_record(namespace, key_list[0], record): + logging.info( + f"Error updating key '{key}' in namespace '{namespace}'") def delete_item(self, namespace: str, key: Union[List[str], str], drop_empty_db: bool = False ) -> Future[Any]: - if self.eventloop.is_running(): - return cast(Future, self.eventloop.run_in_thread( - self._delete_impl, namespace, key, drop_empty_db)) - else: - ret = self._delete_impl(namespace, key, drop_empty_db) - fut = self.eventloop.create_future() - fut.set_result(ret) - return fut + return self._run_command(self._delete_impl, namespace, key, + drop_empty_db) def _delete_impl(self, namespace: str, key: Union[List[str], str], drop_empty_db: bool = False ) -> Any: - with self.thread_lock: - key_list = self._process_key(key) - val = record = self._get_record(namespace, key_list[0]) - remove_record = True - if len(key_list) > 1: - try: - assert isinstance(record, dict) - item: Dict[str, Any] = reduce( - operator.getitem, key_list[1:-1], record) - val = item.pop(key_list[-1]) - except Exception: - raise self.server.error( - f"Key '{key}' in namespace '{namespace}' not found", - 404) - remove_record = False if record else True - if remove_record: - db = self.namespaces[namespace] - with ( - self.lmdb_env.begin(write=True, buffers=True, db=db) as txn - ): - ret = txn.delete(key_list[0].encode()) - with txn.cursor() as cursor: - if not cursor.first() and drop_empty_db: - txn.drop(db) - del self.namespaces[namespace] - else: - ret = self._insert_record(namespace, key_list[0], record) - if not ret: - logging.info( - f"Error deleting key '{key}' from namespace " - f"'{namespace}'") - return val + key_list = self._process_key(key) + val = record = self._get_record(namespace, key_list[0]) + remove_record = True + if len(key_list) > 1: + try: + assert isinstance(record, dict) + item: Dict[str, Any] = reduce( + operator.getitem, key_list[1:-1], record) + val = item.pop(key_list[-1]) + except Exception: + raise self.server.error( + f"Key '{key}' in namespace '{namespace}' not found", + 404) + remove_record = False if record else True + if remove_record: + db = self.namespaces[namespace] + with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: + ret = txn.delete(key_list[0].encode()) + with txn.cursor() as cursor: + if not cursor.first() and drop_empty_db: + txn.drop(db) + del self.namespaces[namespace] + else: + ret = self._insert_record(namespace, key_list[0], record) + if not ret: + logging.info( + f"Error deleting key '{key}' from namespace " + f"'{namespace}'") + return val def get_item(self, namespace: str, key: Optional[Union[List[str], str]] = None, default: Any = SENTINEL ) -> Future[Any]: - if self.eventloop.is_running(): - return cast(Future, self.eventloop.run_in_thread( - self._get_impl, namespace, key, default)) - else: - ret = self._get_impl(namespace, key, default) - fut = self.eventloop.create_future() - fut.set_result(ret) - return ret + return self._run_command(self._get_impl, namespace, key, default) def _get_impl(self, namespace: str, key: Optional[Union[List[str], str]] = None, default: Any = SENTINEL ) -> Any: - with self.thread_lock: - try: - if key is None: - return self._get_namespace(namespace) - key_list = self._process_key(key) - ns = self._get_record(namespace, key_list[0]) - val = reduce(operator.getitem, # type: ignore - key_list[1:], ns) - except Exception: - if not isinstance(default, SentinelClass): - return default - raise self.server.error( - f"Key '{key}' in namespace '{namespace}' not found", 404) - return val + try: + if key is None: + return self._get_namespace(namespace) + key_list = self._process_key(key) + ns = self._get_record(namespace, key_list[0]) + val = reduce(operator.getitem, # type: ignore + key_list[1:], ns) + except Exception: + if not isinstance(default, SentinelClass): + return default + raise self.server.error( + f"Key '{key}' in namespace '{namespace}' not found", 404) + return val def update_namespace(self, namespace: str, value: Mapping[str, DBRecord] - ) -> Awaitable[None]: - if self.eventloop.is_running(): - return self.eventloop.run_in_thread( - self._update_ns_impl, namespace, value) - else: - self._update_ns_impl(namespace, value) - fut = self.eventloop.create_future() - fut.set_result(None) - return fut + ) -> Future[None]: + return self._run_command(self._update_ns_impl, namespace, value) def _update_ns_impl(self, namespace: str, value: Mapping[str, DBRecord] ) -> None: - with self.thread_lock: - if not value: - return - if namespace not in self.namespaces: - raise self.server.error( - f"Invalid database namespace '{namespace}'") - db = self.namespaces[namespace] - with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: - # We only need to update the keys that changed - for key, val in value.items(): - stored = txn.get(key.encode()) - if stored is not None: - decoded = self._decode_value(stored) - if val == decoded: - continue - ret = txn.put(key.encode(), self._encode_value(val)) - if not ret: - logging.info(f"Error inserting key '{key}' " - f"in namespace '{namespace}'") + if not value: + return + if namespace not in self.namespaces: + raise self.server.error( + f"Invalid database namespace '{namespace}'") + db = self.namespaces[namespace] + with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: + # We only need to update the keys that changed + for key, val in value.items(): + stored = txn.get(key.encode()) + if stored is not None: + decoded = self._decode_value(stored) + if val == decoded: + continue + ret = txn.put(key.encode(), self._encode_value(val)) + if not ret: + logging.info(f"Error inserting key '{key}' " + f"in namespace '{namespace}'") def clear_namespace(self, namespace: str, drop_empty_db: bool = False - ) -> Awaitable[None]: - if self.eventloop.is_running(): - return self.eventloop.run_in_thread( - self._clear_ns_impl, namespace, drop_empty_db) - else: - self._clear_ns_impl(namespace, drop_empty_db) - fut = self.eventloop.create_future() - fut.set_result(None) - return fut + ) -> Future[None]: + return self._run_command(self._clear_ns_impl, namespace, drop_empty_db) def _clear_ns_impl(self, namespace: str, drop_empty_db: bool = False ) -> None: - with self.thread_lock: - if namespace not in self.namespaces: - raise self.server.error( - f"Invalid database namespace '{namespace}'") - db = self.namespaces[namespace] - with self.lmdb_env.begin(write=True, db=db) as txn: - txn.drop(db, delete=drop_empty_db) - if drop_empty_db: - del self.namespaces[namespace] + if namespace not in self.namespaces: + raise self.server.error( + f"Invalid database namespace '{namespace}'") + db = self.namespaces[namespace] + with self.lmdb_env.begin(write=True, db=db) as txn: + txn.drop(db, delete=drop_empty_db) + if drop_empty_db: + del self.namespaces[namespace] def sync_namespace(self, namespace: str, value: Mapping[str, DBRecord] - ) -> Awaitable[None]: - if self.eventloop.is_running(): - return self.eventloop.run_in_thread( - self._sync_ns_impl, namespace, value) - else: - self._sync_ns_impl(namespace, value) - fut = self.eventloop.create_future() - fut.set_result(None) - return fut + ) -> Future[None]: + return self._run_command(self._sync_ns_impl, namespace, value) def _sync_ns_impl(self, namespace: str, value: Mapping[str, DBRecord] ) -> None: - with self.thread_lock: - if not value: - return - if namespace not in self.namespaces: - raise self.server.error( - f"Invalid database namespace '{namespace}'") - db = self.namespaces[namespace] - new_keys = set(value.keys()) - with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: - with txn.cursor() as cursor: - remaining = cursor.first() - while remaining: - bkey, bval = cursor.item() - key = bytes(bkey).decode() - if key not in value: - remaining = cursor.delete() - else: - decoded = self._decode_value(bval) - if decoded != value[key]: - cursor.put(self._encode_value(value[key])) - new_keys.remove(key) - remaining = cursor.next() - for k in new_keys: - val = value[k] - ret = txn.put(key.encode(), self._encode_value(val)) - if not ret: - logging.info(f"Error inserting key '{k}' " - f"in namespace '{namespace}'") - - async def ns_length_async(self, namespace: str) -> int: - return len(await self.ns_keys_async(namespace)) - - def ns_length(self, namespace: str) -> int: - return len(self.ns_keys(namespace)) - - def ns_keys_async(self, namespace: str) -> Awaitable[List[str]]: - return self.eventloop.run_in_thread(self.ns_keys, namespace) - - def ns_keys(self, namespace: str) -> List[str]: - with self.thread_lock: - keys: List[str] = [] - db = self.namespaces[namespace] - with self.lmdb_env.begin(db=db) as txn: - with txn.cursor() as cursor: - remaining = cursor.first() - while remaining: - keys.append(cursor.key().decode()) + if not value: + return + if namespace not in self.namespaces: + raise self.server.error( + f"Invalid database namespace '{namespace}'") + db = self.namespaces[namespace] + new_keys = set(value.keys()) + with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: + with txn.cursor() as cursor: + remaining = cursor.first() + while remaining: + bkey, bval = cursor.item() + key = bytes(bkey).decode() + if key not in value: + remaining = cursor.delete() + else: + decoded = self._decode_value(bval) + if decoded != value[key]: + cursor.put(self._encode_value(value[key])) + new_keys.remove(key) remaining = cursor.next() - return keys + for k in new_keys: + val = value[k] + ret = txn.put(key.encode(), self._encode_value(val)) + if not ret: + logging.info(f"Error inserting key '{k}' " + f"in namespace '{namespace}'") - def ns_values_async(self, namespace: str) -> Awaitable[ValuesView]: - return self.eventloop.run_in_thread(self.ns_values, namespace) + def ns_length(self, namespace: str) -> Future[int]: + return self._run_command(self._ns_length_impl, namespace) - def ns_values(self, namespace: str) -> ValuesView: - with self.thread_lock: - ns = self._get_namespace(namespace) - return ns.values() + def _ns_length_impl(self, namespace: str) -> int: + db = self.namespaces[namespace] + with self.lmdb_env.begin(db=db) as txn: + stats = txn.stat(db) + return stats['entries'] - def ns_items_async(self, namespace: str) -> Awaitable[ItemsView]: - return self.eventloop.run_in_thread(self.ns_items, namespace) + def ns_keys(self, namespace: str) -> Future[List[str]]: + return self._run_command(self._ns_keys_impl, namespace) - def ns_items(self, namespace: str) -> ItemsView: - with self.thread_lock: - ns = self._get_namespace(namespace) - return ns.items() + def _ns_keys_impl(self, namespace: str) -> List[str]: + keys: List[str] = [] + db = self.namespaces[namespace] + with self.lmdb_env.begin(db=db) as txn: + with txn.cursor() as cursor: + remaining = cursor.first() + while remaining: + keys.append(cursor.key().decode()) + remaining = cursor.next() + return keys - def ns_contains_async(self, + def ns_values(self, namespace: str) -> Future[List[Any]]: + return self._run_command(self._ns_values_impl, namespace) + + def _ns_values_impl(self, namespace: str) -> List[Any]: + values: List[Any] = [] + db = self.namespaces[namespace] + with self.lmdb_env.begin(db=db, buffers=True) as txn: + with txn.cursor() as cursor: + remaining = cursor.first() + while remaining: + values.append(self._decode_value(cursor.value())) + remaining = cursor.next() + return values + + def ns_items(self, namespace: str) -> Future[List[Tuple[str, Any]]]: + return self._run_command(self._ns_items_impl, namespace) + + def _ns_items_impl(self, namespace: str) -> List[Tuple[str, Any]]: + ns = self._get_namespace(namespace) + return list(ns.items()) + + def ns_contains(self, + namespace: str, + key: Union[List[str], str] + ) -> Future[bool]: + return self._run_command(self._ns_contains_impl, namespace, key) + + def _ns_contains_impl(self, namespace: str, key: Union[List[str], str] - ) -> Awaitable[bool]: - return self.eventloop.run_in_thread( - self.ns_contains, namespace, key) - - def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool: - with self.thread_lock: - try: - key_list = self._process_key(key) - record = self._get_record(namespace, key_list[0]) - if len(key_list) == 1: - return True - reduce(operator.getitem, # type: ignore - key_list[1:], record) - except Exception: - return False - return True + ) -> bool: + try: + key_list = self._process_key(key) + record = self._get_record(namespace, key_list[0]) + if len(key_list) == 1: + return True + reduce(operator.getitem, # type: ignore + key_list[1:], record) + except Exception: + return False + return True def register_local_namespace(self, namespace: str, @@ -727,8 +695,8 @@ class NamespaceWrapper: key = [key] return self.db.delete_item(self.namespace, key) - async def length(self) -> int: - return await self.db.ns_length_async(self.namespace) + def length(self) -> Future[int]: + return self.db.ns_length(self.namespace) def as_dict(self) -> Dict[str, Any]: self._check_sync_method("as_dict") @@ -750,36 +718,21 @@ class NamespaceWrapper: self._check_sync_method("__contains__") if isinstance(key, str) and not self.parse_keys: key = [key] - return self.db.ns_contains(self.namespace, key) + return self.db.ns_contains(self.namespace, key).result() - async def contains(self, key: Union[List[str], str]) -> bool: + def contains(self, key: Union[List[str], str]) -> Future[bool]: if isinstance(key, str) and not self.parse_keys: key = [key] - return await self.db.ns_contains_async(self.namespace, key) + return self.db.ns_contains(self.namespace, key) def keys(self) -> Future[List[str]]: - if not self.eventloop.is_running: - ret = self.db.ns_keys() - fut = self.eventloop.create_future() - fut.set_result(ret) - return fut - return cast(Future, self.db.ns_keys_async(self.namespace)) + return self.db.ns_keys(self.namespace) - def values(self) -> Future[ValuesView]: - if not self.eventloop.is_running: - ret = self.db.ns_values() - fut = self.eventloop.create_future() - fut.set_result(ret) - return fut - return cast(Future, self.db.ns_values_async(self.namespace)) + def values(self) -> Future[List[Any]]: + return self.db.ns_values(self.namespace) - def items(self) -> Future[ItemsView]: - if not self.eventloop.is_running: - ret = self.db.ns_items() - fut = self.eventloop.create_future() - fut.set_result(ret) - return fut - return cast(Future, self.db.ns_items_async(self.namespace)) + def items(self) -> Future[List[Tuple[str, Any]]]: + return self.db.ns_items(self.namespace) def pop(self, key: Union[List[str], str],