database: add annotations

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-13 10:48:03 -04:00
parent b91df6642d
commit 410db750c6
1 changed files with 128 additions and 53 deletions

View File

@ -3,6 +3,8 @@
# Copyright (C) 2021 Eric Callahan <arksine.code@gmail.com> # Copyright (C) 2021 Eric Callahan <arksine.code@gmail.com>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
from __future__ import annotations
import os import os
import json import json
import struct import struct
@ -11,6 +13,25 @@ import logging
from io import BytesIO from io import BytesIO
from functools import reduce from functools import reduce
import lmdb import lmdb
from utils import SentinelClass
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
ItemsView,
ValuesView,
Tuple,
Optional,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from confighelper import ConfigHelper
from websockets import WebRequest
DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]]
DBType = Optional[DBRecord]
DATABASE_VERSION = 1 DATABASE_VERSION = 1
MAX_NAMESPACES = 50 MAX_NAMESPACES = 50
@ -34,19 +55,19 @@ RECORD_DECODE_FUNCS = {
ord("{"): lambda x: json.load(BytesIO(x)), ord("{"): lambda x: json.load(BytesIO(x)),
} }
def getitem_with_default(item, field): SENTINEL = SentinelClass.get_instance()
def getitem_with_default(item: Dict, field: Any) -> Any:
if field not in item: if field not in item:
item[field] = {} item[field] = {}
return item[field] return item[field]
class Sentinel:
pass
class MoonrakerDatabase: class MoonrakerDatabase:
def __init__(self, config): def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
self.namespaces = {} self.namespaces: Dict[str, object] = {}
self.enable_debug = config.get("enable_database_debug", False) self.enable_debug = config.getboolean("enable_database_debug", False)
self.database_path = os.path.expanduser(config.get( self.database_path = os.path.expanduser(config.get(
'database_path', "~/.moonraker_database")) 'database_path', "~/.moonraker_database"))
if not os.path.isdir(self.database_path): if not os.path.isdir(self.database_path):
@ -75,7 +96,8 @@ class MoonrakerDatabase:
"moonraker", "database.protected_namespaces", ["moonraker"])) "moonraker", "database.protected_namespaces", ["moonraker"]))
self.forbidden_namespaces = set(self.get_item( self.forbidden_namespaces = set(self.get_item(
"moonraker", "database.forbidden_namespaces", [])) "moonraker", "database.forbidden_namespaces", []))
debug_counter = self.get_item("moonraker", "database.debug_counter", 0) debug_counter: int = self.get_item(
"moonraker", "database.debug_counter", 0)
if self.enable_debug: if self.enable_debug:
debug_counter += 1 debug_counter += 1
self.insert_item("moonraker", "database.debug_counter", self.insert_item("moonraker", "database.debug_counter",
@ -88,7 +110,11 @@ class MoonrakerDatabase:
"/server/database/item", ["GET", "POST", "DELETE"], "/server/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request) self._handle_item_request)
def insert_item(self, namespace, key, value): def insert_item(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> None:
key_list = self._process_key(key) key_list = self._process_key(key)
if namespace not in self.namespaces: if namespace not in self.namespaces:
self.namespaces[namespace] = self.lmdb_env.open_db( self.namespaces[namespace] = self.lmdb_env.open_db(
@ -101,23 +127,31 @@ class MoonrakerDatabase:
logging.info( logging.info(
f"Warning: Key {key_list[0]} contains a value of type " f"Warning: Key {key_list[0]} contains a value of type "
f"{type(record)}. Overwriting with an object.") f"{type(record)}. Overwriting with an object.")
item = reduce(getitem_with_default, key_list[1:-1], record) item: Dict[str, Any] = reduce(getitem_with_default, key_list[1:-1],
record)
item[key_list[-1]] = value item[key_list[-1]] = value
if not self._insert_record(namespace, key_list[0], record): if not self._insert_record(namespace, key_list[0], record):
logging.info( logging.info(
f"Error inserting key '{key}' in namespace '{namespace}'") f"Error inserting key '{key}' in namespace '{namespace}'")
def update_item(self, namespace, key, value): def update_item(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> None:
key_list = self._process_key(key) key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0]) record = self._get_record(namespace, key_list[0])
if len(key_list) == 1: if len(key_list) == 1:
if isinstance(record, dict) and isinstance(value, dict): if isinstance(record, dict) and isinstance(value, dict):
record.update(value) record.update(value)
else: else:
assert value is not None
record = value record = value
else: else:
try: try:
item = reduce(operator.getitem, key_list[1:-1], record) assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
except Exception: except Exception:
raise self.server.error( raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found", 404) f"Key '{key}' in namespace '{namespace}' not found", 404)
@ -130,13 +164,19 @@ class MoonrakerDatabase:
logging.info( logging.info(
f"Error updating key '{key}' in namespace '{namespace}'") f"Error updating key '{key}' in namespace '{namespace}'")
def delete_item(self, namespace, key, drop_empty_db=False): def delete_item(self,
namespace: str,
key: Union[List[str], str],
drop_empty_db: bool = False
) -> Any:
key_list = self._process_key(key) key_list = self._process_key(key)
val = record = self._get_record(namespace, key_list[0]) val = record = self._get_record(namespace, key_list[0])
remove_record = True remove_record = True
if len(key_list) > 1: if len(key_list) > 1:
try: try:
item = reduce(operator.getitem, key_list[1:-1], record) assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
val = item.pop(key_list[-1]) val = item.pop(key_list[-1])
except Exception: except Exception:
raise self.server.error( raise self.server.error(
@ -157,25 +197,29 @@ class MoonrakerDatabase:
f"Error deleting key '{key}' from namespace '{namespace}'") f"Error deleting key '{key}' from namespace '{namespace}'")
return val return val
def get_item(self, namespace, key=None, default=Sentinel): def get_item(self,
namespace: str,
key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL
) -> Any:
try: try:
if key is None: if key is None:
return self._get_namespace(namespace) return self._get_namespace(namespace)
key_list = self._process_key(key) key_list = self._process_key(key)
ns = self._get_record(namespace, key_list[0]) ns = self._get_record(namespace, key_list[0])
val = reduce(operator.getitem, key_list[1:], ns) val = reduce(operator.getitem, key_list[1:], ns) # type: ignore
except Exception: except Exception:
if default != Sentinel: if not isinstance(default, SentinelClass):
return default return default
raise self.server.error( raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found", 404) f"Key '{key}' in namespace '{namespace}' not found", 404)
return val return val
def ns_length(self, namespace): def ns_length(self, namespace: str) -> int:
return len(self.ns_keys(namespace)) return len(self.ns_keys(namespace))
def ns_keys(self, namespace): def ns_keys(self, namespace: str) -> List[str]:
keys = [] keys: List[str] = []
db = self.namespaces[namespace] db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn: with self.lmdb_env.begin(db=db) as txn:
cursor = txn.cursor() cursor = txn.cursor()
@ -185,15 +229,15 @@ class MoonrakerDatabase:
remaining = cursor.next() remaining = cursor.next()
return keys return keys
def ns_values(self, namespace): def ns_values(self, namespace: str) -> ValuesView:
ns = self._get_namespace(namespace) ns = self._get_namespace(namespace)
return ns.values() return ns.values()
def ns_items(self, namespace): def ns_items(self, namespace: str) -> ItemsView:
ns = self._get_namespace(namespace) ns = self._get_namespace(namespace)
return ns.items() return ns.items()
def ns_contains(self, namespace, key): def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool:
try: try:
key_list = self._process_key(key) key_list = self._process_key(key)
if len(key_list) == 1: if len(key_list) == 1:
@ -204,7 +248,10 @@ class MoonrakerDatabase:
return False return False
return True return True
def register_local_namespace(self, namespace, forbidden=False): def register_local_namespace(self,
namespace: str,
forbidden: bool = False
) -> None:
if namespace not in self.namespaces: if namespace not in self.namespaces:
self.namespaces[namespace] = self.lmdb_env.open_db( self.namespaces[namespace] = self.lmdb_env.open_db(
namespace.encode()) namespace.encode())
@ -219,13 +266,16 @@ class MoonrakerDatabase:
self.insert_item("moonraker", "database.protected_namespaces", self.insert_item("moonraker", "database.protected_namespaces",
list(self.protected_namespaces)) list(self.protected_namespaces))
def wrap_namespace(self, namespace, parse_keys=True): def wrap_namespace(self,
namespace: str,
parse_keys: bool = True
) -> NamespaceWrapper:
if namespace not in self.namespaces: if namespace not in self.namespaces:
raise self.server.error( raise self.server.error(
f"Namespace '{namespace}' not found", 404) f"Namespace '{namespace}' not found", 404)
return NamespaceWrapper(namespace, self, parse_keys) return NamespaceWrapper(namespace, self, parse_keys)
def _process_key(self, key): def _process_key(self, key: Union[List[str], str]) -> List[str]:
try: try:
key_list = key if isinstance(key, list) else key.split('.') key_list = key if isinstance(key, list) else key.split('.')
except Exception: except Exception:
@ -234,13 +284,19 @@ class MoonrakerDatabase:
raise self.server.error(f"Invalid Key Format: '{key}'") raise self.server.error(f"Invalid Key Format: '{key}'")
return key_list return key_list
def _insert_record(self, namespace, key, val): def _insert_record(self, namespace: str, key: str, val: DBType) -> bool:
db = self.namespaces[namespace] db = self.namespaces[namespace]
if val is None:
return False
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
ret = txn.put(key.encode(), self._encode_value(val)) ret = txn.put(key.encode(), self._encode_value(val))
return ret return ret
def _get_record(self, namespace, key, force=False): def _get_record(self,
namespace: str,
key: str,
force: bool = False
) -> DBRecord:
if namespace not in self.namespaces: if namespace not in self.namespaces:
raise self.server.error( raise self.server.error(
f"Namespace '{namespace}' not found", 404) f"Namespace '{namespace}' not found", 404)
@ -254,7 +310,7 @@ class MoonrakerDatabase:
f"Key '{key}' in namespace '{namespace}' not found", 404) f"Key '{key}' in namespace '{namespace}' not found", 404)
return self._decode_value(value) return self._decode_value(value)
def _get_namespace(self, namespace): def _get_namespace(self, namespace: str) -> Dict[str, Any]:
if namespace not in self.namespaces: if namespace not in self.namespaces:
raise self.server.error( raise self.server.error(
f"Invalid database namespace '{namespace}'") f"Invalid database namespace '{namespace}'")
@ -268,7 +324,7 @@ class MoonrakerDatabase:
result[k] = self._decode_value(value) result[k] = self._decode_value(value)
return result return result
def _encode_value(self, value): def _encode_value(self, value: DBRecord) -> bytes:
try: try:
enc_func = RECORD_ENCODE_FUNCS[type(value)] enc_func = RECORD_ENCODE_FUNCS[type(value)]
return enc_func(value) return enc_func(value)
@ -276,26 +332,32 @@ class MoonrakerDatabase:
raise self.server.error( raise self.server.error(
f"Error encoding val: {value}, type: {type(value)}") f"Error encoding val: {value}, type: {type(value)}")
def _decode_value(self, bvalue): def _decode_value(self, bvalue: bytes) -> DBRecord:
fmt = bvalue[0] fmt = bvalue[0]
try: try:
decode_func = RECORD_DECODE_FUNCS[fmt] decode_func = RECORD_DECODE_FUNCS[fmt]
return decode_func(bvalue) return decode_func(bvalue)
except Exception: except Exception:
raise self.server.error( raise self.server.error(
f"Error decoding value {bvalue}, format: {chr(fmt)}") f"Error decoding value {bvalue.decode()}, format: {chr(fmt)}")
async def _handle_list_request(self, web_request): async def _handle_list_request(self,
web_request: WebRequest
) -> Dict[str, List[str]]:
ns_list = set(self.namespaces.keys()) - self.forbidden_namespaces ns_list = set(self.namespaces.keys()) - self.forbidden_namespaces
return {'namespaces': list(ns_list)} return {'namespaces': list(ns_list)}
async def _handle_item_request(self, web_request): async def _handle_item_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action() action = web_request.get_action()
namespace = web_request.get_str("namespace") namespace = web_request.get_str("namespace")
if namespace in self.forbidden_namespaces: if namespace in self.forbidden_namespaces:
raise self.server.error( raise self.server.error(
f"Read/Write access to namespace '{namespace}'" f"Read/Write access to namespace '{namespace}'"
" is forbidden", 403) " is forbidden", 403)
key: Any
valid_types: Tuple[type, ...]
if action != "GET": if action != "GET":
if namespace in self.protected_namespaces and \ if namespace in self.protected_namespaces and \
not self.enable_debug: not self.enable_debug:
@ -320,7 +382,7 @@ class MoonrakerDatabase:
val = self.delete_item(namespace, key, drop_empty_db=True) val = self.delete_item(namespace, key, drop_empty_db=True)
return {'namespace': namespace, 'key': key, 'value': val} return {'namespace': namespace, 'key': key, 'value': val}
def close(self): def close(self) -> None:
# log db stats # log db stats
msg = "" msg = ""
with self.lmdb_env.begin() as txn: with self.lmdb_env.begin() as txn:
@ -333,24 +395,28 @@ class MoonrakerDatabase:
self.lmdb_env.close() self.lmdb_env.close()
class NamespaceWrapper: class NamespaceWrapper:
def __init__(self, namespace, database, parse_keys): def __init__(self,
namespace: str,
database: MoonrakerDatabase,
parse_keys: bool
) -> None:
self.namespace = namespace self.namespace = namespace
self.db = database self.db = database
# If parse keys is true, keys of a string type # If parse keys is true, keys of a string type
# will be passed straight to the DB methods. # will be passed straight to the DB methods.
self.parse_keys = parse_keys self.parse_keys = parse_keys
def insert(self, key, value): def insert(self, key: Union[List[str], str], value: DBType) -> None:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
self.db.insert_item(self.namespace, key, value) self.db.insert_item(self.namespace, key, value)
def update_child(self, key, value): def update_child(self, key: Union[List[str], str], value: DBType) -> None:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
self.db.update_item(self.namespace, key, value) self.db.update_item(self.namespace, key, value)
def update(self, value): def update(self, value: Dict[str, Any]) -> None:
val_keys = set(value.keys()) val_keys = set(value.keys())
new_keys = val_keys - set(self.keys()) new_keys = val_keys - set(self.keys())
update_keys = val_keys - new_keys update_keys = val_keys - new_keys
@ -359,52 +425,61 @@ class NamespaceWrapper:
for key in new_keys: for key in new_keys:
self.insert([key], value[key]) self.insert([key], value[key])
def get(self, key, default=None): def get(self,
key: Union[List[str], str],
default: Any = None
) -> Any:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return self.db.get_item(self.namespace, key, default) return self.db.get_item(self.namespace, key, default)
def delete(self, key): def delete(self, key: Union[List[str], str]) -> Any:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return self.db.delete_item(self.namespace, key) return self.db.delete_item(self.namespace, key)
def __len__(self): def __len__(self) -> int:
return self.db.ns_length(self.namespace) return self.db.ns_length(self.namespace)
def __getitem__(self, key): def __getitem__(self, key: Union[List[str], str]) -> Any:
return self.get(key, default=Sentinel) return self.get(key, default=SENTINEL)
def __setitem__(self, key, value): def __setitem__(self,
key: Union[List[str], str],
value: DBType
) -> None:
self.insert(key, value) self.insert(key, value)
def __delitem__(self, key): def __delitem__(self, key: Union[List[str], str]):
self.delete(key) self.delete(key)
def __contains__(self, key): def __contains__(self, key: Union[List[str], str]) -> bool:
if isinstance(key, str) and not self.parse_keys: if isinstance(key, str) and not self.parse_keys:
key = [key] key = [key]
return self.db.ns_contains(self.namespace, key) return self.db.ns_contains(self.namespace, key)
def keys(self): def keys(self) -> List[str]:
return self.db.ns_keys(self.namespace) return self.db.ns_keys(self.namespace)
def values(self): def values(self) -> ValuesView:
return self.db.ns_values(self.namespace) return self.db.ns_values(self.namespace)
def items(self): def items(self) -> ItemsView:
return self.db.ns_items(self.namespace) return self.db.ns_items(self.namespace)
def pop(self, key, default=Sentinel): def pop(self,
key: Union[List[str], str],
default: Any = SENTINEL
) -> Any:
try: try:
val = self.delete(key) val = self.delete(key)
except Exception: except Exception:
if default == Sentinel: if isinstance(default, SentinelClass):
raise raise
val = default val = default
return val return val
def clear(self): def clear(self) -> None:
keys = self.keys() keys = self.keys()
for k in keys: for k in keys:
try: try:
@ -412,5 +487,5 @@ class NamespaceWrapper:
except Exception: except Exception:
pass pass
def load_component(config): def load_component(config: ConfigHelper) -> MoonrakerDatabase:
return MoonrakerDatabase(config) return MoonrakerDatabase(config)