database: reduce duplicate code

Add a _get_db() method that perform the check for existance.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-31 19:39:16 -05:00
parent e2a62f80d4
commit 46f74329d3
1 changed files with 18 additions and 40 deletions

View File

@ -336,10 +336,7 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
records: Dict[str, Any] records: Dict[str, Any]
) -> None: ) -> None:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
for key, val in records.items(): for key, val in records.items():
ret = txn.put(key.encode(), self._encode_value(val)) ret = txn.put(key.encode(), self._encode_value(val))
@ -360,10 +357,7 @@ class MoonrakerDatabase:
source_keys: List[str], source_keys: List[str],
dest_keys: List[str] dest_keys: List[str]
) -> None: ) -> None:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, db=db) as txn: with self.lmdb_env.begin(write=True, db=db) as txn:
for source, dest in zip(source_keys, dest_keys): for source, dest in zip(source_keys, dest_keys):
val = txn.pop(source.encode()) val = txn.pop(source.encode())
@ -380,10 +374,7 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
keys: List[str] keys: List[str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
for key in keys: for key in keys:
@ -402,10 +393,7 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
keys: List[str] keys: List[str]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
encoded_keys: List[bytes] = [k.encode() for k in keys] encoded_keys: List[bytes] = [k.encode() for k in keys]
with self.lmdb_env.begin(buffers=True, db=db) as txn: with self.lmdb_env.begin(buffers=True, db=db) as txn:
@ -429,10 +417,7 @@ class MoonrakerDatabase:
) -> None: ) -> None:
if not value: if not value:
return return
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
# We only need to update the keys that changed # We only need to update the keys that changed
for key, val in value.items(): for key, val in value.items():
@ -456,10 +441,7 @@ class MoonrakerDatabase:
namespace: str, namespace: str,
drop_empty_db: bool = False drop_empty_db: bool = False
) -> None: ) -> None:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
with self.lmdb_env.begin(write=True, db=db) as txn: with self.lmdb_env.begin(write=True, db=db) as txn:
txn.drop(db, delete=drop_empty_db) txn.drop(db, delete=drop_empty_db)
if drop_empty_db: if drop_empty_db:
@ -477,10 +459,7 @@ class MoonrakerDatabase:
) -> None: ) -> None:
if not value: if not value:
return return
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
new_keys = set(value.keys()) new_keys = set(value.keys())
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
with txn.cursor() as cursor: with txn.cursor() as cursor:
@ -507,7 +486,7 @@ class MoonrakerDatabase:
return self._run_command(self._ns_length_impl, namespace) return self._run_command(self._ns_length_impl, namespace)
def _ns_length_impl(self, namespace: str) -> int: def _ns_length_impl(self, namespace: str) -> int:
db = self.namespaces[namespace] db = self._get_db(namespace)
with self.lmdb_env.begin(db=db) as txn: with self.lmdb_env.begin(db=db) as txn:
stats = txn.stat(db) stats = txn.stat(db)
return stats['entries'] return stats['entries']
@ -517,7 +496,7 @@ class MoonrakerDatabase:
def _ns_keys_impl(self, namespace: str) -> List[str]: def _ns_keys_impl(self, namespace: str) -> List[str]:
keys: List[str] = [] keys: List[str] = []
db = self.namespaces[namespace] db = self._get_db(namespace)
with self.lmdb_env.begin(db=db) as txn: with self.lmdb_env.begin(db=db) as txn:
with txn.cursor() as cursor: with txn.cursor() as cursor:
remaining = cursor.first() remaining = cursor.first()
@ -531,7 +510,7 @@ class MoonrakerDatabase:
def _ns_values_impl(self, namespace: str) -> List[Any]: def _ns_values_impl(self, namespace: str) -> List[Any]:
values: List[Any] = [] values: List[Any] = []
db = self.namespaces[namespace] db = self._get_db(namespace)
with self.lmdb_env.begin(db=db, buffers=True) as txn: with self.lmdb_env.begin(db=db, buffers=True) as txn:
with txn.cursor() as cursor: with txn.cursor() as cursor:
remaining = cursor.first() remaining = cursor.first()
@ -603,6 +582,11 @@ class MoonrakerDatabase:
f"Namespace '{namespace}' not found", 404) f"Namespace '{namespace}' not found", 404)
return NamespaceWrapper(namespace, self, parse_keys) return NamespaceWrapper(namespace, self, parse_keys)
def _get_db(self, namespace: str) -> object:
if namespace not in self.namespaces:
raise self.server.error(f"Namespace '{namespace}' not found", 404)
return self.namespaces[namespace]
def _process_key(self, key: Union[List[str], str]) -> List[str]: def _process_key(self, key: Union[List[str], str]) -> List[str]:
try: try:
key_list = key if isinstance(key, list) else key.split('.') key_list = key if isinstance(key, list) else key.split('.')
@ -613,7 +597,7 @@ class MoonrakerDatabase:
return key_list return key_list
def _insert_record(self, namespace: str, key: str, val: DBType) -> bool: def _insert_record(self, namespace: str, key: str, val: DBType) -> bool:
db = self.namespaces[namespace] db = self._get_db(namespace)
if val is None: if val is None:
return False return False
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn:
@ -625,10 +609,7 @@ class MoonrakerDatabase:
key: str, key: str,
force: bool = False force: bool = False
) -> DBRecord: ) -> DBRecord:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Namespace '{namespace}' not found", 404)
db = self.namespaces[namespace]
with self.lmdb_env.begin(buffers=True, db=db) as txn: with self.lmdb_env.begin(buffers=True, db=db) as txn:
value = txn.get(key.encode()) value = txn.get(key.encode())
if value is None: if value is None:
@ -639,10 +620,7 @@ class MoonrakerDatabase:
return self._decode_value(value) return self._decode_value(value)
def _get_namespace(self, namespace: str) -> Dict[str, Any]: def _get_namespace(self, namespace: str) -> Dict[str, Any]:
if namespace not in self.namespaces: db = self._get_db(namespace)
raise self.server.error(
f"Invalid database namespace '{namespace}'")
db = self.namespaces[namespace]
result = {} result = {}
invalid_key_result = None invalid_key_result = None
with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: with self.lmdb_env.begin(write=True, buffers=True, db=db) as txn: