update_manager: Implement async streaming downloads

This allows Moonraker to report download progress.  This also resolves potential issues with I/O blocking the event loop.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2021-07-04 11:23:10 -04:00
parent f2b48d0f1a
commit 5d35cc0d10
1 changed files with 212 additions and 106 deletions

View File

@ -12,13 +12,13 @@ import json
import sys import sys
import shutil import shutil
import zipfile import zipfile
import io
import time import time
import tempfile import tempfile
from concurrent.futures import ThreadPoolExecutor
import tornado.gen import tornado.gen
from tornado.ioloop import IOLoop, PeriodicCallback from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.httpclient import AsyncHTTPClient from tornado.httpclient import AsyncHTTPClient
from tornado.locks import Event, Condition, Lock from tornado.locks import Event, Lock
from .base_deploy import BaseDeploy from .base_deploy import BaseDeploy
from .app_deploy import AppDeploy from .app_deploy import AppDeploy
from .git_deploy import GitDeploy from .git_deploy import GitDeploy
@ -28,6 +28,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Optional, Optional,
Tuple,
Type, Type,
Union, Union,
Dict, Dict,
@ -530,14 +531,18 @@ class CommandHelper:
raise self.server.error( raise self.server.error(
f"Retries exceeded for GitHub API request: {url}") f"Retries exceeded for GitHub API request: {url}")
async def http_download_request(self, url: str) -> bytes: async def http_download_request(self,
url: str,
content_type: str,
timeout: float = 180.
) -> bytes:
retries = 5 retries = 5
while retries: while retries:
try: try:
timeout = time.time() + 130. timeout = time.time() + timeout + 10.
fut = self.http_client.fetch( fut = self.http_client.fetch(
url, headers={"Accept": "application/zip"}, url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=120.) connect_timeout=5., request_timeout=timeout)
resp: HTTPResponse resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut) resp = await tornado.gen.with_timeout(timeout, fut)
except Exception: except Exception:
@ -551,6 +556,38 @@ class CommandHelper:
raise self.server.error( raise self.server.error(
f"Retries exceeded for GitHub API request: {url}") f"Retries exceeded for GitHub API request: {url}")
async def streaming_download_request(self,
url: str,
dest: Union[str, pathlib.Path],
content_type: str,
size: int,
timeout: float = 180.
) -> None:
if isinstance(dest, str):
dest = pathlib.Path(dest)
retries = 5
while retries:
dl = StreamingDownload(self, dest, size)
try:
timeout = time.time() + timeout + 10.
fut = self.http_client.fetch(
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout,
streaming_callback=dl.on_chunk_recd)
resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut)
except Exception:
retries -= 1
logging.exception("Error Processing Download")
if not retries:
raise
await tornado.gen.sleep(1.)
continue
finally:
await dl.close()
if resp.code < 400:
return
def notify_update_response(self, def notify_update_response(self,
resp: Union[str, bytes], resp: Union[str, bytes],
is_complete: bool = False is_complete: bool = False
@ -589,6 +626,50 @@ class CachedGithubResponse:
self.etag = etag self.etag = etag
self.cached_result = result self.cached_result = result
class StreamingDownload:
def __init__(self,
cmd_helper: CommandHelper,
dest: pathlib.Path,
download_size: int) -> None:
self.cmd_helper = cmd_helper
self.ioloop = IOLoop.current()
self.name = dest.name
self.file_hdl = dest.open('wb')
self.download_size = download_size
self.total_recd: int = 0
self.last_pct: int = 0
self.chunk_buffer: List[bytes] = []
self.busy_evt: Event = Event()
self.busy_evt.set()
def on_chunk_recd(self, chunk: bytes) -> None:
if not chunk:
return
self.chunk_buffer.append(chunk)
if not self.busy_evt.is_set():
return
self.busy_evt.clear()
self.ioloop.spawn_callback(self._process_buffer)
async def close(self):
await self.busy_evt.wait()
self.file_hdl.close()
async def _process_buffer(self):
while self.chunk_buffer:
chunk = self.chunk_buffer.pop(0)
self.total_recd += len(chunk)
pct = int(self.total_recd / self.download_size * 100 + .5)
with ThreadPoolExecutor(max_workers=1) as tpe:
await self.ioloop.run_in_executor(
tpe, self.file_hdl.write, chunk)
if pct >= self.last_pct + 5:
self.last_pct = pct
totals = f"{self.total_recd // 1024} KiB / " \
f"{self.download_size // 1024} KiB"
self.cmd_helper.notify_update_response(
f"Downloading {self.name}: {totals} [{pct}%]")
self.busy_evt.set()
class PackageDeploy(BaseDeploy): class PackageDeploy(BaseDeploy):
def __init__(self, def __init__(self,
@ -597,49 +678,49 @@ class PackageDeploy(BaseDeploy):
) -> None: ) -> None:
super().__init__(config, cmd_helper) super().__init__(config, cmd_helper)
self.available_packages: List[str] = [] self.available_packages: List[str] = []
self.refresh_condition: Optional[Condition] = None self.refresh_evt: Optional[Event] = None
self.mutex: Lock = Lock()
async def refresh(self, fetch_packages: bool = True) -> None: async def refresh(self, fetch_packages: bool = True) -> None:
# TODO: Use python-apt python lib rather than command line for updates # TODO: Use python-apt python lib rather than command line for updates
if self.refresh_condition is None: if self.refresh_evt is not None:
self.refresh_condition = Condition() self.refresh_evt.wait()
else:
self.refresh_condition.wait()
return return
try: async with self.mutex:
if fetch_packages: self.refresh_evt = Event()
await self.cmd_helper.run_cmd( try:
f"{APT_CMD} update", timeout=300., retries=3) if fetch_packages:
res = await self.cmd_helper.run_cmd_with_response( await self.cmd_helper.run_cmd(
"apt list --upgradable", timeout=60.) f"{APT_CMD} update", timeout=300., retries=3)
pkg_list = [p.strip() for p in res.split("\n") if p.strip()] res = await self.cmd_helper.run_cmd_with_response(
if pkg_list: "apt list --upgradable", timeout=60.)
pkg_list = pkg_list[2:] pkg_list = [p.strip() for p in res.split("\n") if p.strip()]
self.available_packages = [p.split("/", maxsplit=1)[0] if pkg_list:
for p in pkg_list] pkg_list = pkg_list[2:]
pkg_msg = "\n".join(self.available_packages) self.available_packages = [p.split("/", maxsplit=1)[0]
logging.info( for p in pkg_list]
f"Detected {len(self.available_packages)} package updates:" pkg_msg = "\n".join(self.available_packages)
f"\n{pkg_msg}") logging.info(
except Exception: f"Detected {len(self.available_packages)} package updates:"
logging.exception("Error Refreshing System Packages") f"\n{pkg_msg}")
self.refresh_condition.notify_all() except Exception:
self.refresh_condition = None logging.exception("Error Refreshing System Packages")
self.refresh_evt.set()
self.refresh_evt = None
async def update(self) -> None: async def update(self) -> None:
if self.refresh_condition is not None: async with self.mutex:
self.refresh_condition.wait() self.cmd_helper.notify_update_response("Updating packages...")
self.cmd_helper.notify_update_response("Updating packages...") try:
try: await self.cmd_helper.run_cmd(
await self.cmd_helper.run_cmd( f"{APT_CMD} update", timeout=300., notify=True)
f"{APT_CMD} update", timeout=300., notify=True) await self.cmd_helper.run_cmd(
await self.cmd_helper.run_cmd( f"{APT_CMD} upgrade --yes", timeout=3600., notify=True)
f"{APT_CMD} upgrade --yes", timeout=3600., notify=True) except Exception:
except Exception: raise self.server.error("Error updating system packages")
raise self.server.error("Error updating system packages") self.available_packages = []
self.available_packages = [] self.cmd_helper.notify_update_response(
self.cmd_helper.notify_update_response("Package update finished...", "Package update finished...", is_complete=True)
is_complete=True)
def get_update_status(self) -> Dict[str, Any]: def get_update_status(self) -> Dict[str, Any]:
return { return {
@ -665,34 +746,37 @@ class WebClientDeploy(BaseDeploy):
raise config.error( raise config.error(
"Invalid value for option 'persistent_files': " "Invalid value for option 'persistent_files': "
"'.version' can not be persistent") "'.version' can not be persistent")
self.version: str = "?" self.version: str = "?"
self.remote_version: str = "?" self.remote_version: str = "?"
self.dl_url: str = "?" self.dl_info: Tuple[str, str, int] = ("?", "?", 0)
self.refresh_condition: Optional[Condition] = None self.refresh_evt: Optional[Event] = None
self._get_local_version() self.mutex: Lock = Lock()
logging.info(f"\nInitializing Client Updater: '{self.name}'," logging.info(f"\nInitializing Client Updater: '{self.name}',"
f"\nversion: {self.version}"
f"\npath: {self.path}") f"\npath: {self.path}")
def _get_local_version(self) -> None: async def _get_local_version(self) -> None:
version_path = self.path.joinpath(".version") version_path = self.path.joinpath(".version")
if version_path.is_file(): if version_path.is_file():
self.version = version_path.read_text().strip() with ThreadPoolExecutor(max_workers=1) as tpe:
version = await IOLoop.current().run_in_executor(
tpe, version_path.read_text)
self.version = version.strip()
else:
self.version = "?"
async def refresh(self) -> None: async def refresh(self) -> None:
if self.refresh_condition is None: if self.refresh_evt is not None:
self.refresh_condition = Condition() self.refresh_evt.wait()
else:
self.refresh_condition.wait()
return return
try: async with self.mutex:
self._get_local_version() self.refresh_evt = Event()
await self._get_remote_version() try:
except Exception: await self._get_local_version()
logging.exception("Error Refreshing Client") await self._get_remote_version()
self.refresh_condition.notify_all() except Exception:
self.refresh_condition = None logging.exception("Error Refreshing Client")
self.refresh_evt.set()
self.refresh_evt = None
async def _get_remote_version(self) -> None: async def _get_remote_version(self) -> None:
# Remote state # Remote state
@ -704,59 +788,81 @@ class WebClientDeploy(BaseDeploy):
logging.exception(f"Client {self.repo}: Github Request Error") logging.exception(f"Client {self.repo}: Github Request Error")
result = {} result = {}
self.remote_version = result.get('name', "?") self.remote_version = result.get('name', "?")
release_assets: Dict[str, Any] = result.get('assets', [{}])[0] release_asset: Dict[str, Any] = result.get('assets', [{}])[0]
self.dl_url = release_assets.get('browser_download_url', "?") dl_url: str = release_asset.get('browser_download_url', "?")
content_type: str = release_asset.get('content_type', "?")
size: int = release_asset.get('size', 0)
self.dl_info = (dl_url, content_type, size)
logging.info( logging.info(
f"Github client Info Received:\nRepo: {self.name}\n" f"Github client Info Received:\nRepo: {self.name}\n"
f"Local Version: {self.version}\n" f"Local Version: {self.version}\n"
f"Remote Version: {self.remote_version}\n" f"Remote Version: {self.remote_version}\n"
f"url: {self.dl_url}") f"url: {dl_url}\n"
f"size: {size}\n"
f"Content Type: {content_type}")
async def update(self) -> None: async def update(self) -> None:
if self.refresh_condition is not None: async with self.mutex:
# wait for refresh if in progess
self.refresh_condition.wait()
if self.remote_version == "?":
await self.refresh()
if self.remote_version == "?": if self.remote_version == "?":
await self._get_remote_version()
if self.remote_version == "?":
raise self.server.error(
f"Client {self.repo}: Unable to locate update")
dl_url, content_type, size = self.dl_info
if dl_url == "?":
raise self.server.error( raise self.server.error(
f"Client {self.repo}: Unable to locate update") f"Client {self.repo}: Invalid download url")
if self.dl_url == "?": if self.version == self.remote_version:
raise self.server.error( # Already up to date
f"Client {self.repo}: Invalid download url") return
if self.version == self.remote_version: self.cmd_helper.notify_update_response(
# Already up to date f"Downloading Client: {self.name}")
return with tempfile.TemporaryDirectory(
self.cmd_helper.notify_update_response( suffix=self.name, prefix="client") as tempdirname:
f"Downloading Client: {self.name}") tempdir = pathlib.Path(tempdirname)
archive = await self.cmd_helper.http_download_request(self.dl_url) temp_download_file = tempdir.joinpath(f"{self.name}.zip")
with tempfile.TemporaryDirectory( temp_persist_dir = tempdir.joinpath(self.name)
suffix=self.name, prefix="client") as tempdirname: await self.cmd_helper.streaming_download_request(
tempdir = pathlib.Path(tempdirname) dl_url, temp_download_file, content_type, size)
if self.path.is_dir(): self.cmd_helper.notify_update_response(
# find and move persistent files f"Download Complete, extracting release to '{self.path}'")
for fname in os.listdir(self.path): with ThreadPoolExecutor(max_workers=1) as tpe:
src_path = self.path.joinpath(fname) await IOLoop.current().run_in_executor(
if fname in self.persistent_files: tpe, self._extract_release, temp_persist_dir,
dest_dir = tempdir.joinpath(fname).parent temp_download_file)
os.makedirs(dest_dir, exist_ok=True) self.version = self.remote_version
shutil.move(src_path, dest_dir) version_path = self.path.joinpath(".version")
shutil.rmtree(self.path) if not version_path.exists():
os.mkdir(self.path) with ThreadPoolExecutor(max_workers=1) as tpe:
with zipfile.ZipFile(io.BytesIO(archive)) as zf: await IOLoop.current().run_in_executor(
zf.extractall(self.path) tpe, version_path.write_text, self.version)
# Move temporary files back into self.cmd_helper.notify_update_response(
for fname in os.listdir(tempdir): f"Client Update Finished: {self.name}", is_complete=True)
src_path = tempdir.joinpath(fname)
dest_dir = self.path.joinpath(fname).parent def _extract_release(self,
os.makedirs(dest_dir, exist_ok=True) persist_dir: pathlib.Path,
shutil.move(src_path, dest_dir) release_file: pathlib.Path
self.version = self.remote_version ) -> None:
version_path = self.path.joinpath(".version") if not persist_dir.exists():
if not version_path.exists(): os.mkdir(persist_dir)
version_path.write_text(self.version) if self.path.is_dir():
self.cmd_helper.notify_update_response( # find and move persistent files
f"Client Update Finished: {self.name}", is_complete=True) for fname in os.listdir(self.path):
src_path = self.path.joinpath(fname)
if fname in self.persistent_files:
dest_dir = persist_dir.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(str(src_path), str(dest_dir))
shutil.rmtree(self.path)
os.mkdir(self.path)
with zipfile.ZipFile(release_file) as zf:
zf.extractall(self.path)
# Move temporary files back into
for fname in os.listdir(persist_dir):
src_path = persist_dir.joinpath(fname)
dest_dir = self.path.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(str(src_path), str(dest_dir))
def get_update_status(self) -> Dict[str, Any]: def get_update_status(self) -> Dict[str, Any]:
return { return {