From 5d35cc0d10ba7d571c9a70e1743eee6414273d87 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Sun, 4 Jul 2021 11:23:10 -0400 Subject: [PATCH] update_manager: Implement async streaming downloads This allows Moonraker to report download progress. This also resolves potential issues with I/O blocking the event loop. Signed-off-by: Eric Callahan --- .../update_manager/update_manager.py | 318 ++++++++++++------ 1 file changed, 212 insertions(+), 106 deletions(-) diff --git a/moonraker/components/update_manager/update_manager.py b/moonraker/components/update_manager/update_manager.py index 8b9e0e7..309965b 100644 --- a/moonraker/components/update_manager/update_manager.py +++ b/moonraker/components/update_manager/update_manager.py @@ -12,13 +12,13 @@ import json import sys import shutil import zipfile -import io import time import tempfile +from concurrent.futures import ThreadPoolExecutor import tornado.gen from tornado.ioloop import IOLoop, PeriodicCallback from tornado.httpclient import AsyncHTTPClient -from tornado.locks import Event, Condition, Lock +from tornado.locks import Event, Lock from .base_deploy import BaseDeploy from .app_deploy import AppDeploy from .git_deploy import GitDeploy @@ -28,6 +28,7 @@ from typing import ( TYPE_CHECKING, Any, Optional, + Tuple, Type, Union, Dict, @@ -530,14 +531,18 @@ class CommandHelper: raise self.server.error( f"Retries exceeded for GitHub API request: {url}") - async def http_download_request(self, url: str) -> bytes: + async def http_download_request(self, + url: str, + content_type: str, + timeout: float = 180. + ) -> bytes: retries = 5 while retries: try: - timeout = time.time() + 130. + timeout = time.time() + timeout + 10. fut = self.http_client.fetch( - url, headers={"Accept": "application/zip"}, - connect_timeout=5., request_timeout=120.) + url, headers={"Accept": content_type}, + connect_timeout=5., request_timeout=timeout) resp: HTTPResponse resp = await tornado.gen.with_timeout(timeout, fut) except Exception: @@ -551,6 +556,38 @@ class CommandHelper: raise self.server.error( f"Retries exceeded for GitHub API request: {url}") + async def streaming_download_request(self, + url: str, + dest: Union[str, pathlib.Path], + content_type: str, + size: int, + timeout: float = 180. + ) -> None: + if isinstance(dest, str): + dest = pathlib.Path(dest) + retries = 5 + while retries: + dl = StreamingDownload(self, dest, size) + try: + timeout = time.time() + timeout + 10. + fut = self.http_client.fetch( + url, headers={"Accept": content_type}, + connect_timeout=5., request_timeout=timeout, + streaming_callback=dl.on_chunk_recd) + resp: HTTPResponse + resp = await tornado.gen.with_timeout(timeout, fut) + except Exception: + retries -= 1 + logging.exception("Error Processing Download") + if not retries: + raise + await tornado.gen.sleep(1.) + continue + finally: + await dl.close() + if resp.code < 400: + return + def notify_update_response(self, resp: Union[str, bytes], is_complete: bool = False @@ -589,6 +626,50 @@ class CachedGithubResponse: self.etag = etag self.cached_result = result +class StreamingDownload: + def __init__(self, + cmd_helper: CommandHelper, + dest: pathlib.Path, + download_size: int) -> None: + self.cmd_helper = cmd_helper + self.ioloop = IOLoop.current() + self.name = dest.name + self.file_hdl = dest.open('wb') + self.download_size = download_size + self.total_recd: int = 0 + self.last_pct: int = 0 + self.chunk_buffer: List[bytes] = [] + self.busy_evt: Event = Event() + self.busy_evt.set() + + def on_chunk_recd(self, chunk: bytes) -> None: + if not chunk: + return + self.chunk_buffer.append(chunk) + if not self.busy_evt.is_set(): + return + self.busy_evt.clear() + self.ioloop.spawn_callback(self._process_buffer) + + async def close(self): + await self.busy_evt.wait() + self.file_hdl.close() + + async def _process_buffer(self): + while self.chunk_buffer: + chunk = self.chunk_buffer.pop(0) + self.total_recd += len(chunk) + pct = int(self.total_recd / self.download_size * 100 + .5) + with ThreadPoolExecutor(max_workers=1) as tpe: + await self.ioloop.run_in_executor( + tpe, self.file_hdl.write, chunk) + if pct >= self.last_pct + 5: + self.last_pct = pct + totals = f"{self.total_recd // 1024} KiB / " \ + f"{self.download_size // 1024} KiB" + self.cmd_helper.notify_update_response( + f"Downloading {self.name}: {totals} [{pct}%]") + self.busy_evt.set() class PackageDeploy(BaseDeploy): def __init__(self, @@ -597,49 +678,49 @@ class PackageDeploy(BaseDeploy): ) -> None: super().__init__(config, cmd_helper) self.available_packages: List[str] = [] - self.refresh_condition: Optional[Condition] = None + self.refresh_evt: Optional[Event] = None + self.mutex: Lock = Lock() async def refresh(self, fetch_packages: bool = True) -> None: # TODO: Use python-apt python lib rather than command line for updates - if self.refresh_condition is None: - self.refresh_condition = Condition() - else: - self.refresh_condition.wait() + if self.refresh_evt is not None: + self.refresh_evt.wait() return - try: - if fetch_packages: - await self.cmd_helper.run_cmd( - f"{APT_CMD} update", timeout=300., retries=3) - res = await self.cmd_helper.run_cmd_with_response( - "apt list --upgradable", timeout=60.) - pkg_list = [p.strip() for p in res.split("\n") if p.strip()] - if pkg_list: - pkg_list = pkg_list[2:] - self.available_packages = [p.split("/", maxsplit=1)[0] - for p in pkg_list] - pkg_msg = "\n".join(self.available_packages) - logging.info( - f"Detected {len(self.available_packages)} package updates:" - f"\n{pkg_msg}") - except Exception: - logging.exception("Error Refreshing System Packages") - self.refresh_condition.notify_all() - self.refresh_condition = None + async with self.mutex: + self.refresh_evt = Event() + try: + if fetch_packages: + await self.cmd_helper.run_cmd( + f"{APT_CMD} update", timeout=300., retries=3) + res = await self.cmd_helper.run_cmd_with_response( + "apt list --upgradable", timeout=60.) + pkg_list = [p.strip() for p in res.split("\n") if p.strip()] + if pkg_list: + pkg_list = pkg_list[2:] + self.available_packages = [p.split("/", maxsplit=1)[0] + for p in pkg_list] + pkg_msg = "\n".join(self.available_packages) + logging.info( + f"Detected {len(self.available_packages)} package updates:" + f"\n{pkg_msg}") + except Exception: + logging.exception("Error Refreshing System Packages") + self.refresh_evt.set() + self.refresh_evt = None async def update(self) -> None: - if self.refresh_condition is not None: - self.refresh_condition.wait() - self.cmd_helper.notify_update_response("Updating packages...") - try: - await self.cmd_helper.run_cmd( - f"{APT_CMD} update", timeout=300., notify=True) - await self.cmd_helper.run_cmd( - f"{APT_CMD} upgrade --yes", timeout=3600., notify=True) - except Exception: - raise self.server.error("Error updating system packages") - self.available_packages = [] - self.cmd_helper.notify_update_response("Package update finished...", - is_complete=True) + async with self.mutex: + self.cmd_helper.notify_update_response("Updating packages...") + try: + await self.cmd_helper.run_cmd( + f"{APT_CMD} update", timeout=300., notify=True) + await self.cmd_helper.run_cmd( + f"{APT_CMD} upgrade --yes", timeout=3600., notify=True) + except Exception: + raise self.server.error("Error updating system packages") + self.available_packages = [] + self.cmd_helper.notify_update_response( + "Package update finished...", is_complete=True) def get_update_status(self) -> Dict[str, Any]: return { @@ -665,34 +746,37 @@ class WebClientDeploy(BaseDeploy): raise config.error( "Invalid value for option 'persistent_files': " "'.version' can not be persistent") - self.version: str = "?" self.remote_version: str = "?" - self.dl_url: str = "?" - self.refresh_condition: Optional[Condition] = None - self._get_local_version() + self.dl_info: Tuple[str, str, int] = ("?", "?", 0) + self.refresh_evt: Optional[Event] = None + self.mutex: Lock = Lock() logging.info(f"\nInitializing Client Updater: '{self.name}'," - f"\nversion: {self.version}" f"\npath: {self.path}") - def _get_local_version(self) -> None: + async def _get_local_version(self) -> None: version_path = self.path.joinpath(".version") if version_path.is_file(): - self.version = version_path.read_text().strip() + with ThreadPoolExecutor(max_workers=1) as tpe: + version = await IOLoop.current().run_in_executor( + tpe, version_path.read_text) + self.version = version.strip() + else: + self.version = "?" async def refresh(self) -> None: - if self.refresh_condition is None: - self.refresh_condition = Condition() - else: - self.refresh_condition.wait() + if self.refresh_evt is not None: + self.refresh_evt.wait() return - try: - self._get_local_version() - await self._get_remote_version() - except Exception: - logging.exception("Error Refreshing Client") - self.refresh_condition.notify_all() - self.refresh_condition = None + async with self.mutex: + self.refresh_evt = Event() + try: + await self._get_local_version() + await self._get_remote_version() + except Exception: + logging.exception("Error Refreshing Client") + self.refresh_evt.set() + self.refresh_evt = None async def _get_remote_version(self) -> None: # Remote state @@ -704,59 +788,81 @@ class WebClientDeploy(BaseDeploy): logging.exception(f"Client {self.repo}: Github Request Error") result = {} self.remote_version = result.get('name', "?") - release_assets: Dict[str, Any] = result.get('assets', [{}])[0] - self.dl_url = release_assets.get('browser_download_url', "?") + release_asset: Dict[str, Any] = result.get('assets', [{}])[0] + dl_url: str = release_asset.get('browser_download_url', "?") + content_type: str = release_asset.get('content_type', "?") + size: int = release_asset.get('size', 0) + self.dl_info = (dl_url, content_type, size) logging.info( f"Github client Info Received:\nRepo: {self.name}\n" f"Local Version: {self.version}\n" f"Remote Version: {self.remote_version}\n" - f"url: {self.dl_url}") + f"url: {dl_url}\n" + f"size: {size}\n" + f"Content Type: {content_type}") async def update(self) -> None: - if self.refresh_condition is not None: - # wait for refresh if in progess - self.refresh_condition.wait() - if self.remote_version == "?": - await self.refresh() + async with self.mutex: if self.remote_version == "?": + await self._get_remote_version() + if self.remote_version == "?": + raise self.server.error( + f"Client {self.repo}: Unable to locate update") + dl_url, content_type, size = self.dl_info + if dl_url == "?": raise self.server.error( - f"Client {self.repo}: Unable to locate update") - if self.dl_url == "?": - raise self.server.error( - f"Client {self.repo}: Invalid download url") - if self.version == self.remote_version: - # Already up to date - return - self.cmd_helper.notify_update_response( - f"Downloading Client: {self.name}") - archive = await self.cmd_helper.http_download_request(self.dl_url) - with tempfile.TemporaryDirectory( - suffix=self.name, prefix="client") as tempdirname: - tempdir = pathlib.Path(tempdirname) - if self.path.is_dir(): - # find and move persistent files - for fname in os.listdir(self.path): - src_path = self.path.joinpath(fname) - if fname in self.persistent_files: - dest_dir = tempdir.joinpath(fname).parent - os.makedirs(dest_dir, exist_ok=True) - shutil.move(src_path, dest_dir) - shutil.rmtree(self.path) - os.mkdir(self.path) - with zipfile.ZipFile(io.BytesIO(archive)) as zf: - zf.extractall(self.path) - # Move temporary files back into - for fname in os.listdir(tempdir): - src_path = tempdir.joinpath(fname) - dest_dir = self.path.joinpath(fname).parent - os.makedirs(dest_dir, exist_ok=True) - shutil.move(src_path, dest_dir) - self.version = self.remote_version - version_path = self.path.joinpath(".version") - if not version_path.exists(): - version_path.write_text(self.version) - self.cmd_helper.notify_update_response( - f"Client Update Finished: {self.name}", is_complete=True) + f"Client {self.repo}: Invalid download url") + if self.version == self.remote_version: + # Already up to date + return + self.cmd_helper.notify_update_response( + f"Downloading Client: {self.name}") + with tempfile.TemporaryDirectory( + suffix=self.name, prefix="client") as tempdirname: + tempdir = pathlib.Path(tempdirname) + temp_download_file = tempdir.joinpath(f"{self.name}.zip") + temp_persist_dir = tempdir.joinpath(self.name) + await self.cmd_helper.streaming_download_request( + dl_url, temp_download_file, content_type, size) + self.cmd_helper.notify_update_response( + f"Download Complete, extracting release to '{self.path}'") + with ThreadPoolExecutor(max_workers=1) as tpe: + await IOLoop.current().run_in_executor( + tpe, self._extract_release, temp_persist_dir, + temp_download_file) + self.version = self.remote_version + version_path = self.path.joinpath(".version") + if not version_path.exists(): + with ThreadPoolExecutor(max_workers=1) as tpe: + await IOLoop.current().run_in_executor( + tpe, version_path.write_text, self.version) + self.cmd_helper.notify_update_response( + f"Client Update Finished: {self.name}", is_complete=True) + + def _extract_release(self, + persist_dir: pathlib.Path, + release_file: pathlib.Path + ) -> None: + if not persist_dir.exists(): + os.mkdir(persist_dir) + if self.path.is_dir(): + # find and move persistent files + for fname in os.listdir(self.path): + src_path = self.path.joinpath(fname) + if fname in self.persistent_files: + dest_dir = persist_dir.joinpath(fname).parent + os.makedirs(dest_dir, exist_ok=True) + shutil.move(str(src_path), str(dest_dir)) + shutil.rmtree(self.path) + os.mkdir(self.path) + with zipfile.ZipFile(release_file) as zf: + zf.extractall(self.path) + # Move temporary files back into + for fname in os.listdir(persist_dir): + src_path = persist_dir.joinpath(fname) + dest_dir = self.path.joinpath(fname).parent + os.makedirs(dest_dir, exist_ok=True) + shutil.move(str(src_path), str(dest_dir)) def get_update_status(self) -> Dict[str, Any]: return {