database: add backup, restore, and compact endpoints

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-05-12 20:40:28 -04:00
parent 64ffe22545
commit 3f62bb6fb4
3 changed files with 184 additions and 8 deletions

View File

@ -11,7 +11,8 @@ import operator
import inspect import inspect
import logging import logging
import contextlib import contextlib
from asyncio import Future, Task import time
from asyncio import Future, Task, Lock
from functools import reduce from functools import reduce
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
@ -39,6 +40,7 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
from ..confighelper import ConfigHelper from ..confighelper import ConfigHelper
from ..common import WebRequest from ..common import WebRequest
from .klippy_connection import KlippyConnection
from lmdb import Environment as LmdbEnvironment from lmdb import Environment as LmdbEnvironment
from types import TracebackType from types import TracebackType
DBRecord = Optional[Union[int, float, bool, str, List[Any], Dict[str, Any]]] DBRecord = Optional[Union[int, float, bool, str, List[Any], Dict[str, Any]]]
@ -180,6 +182,7 @@ class MoonrakerDatabase:
self.eventloop = self.server.get_event_loop() self.eventloop = self.server.get_event_loop()
self.registered_namespaces: Set[str] = set(["moonraker", "database"]) self.registered_namespaces: Set[str] = set(["moonraker", "database"])
self.registered_tables: Set[str] = set([NAMESPACE_TABLE, REGISTRATION_TABLE]) self.registered_tables: Set[str] = set([NAMESPACE_TABLE, REGISTRATION_TABLE])
self.backup_lock = Lock()
instance_id: str = self.server.get_app_args()["instance_uuid"] instance_id: str = self.server.get_app_args()["instance_uuid"]
db_path = self._get_database_folder(config) db_path = self._get_database_folder(config)
self._sql_db = db_path.joinpath(SQL_DB_FILENAME) self._sql_db = db_path.joinpath(SQL_DB_FILENAME)
@ -224,6 +227,16 @@ class MoonrakerDatabase:
self.server.register_endpoint( self.server.register_endpoint(
"/server/database/item", RequestType.all(), self._handle_item_request "/server/database/item", RequestType.all(), self._handle_item_request
) )
self.server.register_endpoint(
"/server/database/backup", RequestType.POST | RequestType.DELETE,
self._handle_backup_request
)
self.server.register_endpoint(
"/server/database/restore", RequestType.POST, self._handle_restore_request
)
self.server.register_endpoint(
"/server/database/compact", RequestType.POST, self._handle_compact_request
)
self.server.register_debug_endpoint( self.server.register_debug_endpoint(
"/debug/database/list", RequestType.GET, self._handle_list_request "/debug/database/list", RequestType.GET, self._handle_list_request
) )
@ -429,6 +442,21 @@ class MoonrakerDatabase:
) -> Future[Any]: ) -> Future[Any]:
return self.db_provider.execute_db_function(callback) return self.db_provider.execute_db_function(callback)
def compact_database(self) -> Future[Dict[str, int]]:
return self.db_provider.execute_db_function(
self.db_provider.compact_database
)
def backup_database(self, bkp_path: pathlib.Path) -> Future[None]:
return self.db_provider.execute_db_function(
self.db_provider.backup_database, bkp_path
)
def restore_database(self, restore_path: pathlib.Path) -> Future[Dict[str, Any]]:
return self.db_provider.execute_db_function(
self.db_provider.restore_database, restore_path
)
def register_local_namespace( def register_local_namespace(
self, namespace: str, forbidden: bool = False, parse_keys: bool = False self, namespace: str, forbidden: bool = False, parse_keys: bool = False
) -> NamespaceWrapper: ) -> NamespaceWrapper:
@ -479,6 +507,10 @@ class MoonrakerDatabase:
def get_provider_wrapper(self) -> DBProviderWrapper: def get_provider_wrapper(self) -> DBProviderWrapper:
return self.db_provider.get_provider_wapper() return self.db_provider.get_provider_wapper()
def get_backup_dir(self) -> pathlib.Path:
bkp_dir = pathlib.Path(self.server.get_app_arg("data_path"))
return bkp_dir.joinpath("backup/database").resolve()
def register_table(self, table_def: SqlTableDefinition) -> SqlTableWrapper: def register_table(self, table_def: SqlTableDefinition) -> SqlTableWrapper:
if table_def.name in self.registered_tables: if table_def.name in self.registered_tables:
raise self.server.error(f"Table '{table_def.name}' already registered") raise self.server.error(f"Table '{table_def.name}' already registered")
@ -486,17 +518,78 @@ class MoonrakerDatabase:
self.db_provider.register_table(table_def) self.db_provider.register_table(table_def)
return SqlTableWrapper(self, table_def) return SqlTableWrapper(self, table_def)
async def _handle_compact_request(self, web_request: WebRequest) -> Dict[str, int]:
kconn: KlippyConnection = self.server.lookup_component("klippy_connection")
if kconn.is_printing():
raise self.server.error("Cannot compact when Klipper is printing")
async with self.backup_lock:
return await self.compact_database()
async def _handle_backup_request(self, web_request: WebRequest) -> Dict[str, Any]:
async with self.backup_lock:
request_type = web_request.get_request_type()
if request_type == RequestType.POST:
kconn: KlippyConnection
kconn = self.server.lookup_component("klippy_connection")
if kconn.is_printing():
raise self.server.error("Cannot backup when Klipper is printing")
suffix = time.strftime("%Y%m%d-%H%M%S", time.localtime())
db_name = web_request.get_str("filename", f"sqldb-backup-{suffix}.db")
bkp_dir = self.get_backup_dir()
bkp_path = bkp_dir.joinpath(db_name).resolve()
if bkp_dir not in bkp_path.parents:
raise self.server.error(f"Invalid name {db_name}.")
await self.backup_database(bkp_path)
elif request_type == RequestType.DELETE:
db_name = web_request.get_str("filename")
bkp_dir = self.get_backup_dir()
bkp_path = bkp_dir.joinpath(db_name).resolve()
if bkp_dir not in bkp_path.parents:
raise self.server.error(f"Invalid name {db_name}.")
if not bkp_path.is_file():
raise self.server.error(
f"Backup file {db_name} does not exist", 404
)
await self.eventloop.run_in_thread(bkp_path.unlink)
else:
raise self.server.error("Invalid request type")
return {
"backup_path": str(bkp_path)
}
async def _handle_restore_request(self, web_request: WebRequest) -> Dict[str, Any]:
kconn: KlippyConnection = self.server.lookup_component("klippy_connection")
if kconn.is_printing():
raise self.server.error("Cannot restore when Klipper is printing")
async with self.backup_lock:
db_name = web_request.get_str("filename")
bkp_dir = self.get_backup_dir()
restore_path = bkp_dir.joinpath(db_name).resolve()
if bkp_dir not in restore_path.parents:
raise self.server.error(f"Invalid name {db_name}.")
restore_info = await self.restore_database(restore_path)
self.server.restart(.1)
return restore_info
async def _handle_list_request( async def _handle_list_request(
self, web_request: WebRequest self, web_request: WebRequest
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
path = web_request.get_endpoint() path = web_request.get_endpoint()
ns_list = set(self.db_provider.namespaces) ns_list = set(self.db_provider.namespaces)
bkp_dir = self.get_backup_dir()
backups: List[str] = []
if bkp_dir.is_dir():
backups = [bkp.name for bkp in bkp_dir.iterdir() if bkp.is_file()]
if not path.startswith("/debug/"): if not path.startswith("/debug/"):
ns_list -= self.forbidden_namespaces ns_list -= self.forbidden_namespaces
return {"namespaces": list(ns_list)} return {
"namespaces": list(ns_list),
"backups": backups
}
else: else:
return { return {
"namespaces": list(ns_list), "namespaces": list(ns_list),
"backups": backups,
"tables": list(self.db_provider.tables) "tables": list(self.db_provider.tables)
} }
@ -550,9 +643,11 @@ class MoonrakerDatabase:
return {'namespace': namespace, 'key': key, 'value': val} return {'namespace': namespace, 'key': key, 'value': val}
async def close(self) -> None: async def close(self) -> None:
await self.insert_item( if not self.db_provider.is_restored():
"database", "unsafe_shutdowns", self.unsafe_shutdowns # Don't overwrite unsafe shutdowns on a restored database
) await self.insert_item(
"database", "unsafe_shutdowns", self.unsafe_shutdowns
)
# Stop command thread # Stop command thread
await self.db_provider.stop() await self.db_provider.stop()
@ -628,6 +723,7 @@ class SqliteProvider(Thread):
self._namespaces: Set[str] = set() self._namespaces: Set[str] = set()
self._tables: Set[str] = set() self._tables: Set[str] = set()
self._db_path = db_path self._db_path = db_path
self.restored: bool = False
self.command_queue: Queue[Tuple[Future, Optional[Callable], Tuple[Any, ...]]] self.command_queue: Queue[Tuple[Future, Optional[Callable], Tuple[Any, ...]]]
self.command_queue = Queue() self.command_queue = Queue()
sqlite3.register_converter("record", decode_record) sqlite3.register_converter("record", decode_record)
@ -1204,9 +1300,84 @@ class SqliteProvider(Thread):
f"Stored table prototype:\n{detected_proto}" f"Stored table prototype:\n{detected_proto}"
) )
def compact_database(self, conn: sqlite3.Connection) -> Dict[str, int]:
if self.restored:
raise self.server.error(
"Cannot compact restored database, awaiting restart"
)
cur_size = self._db_path.stat().st_size
conn.execute("VACUUM")
new_size = self._db_path.stat().st_size
return {
"previous_size": cur_size,
"new_size": new_size
}
def backup_database(
self, conn: sqlite3.Connection, bkp_path: pathlib.Path
) -> None:
if self.restored:
raise self.server.error(
"Cannot backup restored database, awaiting restart"
)
parent = bkp_path.parent
if not parent.exists():
parent.mkdir(parents=True, exist_ok=True)
elif bkp_path.exists():
bkp_path.unlink()
bkp_conn = sqlite3.connect(str(bkp_path))
conn.backup(bkp_conn)
bkp_conn.close()
def restore_database(
self, conn: sqlite3.Connection, restore_path: pathlib.Path
) -> Dict[str, Any]:
if self.restored:
raise self.server.error("Database already restored")
if not restore_path.is_file():
raise self.server.error(f"Restoration File {restore_path} does not exist")
restore_conn = sqlite3.connect(str(restore_path))
restore_info = self._validate_restore_db(restore_conn)
restore_conn.backup(conn)
restore_conn.close()
self.restored = True
return restore_info
def _validate_restore_db(
self, restore_conn: sqlite3.Connection
) -> Dict[str, Any]:
cursor = restore_conn.execute(
"SELECT name FROM sqlite_schema WHERE type = 'table'"
)
cursor.arraysize = 100
tables = [row[0] for row in cursor.fetchall()]
if NAMESPACE_TABLE not in tables:
restore_conn.close()
raise self.server.error(
f"Invalid database for restoration, missing table '{NAMESPACE_TABLE}'"
)
missing_tables = self._tables.difference(tables)
if missing_tables:
logging.info(f"Database to restore missing tables: {missing_tables}")
cursor = restore_conn.execute(
f"SELECT DISTINCT namespace FROM {NAMESPACE_TABLE}"
)
cursor.arraysize = 100
namespaces = [row[0] for row in cursor.fetchall()]
missing_ns = self._namespaces.difference(namespaces)
if missing_ns:
logging.info(f"Database to restore missing namespaces: {missing_ns}")
return {
"restored_tables": tables,
"restored_namespaces": namespaces
}
def get_provider_wapper(self) -> DBProviderWrapper: def get_provider_wapper(self) -> DBProviderWrapper:
return DBProviderWrapper(self) return DBProviderWrapper(self)
def is_restored(self) -> bool:
return self.restored
def stop(self) -> Future[None]: def stop(self) -> Future[None]:
fut = self.asyncio_loop.create_future() fut = self.asyncio_loop.create_future()
if not self.is_alive(): if not self.is_alive():

