database: refactor to remove duplicate code

Wrap command implementations in with a _run_command() method.  All
database commands now return a Future object.  If the command was
run before the eventloop starts its possible to immediately query
the Future's result.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-31 11:45:22 -05:00
parent 65d1f23352
commit e029b6c582
1 changed files with 235 additions and 282 deletions

View File

@ -22,9 +22,9 @@ from typing import (
TYPE_CHECKING,
Any,
Awaitable,
ItemsView,
Callable,
Mapping,
ValuesView,
TypeVar,
Tuple,
Optional,
Union,
@ -37,6 +37,7 @@ if TYPE_CHECKING:
from websockets import WebRequest
DBRecord = Union[int, float, bool, str, List[Any], Dict[str, Any]]
DBType = Optional[DBRecord]
_T = TypeVar("_T")
DATABASE_VERSION = 1
MAX_NAMESPACES = 100
@ -157,343 +158,310 @@ class MoonrakerDatabase:
"/server/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request)
def _run_command(self,
command_func: Callable[..., _T],
*args
) -> Future[_T]:
def func_wrapper():
with self.thread_lock:
return command_func(*args)
if self.eventloop.is_running():
return cast(Future, self.eventloop.run_in_thread(func_wrapper))
else:
ret = func_wrapper()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
def insert_item(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._insert_impl, namespace, key, value)
else:
self._insert_impl(namespace, key, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
) -> Future[None]:
return self._run_command(self._insert_impl, namespace, key, value)
def _insert_impl(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> None:
with self.thread_lock:
key_list = self._process_key(key)
if namespace not in self.namespaces:
self.namespaces[namespace] = self.lmdb_env.open_db(
namespace.encode())
record = value
if len(key_list) > 1:
record = self._get_record(namespace, key_list[0], force=True)
if not isinstance(record, dict):
record = {}
logging.info(
f"Warning: Key {key_list[0]} contains a value of type"
f" {type(record)}. Overwriting with an object.")
item: Dict[str, Any] = reduce(
getitem_with_default, key_list[1:-1], record)
item[key_list[-1]] = value
if not self._insert_record(namespace, key_list[0], record):
key_list = self._process_key(key)
if namespace not in self.namespaces:
self.namespaces[namespace] = self.lmdb_env.open_db(
namespace.encode())
record = value
if len(key_list) > 1:
record = self._get_record(namespace, key_list[0], force=True)
if not isinstance(record, dict):
record = {}
logging.info(
f"Error inserting key '{key}' in namespace '{namespace}'")
f"Warning: Key {key_list[0]} contains a value of type"
f" {type(record)}. Overwriting with an object.")
item: Dict[str, Any] = reduce(
getitem_with_default, key_list[1:-1], record)
item[key_list[-1]] = value
if not self._insert_record(namespace, key_list[0], record):
logging.info(
f"Error inserting key '{key}' in namespace '{namespace}'")
def update_item(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._update_impl, namespace, key, value)
else:
self._update_impl(namespace, key, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
) -> Future[None]:
return self._run_command(self._update_impl, namespace, key, value)
def _update_impl(self,
namespace: str,
key: Union[List[str], str],
value: DBType
) -> None:
with self.thread_lock:
key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0])
if len(key_list) == 1:
if isinstance(record, dict) and isinstance(value, dict):
record.update(value)
else:
assert value is not None
record = value
key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0])
if len(key_list) == 1:
if isinstance(record, dict) and isinstance(value, dict):
record.update(value)
else:
try:
assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
except Exception:
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found",
404)
if isinstance(item[key_list[-1]], dict) \
and isinstance(value, dict):
item[key_list[-1]].update(value)
else:
item[key_list[-1]] = value
if not self._insert_record(namespace, key_list[0], record):
logging.info(
f"Error updating key '{key}' in namespace '{namespace}'")
assert value is not None
record = value
else:
try:
assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
except Exception:
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found",
404)
if isinstance(item[key_list[-1]], dict) \
and isinstance(value, dict):
item[key_list[-1]].update(value)
else:
item[key_list[-1]] = value
if not self._insert_record(namespace, key_list[0], record):
logging.info(
f"Error updating key '{key}' in namespace '{namespace}'")
def delete_item(self,
namespace: str,
key: Union[List[str], str],
drop_empty_db: bool = False
) -> Future[Any]:
if self.eventloop.is_running():
return cast(Future, self.eventloop.run_in_thread(
self._delete_impl, namespace, key, drop_empty_db))
else:
ret = self._delete_impl(namespace, key, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return self._run_command(self._delete_impl, namespace, key,
drop_empty_db)
def _delete_impl(self,
namespace: str,
key: Union[List[str], str],
drop_empty_db: bool = False
) -> Any:
with self.thread_lock:
key_list = self._process_key(key)
val = record = self._get_record(namespace, key_list[0])
remove_record = True
if len(key_list) > 1:
try:
assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
val = item.pop(key_list[-1])
except Exception:
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found",
404)
remove_record = False if record else True
if remove_record:
db = self.namespaces[namespace]
with (
self.lmdb_env.begin(write=True, buffers=True, db=db) as txn
):
ret = txn.delete(key_list[0].encode())
with txn.cursor() as cursor:
if not cursor.first() and drop_empty_db:
txn.drop(db)
del self.namespaces[namespace]
else:
ret = self._insert_record(namespace, key_list[0], record)
if not ret:
logging.info(
f"Error deleting key '{key}' from namespace "
f"'{namespace}'")
return val
key_list = self._process_key(key)
val = record = self._get_record(namespace, key_list[0])
remove_record = True
if len(key_list) > 1:
try:
assert isinstance(record, dict)
item: Dict[str, Any] = reduce(
operator.getitem, key_list[1:-1], record)
val = item.pop(key_list[-1])
except Exception:
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found",
404)
remove_record = False if record else True
if remove_record:
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
ret = txn.delete(key_list[0].encode())
with txn.cursor() as cursor:
if not cursor.first() and drop_empty_db:
txn.drop(db)
del self.namespaces[namespace]
else:
ret = self._insert_record(namespace, key_list[0], record)
if not ret:
logging.info(
f"Error deleting key '{key}' from namespace "
f"'{namespace}'")
return val
def get_item(self,
namespace: str,
key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL
) -> Future[Any]:
if self.eventloop.is_running():
return cast(Future, self.eventloop.run_in_thread(
self._get_impl, namespace, key, default))
else:
ret = self._get_impl(namespace, key, default)
fut = self.eventloop.create_future()
fut.set_result(ret)
return ret
return self._run_command(self._get_impl, namespace, key, default)
def _get_impl(self,
namespace: str,
key: Optional[Union[List[str], str]] = None,
default: Any = SENTINEL
) -> Any:
with self.thread_lock:
try:
if key is None:
return self._get_namespace(namespace)
key_list = self._process_key(key)
ns = self._get_record(namespace, key_list[0])
val = reduce(operator.getitem, # type: ignore
key_list[1:], ns)
except Exception:
if not isinstance(default, SentinelClass):
return default
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found", 404)
return val
try:
if key is None:
return self._get_namespace(namespace)
key_list = self._process_key(key)
ns = self._get_record(namespace, key_list[0])
val = reduce(operator.getitem, # type: ignore
key_list[1:], ns)
except Exception:
if not isinstance(default, SentinelClass):
return default
raise self.server.error(
f"Key '{key}' in namespace '{namespace}' not found", 404)
return val
def update_namespace(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._update_ns_impl, namespace, value)
else:
self._update_ns_impl(namespace, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
) -> Future[None]:
return self._run_command(self._update_ns_impl, namespace, value)
def _update_ns_impl(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> None:
with self.thread_lock:
if not value:
return
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
# We only need to update the keys that changed
for key, val in value.items():
stored = txn.get(key.encode())
if stored is not None:
decoded = self._decode_value(stored)
if val == decoded:
continue
ret = txn.put(key.encode(), self._encode_value(val))
if not ret:
logging.info(f"Error inserting key '{key}' "
f"in namespace '{namespace}'")
if not value:
return
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
# We only need to update the keys that changed
for key, val in value.items():
stored = txn.get(key.encode())
if stored is not None:
decoded = self._decode_value(stored)
if val == decoded:
continue
ret = txn.put(key.encode(), self._encode_value(val))
if not ret:
logging.info(f"Error inserting key '{key}' "
f"in namespace '{namespace}'")
def clear_namespace(self,
namespace: str,
drop_empty_db: bool = False
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._clear_ns_impl, namespace, drop_empty_db)
else:
self._clear_ns_impl(namespace, drop_empty_db)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
) -> Future[None]:
return self._run_command(self._clear_ns_impl, namespace, drop_empty_db)
def _clear_ns_impl(self,
namespace: str,
drop_empty_db: bool = False
) -> None:
with self.thread_lock:
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, db=db) as txn:
txn.drop(db, delete=drop_empty_db)
if drop_empty_db:
del self.namespaces[namespace]
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, db=db) as txn:
txn.drop(db, delete=drop_empty_db)
if drop_empty_db:
del self.namespaces[namespace]
def sync_namespace(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._sync_ns_impl, namespace, value)
else:
self._sync_ns_impl(namespace, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
) -> Future[None]:
return self._run_command(self._sync_ns_impl, namespace, value)
def _sync_ns_impl(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> None:
with self.thread_lock:
if not value:
return
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
new_keys = set(value.keys())
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
bkey, bval = cursor.item()
key = bytes(bkey).decode()
if key not in value:
remaining = cursor.delete()
else:
decoded = self._decode_value(bval)
if decoded != value[key]:
cursor.put(self._encode_value(value[key]))
new_keys.remove(key)
remaining = cursor.next()
for k in new_keys:
val = value[k]
ret = txn.put(key.encode(), self._encode_value(val))
if not ret:
logging.info(f"Error inserting key '{k}' "
f"in namespace '{namespace}'")
async def ns_length_async(self, namespace: str) -> int:
return len(await self.ns_keys_async(namespace))
def ns_length(self, namespace: str) -> int:
return len(self.ns_keys(namespace))
def ns_keys_async(self, namespace: str) -> Awaitable[List[str]]:
return self.eventloop.run_in_thread(self.ns_keys, namespace)
def ns_keys(self, namespace: str) -> List[str]:
with self.thread_lock:
keys: List[str] = []
db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
keys.append(cursor.key().decode())
if not value:
return
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
new_keys = set(value.keys())
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
bkey, bval = cursor.item()
key = bytes(bkey).decode()
if key not in value:
remaining = cursor.delete()
else:
decoded = self._decode_value(bval)
if decoded != value[key]:
cursor.put(self._encode_value(value[key]))
new_keys.remove(key)
remaining = cursor.next()
return keys
for k in new_keys:
val = value[k]
ret = txn.put(key.encode(), self._encode_value(val))
if not ret:
logging.info(f"Error inserting key '{k}' "
f"in namespace '{namespace}'")
def ns_values_async(self, namespace: str) -> Awaitable[ValuesView]:
return self.eventloop.run_in_thread(self.ns_values, namespace)
def ns_length(self, namespace: str) -> Future[int]:
return self._run_command(self._ns_length_impl, namespace)
def ns_values(self, namespace: str) -> ValuesView:
with self.thread_lock:
ns = self._get_namespace(namespace)
return ns.values()
def _ns_length_impl(self, namespace: str) -> int:
db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn:
stats = txn.stat(db)
return stats['entries']
def ns_items_async(self, namespace: str) -> Awaitable[ItemsView]:
return self.eventloop.run_in_thread(self.ns_items, namespace)
def ns_keys(self, namespace: str) -> Future[List[str]]:
return self._run_command(self._ns_keys_impl, namespace)
def ns_items(self, namespace: str) -> ItemsView:
with self.thread_lock:
ns = self._get_namespace(namespace)
return ns.items()
def _ns_keys_impl(self, namespace: str) -> List[str]:
keys: List[str] = []
db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
keys.append(cursor.key().decode())
remaining = cursor.next()
return keys
def ns_contains_async(self,
def ns_values(self, namespace: str) -> Future[List[Any]]:
return self._run_command(self._ns_values_impl, namespace)
def _ns_values_impl(self, namespace: str) -> List[Any]:
values: List[Any] = []
db = self.namespaces[namespace]
with self.lmdb_env.begin(db=db, buffers=True) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
values.append(self._decode_value(cursor.value()))
remaining = cursor.next()
return values
def ns_items(self, namespace: str) -> Future[List[Tuple[str, Any]]]:
return self._run_command(self._ns_items_impl, namespace)
def _ns_items_impl(self, namespace: str) -> List[Tuple[str, Any]]:
ns = self._get_namespace(namespace)
return list(ns.items())
def ns_contains(self,
namespace: str,
key: Union[List[str], str]
) -> Future[bool]:
return self._run_command(self._ns_contains_impl, namespace, key)
def _ns_contains_impl(self,
namespace: str,
key: Union[List[str], str]
) -> Awaitable[bool]:
return self.eventloop.run_in_thread(
self.ns_contains, namespace, key)
def ns_contains(self, namespace: str, key: Union[List[str], str]) -> bool:
with self.thread_lock:
try:
key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0])
if len(key_list) == 1:
return True
reduce(operator.getitem, # type: ignore
key_list[1:], record)
except Exception:
return False
return True
) -> bool:
try:
key_list = self._process_key(key)
record = self._get_record(namespace, key_list[0])
if len(key_list) == 1:
return True
reduce(operator.getitem, # type: ignore
key_list[1:], record)
except Exception:
return False
return True
def register_local_namespace(self,
namespace: str,
@ -727,8 +695,8 @@ class NamespaceWrapper:
key = [key]
return self.db.delete_item(self.namespace, key)
async def length(self) -> int:
return await self.db.ns_length_async(self.namespace)
def length(self) -> Future[int]:
return self.db.ns_length(self.namespace)
def as_dict(self) -> Dict[str, Any]:
self._check_sync_method("as_dict")
@ -750,36 +718,21 @@ class NamespaceWrapper:
self._check_sync_method("__contains__")
if isinstance(key, str) and not self.parse_keys:
key = [key]
return self.db.ns_contains(self.namespace, key)
return self.db.ns_contains(self.namespace, key).result()
async def contains(self, key: Union[List[str], str]) -> bool:
def contains(self, key: Union[List[str], str]) -> Future[bool]:
if isinstance(key, str) and not self.parse_keys:
key = [key]
return await self.db.ns_contains_async(self.namespace, key)
return self.db.ns_contains(self.namespace, key)
def keys(self) -> Future[List[str]]:
if not self.eventloop.is_running:
ret = self.db.ns_keys()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_keys_async(self.namespace))
return self.db.ns_keys(self.namespace)
def values(self) -> Future[ValuesView]:
if not self.eventloop.is_running:
ret = self.db.ns_values()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_values_async(self.namespace))
def values(self) -> Future[List[Any]]:
return self.db.ns_values(self.namespace)
def items(self) -> Future[ItemsView]:
if not self.eventloop.is_running:
ret = self.db.ns_items()
fut = self.eventloop.create_future()
fut.set_result(ret)
return fut
return cast(Future, self.db.ns_items_async(self.namespace))
def items(self) -> Future[List[Tuple[str, Any]]]:
return self.db.ns_items(self.namespace)
def pop(self,
key: Union[List[str], str],