From 410db750c698911dca9306fe7fa9b6cd3815b276 Mon Sep 17 00:00:00 2001 From: Arksine Date: Thu, 13 May 2021 10:48:03 -0400 Subject: [PATCH] database: add annotations Signed-off-by: Eric Callahan --- moonraker/components/database.py | 181 ++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 53 deletions(-) diff --git a/moonraker/components/database.py b/moonraker/components/database.py index 9a7f5c5..24428a1 100644 --- a/moonraker/components/database.py +++ b/moonraker/components/database.py @@ -3,6 +3,8 @@ # Copyright (C) 2021 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license. + +from __future__ import annotations import os import json import struct @@ -11,6 +13,25 @@ import logging from io import BytesIO from functools import reduce import lmdb +from utils import SentinelClass + +# Annotation imports +from typing import ( + TYPE_CHECKING, + Any, + ItemsView, + ValuesView, + Tuple, + Optional, + Union, + Dict, + List, +) +if TYPE_CHECKING: + from confighelper import ConfigHelper + from websockets import WebRequest + DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]] + DBType = Optional[DBRecord] DATABASE_VERSION = 1 MAX_NAMESPACES = 50 @@ -34,19 +55,19 @@ RECORD_DECODE_FUNCS = { ord("{"): lambda x: json.load(BytesIO(x)), } -def getitem_with_default(item, field): +SENTINEL = SentinelClass.get_instance() + +def getitem_with_default(item: Dict, field: Any) -> Any: if field not in item: item[field] = {} return item[field] -class Sentinel: - pass class MoonrakerDatabase: - def __init__(self, config): + def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() - self.namespaces = {} - self.enable_debug = config.get("enable_database_debug", False) + self.namespaces: Dict[str, object] = {} + self.enable_debug = config.getboolean("enable_database_debug", False) self.database_path = os.path.expanduser(config.get( 'database_path', "~/.moonraker_database")) if not os.path.isdir(self.database_path): @@ -75,7 +96,8 @@ class MoonrakerDatabase: "moonraker", "database.protected_namespaces", ["moonraker"])) self.forbidden_namespaces = set(self.get_item( "moonraker", "database.forbidden_namespaces", [])) - debug_counter = self.get_item("moonraker", "database.debug_counter", 0) + debug_counter: int = self.get_item( + "moonraker", "database.debug_counter", 0) if self.enable_debug: debug_counter += 1 self.insert_item("moonraker", "database.debug_counter", @@ -88,7 +110,11 @@ class MoonrakerDatabase: "/server/database/item", ["GET", "POST", "DELETE"], self._handle_item_request) - def insert_item(self, namespace, key, value): + def insert_item(self, + namespace: str, + key: Union[List[str], str], + value: DBType + ) -> None: key_list = self._process_key(key) if namespace not in self.namespaces: self.namespaces[namespace] = self.lmdb_env.open_db( @@ -101,23 +127,31 @@ class MoonrakerDatabase: logging.info( f"Warning: Key {key_list[0]} contains a value of type " f"{type(record)}. Overwriting with an object.") - item = reduce(getitem_with_default, key_list[1:-1], record) + item: Dict[str, Any] = reduce(getitem_with_default, key_list[1:-1], + record) item[key_list[-1]] = value if not self._insert_record(namespace, key_list[0], record): logging.info( f"Error inserting key '{key}' in namespace '{namespace}'") - def update_item(self, namespace, key, value): + def update_item(self, + namespace: str, + key: Union[List[str], str], + value: DBType + ) -> None: key_list = self._process_key(key) record = self._get_record(namespace, key_list[0]) if len(key_list) == 1: if isinstance(record, dict) and isinstance(value, dict): record.update(value) else: + assert value is not None record = value else: try: - item = reduce(operator.getitem, key_list[1:-1], record) + assert isinstance(record, dict) + item: Dict[str, Any] = reduce( + operator.getitem, key_list[1:-1], record) except Exception: raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) @@ -130,13 +164,19 @@ class MoonrakerDatabase: logging.info( f"Error updating key '{key}' in namespace '{namespace}'") - def delete_item(self, namespace, key, drop_empty_db=False): + def delete_item(self, + namespace: str, + key: Union[List[str], str], + drop_empty_db: bool = False + ) -> Any: key_list = self._process_key(key) val = record = self._get_record(namespace, key_list[0]) remove_record = True if len(key_list) > 1: try: - item = reduce(operator.getitem, key_list[1:-1], record) + assert isinstance(record, dict) + item: Dict[str, Any] = reduce( + operator.getitem, key_list[1:-1], record) val = item.pop(key_list[-1]) except Exception: raise self.server.error( @@ -157,25 +197,29 @@ class MoonrakerDatabase: f"Error deleting key '{key}' from namespace '{namespace}'") return val - def get_item(self, namespace, key=None, default=Sentinel): + def get_item(self, + namespace: str, + key: Optional[Union[List[str], str]] = None, + default: Any = SENTINEL + ) -> Any: try: if key is None: return self._get_namespace(namespace) key_list = self._process_key(key) ns = self._get_record(namespace, key_list[0]) - val = reduce(operator.getitem, key_list[1:], ns) + val = reduce(operator.getitem, key_list[1:], ns) # type: ignore except Exception: - if default != Sentinel: + if not isinstance(default, SentinelClass): return default raise self.server.error( f"Key '{key}' in namespace '{namespace}' not found", 404) return val - def ns_length(self, namespace): + def ns_length(self, namespace: str) -> int: return len(self.ns_keys(namespace)) - def ns_keys(self, namespace): - keys = [] + def ns_keys(self, namespace: str) -> List[str]: + keys: List[str] = [] db = self.namespaces[namespace] with self.lmdb_env.begin(db=db) as txn: cursor = txn.cursor() @@ -185,15 +229,15 @@ class MoonrakerDatabase: remaining = cursor.next() return keys - def ns_values(self, namespace): + def ns_values(self, namespace: str) -> ValuesView: ns = self._get_namespace(namespace) return ns.values() - def ns_items(self, namespace): + def ns_items(self, namespace: str) -> ItemsView: ns = self._get_namespace(namespace) return ns.items() - def ns_contains(self, namespace, key): + def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool: try: key_list = self._process_key(key) if len(key_list) == 1: @@ -204,7 +248,10 @@ class MoonrakerDatabase: return False return True - def register_local_namespace(self, namespace, forbidden=False): + def register_local_namespace(self, + namespace: str, + forbidden: bool = False + ) -> None: if namespace not in self.namespaces: self.namespaces[namespace] = self.lmdb_env.open_db( namespace.encode()) @@ -219,13 +266,16 @@ class MoonrakerDatabase: self.insert_item("moonraker", "database.protected_namespaces", list(self.protected_namespaces)) - def wrap_namespace(self, namespace, parse_keys=True): + def wrap_namespace(self, + namespace: str, + parse_keys: bool = True + ) -> NamespaceWrapper: if namespace not in self.namespaces: raise self.server.error( f"Namespace '{namespace}' not found", 404) return NamespaceWrapper(namespace, self, parse_keys) - def _process_key(self, key): + def _process_key(self, key: Union[List[str], str]) -> List[str]: try: key_list = key if isinstance(key, list) else key.split('.') except Exception: @@ -234,13 +284,19 @@ class MoonrakerDatabase: raise self.server.error(f"Invalid Key Format: '{key}'") return key_list - def _insert_record(self, namespace, key, val): + def _insert_record(self, namespace: str, key: str, val: DBType) -> bool: db = self.namespaces[namespace] + if val is None: + return False with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: ret = txn.put(key.encode(), self._encode_value(val)) return ret - def _get_record(self, namespace, key, force=False): + def _get_record(self, + namespace: str, + key: str, + force: bool = False + ) -> DBRecord: if namespace not in self.namespaces: raise self.server.error( f"Namespace '{namespace}' not found", 404) @@ -254,7 +310,7 @@ class MoonrakerDatabase: f"Key '{key}' in namespace '{namespace}' not found", 404) return self._decode_value(value) - def _get_namespace(self, namespace): + def _get_namespace(self, namespace: str) -> Dict[str, Any]: if namespace not in self.namespaces: raise self.server.error( f"Invalid database namespace '{namespace}'") @@ -268,7 +324,7 @@ class MoonrakerDatabase: result[k] = self._decode_value(value) return result - def _encode_value(self, value): + def _encode_value(self, value: DBRecord) -> bytes: try: enc_func = RECORD_ENCODE_FUNCS[type(value)] return enc_func(value) @@ -276,26 +332,32 @@ class MoonrakerDatabase: raise self.server.error( f"Error encoding val: {value}, type: {type(value)}") - def _decode_value(self, bvalue): + def _decode_value(self, bvalue: bytes) -> DBRecord: fmt = bvalue[0] try: decode_func = RECORD_DECODE_FUNCS[fmt] return decode_func(bvalue) except Exception: raise self.server.error( - f"Error decoding value {bvalue}, format: {chr(fmt)}") + f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}") - async def _handle_list_request(self, web_request): + async def _handle_list_request(self, + web_request: WebRequest + ) -> Dict[str, List[str]]: ns_list = set(self.namespaces.keys()) - self.forbidden_namespaces return {'namespaces': list(ns_list)} - async def _handle_item_request(self, web_request): + async def _handle_item_request(self, + web_request: WebRequest + ) -> Dict[str, Any]: action = web_request.get_action() namespace = web_request.get_str("namespace") if namespace in self.forbidden_namespaces: raise self.server.error( f"Read/Write access to namespace '{namespace}'" " is forbidden", 403) + key: Any + valid_types: Tuple[type, ...] if action != "GET": if namespace in self.protected_namespaces and \ not self.enable_debug: @@ -320,7 +382,7 @@ class MoonrakerDatabase: val = self.delete_item(namespace, key, drop_empty_db=True) return {'namespace': namespace, 'key': key, 'value': val} - def close(self): + def close(self) -> None: # log db stats msg = "" with self.lmdb_env.begin() as txn: @@ -333,24 +395,28 @@ class MoonrakerDatabase: self.lmdb_env.close() class NamespaceWrapper: - def __init__(self, namespace, database, parse_keys): + def __init__(self, + namespace: str, + database: MoonrakerDatabase, + parse_keys: bool + ) -> None: self.namespace = namespace self.db = database # If parse keys is true, keys of a string type # will be passed straight to the DB methods. self.parse_keys = parse_keys - def insert(self, key, value): + def insert(self, key: Union[List[str], str], value: DBType) -> None: if isinstance(key, str) and not self.parse_keys: key = [key] self.db.insert_item(self.namespace, key, value) - def update_child(self, key, value): + def update_child(self, key: Union[List[str], str], value: DBType) -> None: if isinstance(key, str) and not self.parse_keys: key = [key] self.db.update_item(self.namespace, key, value) - def update(self, value): + def update(self, value: Dict[str, Any]) -> None: val_keys = set(value.keys()) new_keys = val_keys - set(self.keys()) update_keys = val_keys - new_keys @@ -359,52 +425,61 @@ class NamespaceWrapper: for key in new_keys: self.insert([key], value[key]) - def get(self, key, default=None): + def get(self, + key: Union[List[str], str], + default: Any = None + ) -> Any: if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.get_item(self.namespace, key, default) - def delete(self, key): + def delete(self, key: Union[List[str], str]) -> Any: if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.delete_item(self.namespace, key) - def __len__(self): + def __len__(self) -> int: return self.db.ns_length(self.namespace) - def __getitem__(self, key): - return self.get(key, default=Sentinel) + def __getitem__(self, key: Union[List[str], str]) -> Any: + return self.get(key, default=SENTINEL) - def __setitem__(self, key, value): + def __setitem__(self, + key: Union[List[str], str], + value: DBType + ) -> None: self.insert(key, value) - def __delitem__(self, key): + def __delitem__(self, key: Union[List[str], str]): self.delete(key) - def __contains__(self, key): + def __contains__(self, key: Union[List[str], str]) -> bool: if isinstance(key, str) and not self.parse_keys: key = [key] return self.db.ns_contains(self.namespace, key) - def keys(self): + def keys(self) -> List[str]: return self.db.ns_keys(self.namespace) - def values(self): + def values(self) -> ValuesView: return self.db.ns_values(self.namespace) - def items(self): + def items(self) -> ItemsView: return self.db.ns_items(self.namespace) - def pop(self, key, default=Sentinel): + def pop(self, + key: Union[List[str], str], + default: Any = SENTINEL + ) -> Any: try: val = self.delete(key) except Exception: - if default == Sentinel: + if isinstance(default, SentinelClass): raise val = default return val - def clear(self): + def clear(self) -> None: keys = self.keys() for k in keys: try: @@ -412,5 +487,5 @@ class NamespaceWrapper: except Exception: pass -def load_component(config): +def load_component(config: ConfigHelper) -> MoonrakerDatabase: return MoonrakerDatabase(config)