View File

@ -76,9 +76,8 @@ class FileManager:
db_path = db.get_database_path() db_path = db.get_database_path()
self.add_reserved_path("database", db_path, False) self.add_reserved_path("database", db_path, False)
self.add_reserved_path("certs", self.datapath.joinpath("certs"), False) self.add_reserved_path("certs", self.datapath.joinpath("certs"), False)
self.add_reserved_path( self.add_reserved_path("systemd", self.datapath.joinpath("systemd"), False)
"systemd", self.datapath.joinpath("systemd"), False self.add_reserved_path("backup", self.datapath.joinpath("backup"), False)
)
self.gcode_metadata = MetadataStorage(config, db) self.gcode_metadata = MetadataStorage(config, db)
self.sync_lock = NotifySyncLock(config) self.sync_lock = NotifySyncLock(config)
avail_observers: Dict[str, Type[BaseFileSystemObserver]] = { avail_observers: Dict[str, Type[BaseFileSystemObserver]] = {

View File

@ -424,6 +424,12 @@ class Server:
logging.info("Exiting with signal SIGTERM") logging.info("Exiting with signal SIGTERM")
self.event_loop.register_callback(self._stop_server, "terminate") self.event_loop.register_callback(self._stop_server, "terminate")
def restart(self, delay: Optional[float] = None) -> None:
if delay is None:
self.event_loop.register_callback(self._stop_server)
else:
self.event_loop.delay_callback(delay, self._stop_server)
async def _stop_server(self, exit_reason: str = "restart") -> None: async def _stop_server(self, exit_reason: str = "restart") -> None:
self.server_running = False self.server_running = False
# Call each component's "on_exit" method # Call each component's "on_exit" method