database: add a sync method

Similar to the update method, however sync will remove any
keys in the database not in the new value.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-30 19:41:11 -05:00
parent 781e3e6250
commit 65d1f23352
1 changed files with 57 additions and 3 deletions

View File

@ -23,6 +23,7 @@ from typing import (
Any,
Awaitable,
ItemsView,
Mapping,
ValuesView,
Tuple,
Optional,
@ -329,7 +330,7 @@ class MoonrakerDatabase:
def update_namespace(self,
namespace: str,
value: Dict[str, DBRecord]
value: Mapping[str, DBRecord]
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
@ -342,7 +343,7 @@ class MoonrakerDatabase:
def _update_ns_impl(self,
namespace: str,
value: Dict[str, DBRecord]
value: Mapping[str, DBRecord]
) -> None:
with self.thread_lock:
if not value:
@ -391,6 +392,52 @@ class MoonrakerDatabase:
if drop_empty_db:
del self.namespaces[namespace]
def sync_namespace(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> Awaitable[None]:
if self.eventloop.is_running():
return self.eventloop.run_in_thread(
self._sync_ns_impl, namespace, value)
else:
self._sync_ns_impl(namespace, value)
fut = self.eventloop.create_future()
fut.set_result(None)
return fut
def _sync_ns_impl(self,
namespace: str,
value: Mapping[str, DBRecord]
) -> None:
with self.thread_lock:
if not value:
return
if namespace not in self.namespaces:
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
new_keys = set(value.keys())
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
with txn.cursor() as cursor:
remaining = cursor.first()
while remaining:
bkey, bval = cursor.item()
key = bytes(bkey).decode()
if key not in value:
remaining = cursor.delete()
else:
decoded = self._decode_value(bval)
if decoded != value[key]:
cursor.put(self._encode_value(value[key]))
new_keys.remove(key)
remaining = cursor.next()
for k in new_keys:
val = value[k]
ret = txn.put(key.encode(), self._encode_value(val))
if not ret:
logging.info(f"Error inserting key '{k}' "
f"in namespace '{namespace}'")
async def ns_length_async(self, namespace: str) -> int:
return len(await self.ns_keys_async(namespace))
@ -661,9 +708,12 @@ class NamespaceWrapper:
key = [key]
return self.db.update_item(self.namespace, key, value)
def update(self, value: Dict[str, DBRecord]) -> Awaitable[None]:
def update(self, value: Mapping[str, DBRecord]) -> Awaitable[None]:
return self.db.update_namespace(self.namespace, value)
def sync(self, value: Mapping[str, DBRecord]) -> Awaitable[None]:
return self.db.sync_namespace(self.namespace, value)
def get(self,
key: Union[List[str], str],
default: Any = None
@ -680,6 +730,10 @@ class NamespaceWrapper:
async def length(self) -> int:
return await self.db.ns_length_async(self.namespace)
def as_dict(self) -> Dict[str, Any]:
self._check_sync_method("as_dict")
return self.db._get_namespace(self.namespace)
def __getitem__(self, key: Union[List[str], str]) -> Future[Any]:
return self.get(key, default=SENTINEL)