update_manager: Implement async streaming downloads

This allows Moonraker to report download progress.  This also resolves potential issues with I/O blocking the event loop.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2021-07-04 11:23:10 -04:00
parent f2b48d0f1a
commit 5d35cc0d10
1 changed files with 212 additions and 106 deletions

View File

@ -12,13 +12,13 @@ import json
import sys
import shutil
import zipfile
import io
import time
import tempfile
from concurrent.futures import ThreadPoolExecutor
import tornado.gen
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.httpclient import AsyncHTTPClient
from tornado.locks import Event, Condition, Lock
from tornado.locks import Event, Lock
from .base_deploy import BaseDeploy
from .app_deploy import AppDeploy
from .git_deploy import GitDeploy
@ -28,6 +28,7 @@ from typing import (
TYPE_CHECKING,
Any,
Optional,
Tuple,
Type,
Union,
Dict,
@ -530,14 +531,18 @@ class CommandHelper:
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def http_download_request(self, url: str) -> bytes:
async def http_download_request(self,
url: str,
content_type: str,
timeout: float = 180.
) -> bytes:
retries = 5
while retries:
try:
timeout = time.time() + 130.
timeout = time.time() + timeout + 10.
fut = self.http_client.fetch(
url, headers={"Accept": "application/zip"},
connect_timeout=5., request_timeout=120.)
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout)
resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut)
except Exception:
@ -551,6 +556,38 @@ class CommandHelper:
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def streaming_download_request(self,
url: str,
dest: Union[str, pathlib.Path],
content_type: str,
size: int,
timeout: float = 180.
) -> None:
if isinstance(dest, str):
dest = pathlib.Path(dest)
retries = 5
while retries:
dl = StreamingDownload(self, dest, size)
try:
timeout = time.time() + timeout + 10.
fut = self.http_client.fetch(
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout,
streaming_callback=dl.on_chunk_recd)
resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut)
except Exception:
retries -= 1
logging.exception("Error Processing Download")
if not retries:
raise
await tornado.gen.sleep(1.)
continue
finally:
await dl.close()
if resp.code < 400:
return
def notify_update_response(self,
resp: Union[str, bytes],
is_complete: bool = False
@ -589,6 +626,50 @@ class CachedGithubResponse:
self.etag = etag
self.cached_result = result
class StreamingDownload:
def __init__(self,
cmd_helper: CommandHelper,
dest: pathlib.Path,
download_size: int) -> None:
self.cmd_helper = cmd_helper
self.ioloop = IOLoop.current()
self.name = dest.name
self.file_hdl = dest.open('wb')
self.download_size = download_size
self.total_recd: int = 0
self.last_pct: int = 0
self.chunk_buffer: List[bytes] = []
self.busy_evt: Event = Event()
self.busy_evt.set()
def on_chunk_recd(self, chunk: bytes) -> None:
if not chunk:
return
self.chunk_buffer.append(chunk)
if not self.busy_evt.is_set():
return
self.busy_evt.clear()
self.ioloop.spawn_callback(self._process_buffer)
async def close(self):
await self.busy_evt.wait()
self.file_hdl.close()
async def _process_buffer(self):
while self.chunk_buffer:
chunk = self.chunk_buffer.pop(0)
self.total_recd += len(chunk)
pct = int(self.total_recd / self.download_size * 100 + .5)
with ThreadPoolExecutor(max_workers=1) as tpe:
await self.ioloop.run_in_executor(
tpe, self.file_hdl.write, chunk)
if pct >= self.last_pct + 5:
self.last_pct = pct
totals = f"{self.total_recd // 1024} KiB / " \
f"{self.download_size // 1024} KiB"
self.cmd_helper.notify_update_response(
f"Downloading {self.name}: {totals} [{pct}%]")
self.busy_evt.set()
class PackageDeploy(BaseDeploy):
def __init__(self,
@ -597,49 +678,49 @@ class PackageDeploy(BaseDeploy):
) -> None:
super().__init__(config, cmd_helper)
self.available_packages: List[str] = []
self.refresh_condition: Optional[Condition] = None
self.refresh_evt: Optional[Event] = None
self.mutex: Lock = Lock()
async def refresh(self, fetch_packages: bool = True) -> None:
# TODO: Use python-apt python lib rather than command line for updates
if self.refresh_condition is None:
self.refresh_condition = Condition()
else:
self.refresh_condition.wait()
if self.refresh_evt is not None:
self.refresh_evt.wait()
return
try:
if fetch_packages:
await self.cmd_helper.run_cmd(
f"{APT_CMD} update", timeout=300., retries=3)
res = await self.cmd_helper.run_cmd_with_response(
"apt list --upgradable", timeout=60.)
pkg_list = [p.strip() for p in res.split("\n") if p.strip()]
if pkg_list:
pkg_list = pkg_list[2:]
self.available_packages = [p.split("/", maxsplit=1)[0]
for p in pkg_list]
pkg_msg = "\n".join(self.available_packages)
logging.info(
f"Detected {len(self.available_packages)} package updates:"
f"\n{pkg_msg}")
except Exception:
logging.exception("Error Refreshing System Packages")
self.refresh_condition.notify_all()
self.refresh_condition = None
async with self.mutex:
self.refresh_evt = Event()
try:
if fetch_packages:
await self.cmd_helper.run_cmd(
f"{APT_CMD} update", timeout=300., retries=3)
res = await self.cmd_helper.run_cmd_with_response(
"apt list --upgradable", timeout=60.)
pkg_list = [p.strip() for p in res.split("\n") if p.strip()]
if pkg_list:
pkg_list = pkg_list[2:]
self.available_packages = [p.split("/", maxsplit=1)[0]
for p in pkg_list]
pkg_msg = "\n".join(self.available_packages)
logging.info(
f"Detected {len(self.available_packages)} package updates:"
f"\n{pkg_msg}")
except Exception:
logging.exception("Error Refreshing System Packages")
self.refresh_evt.set()
self.refresh_evt = None
async def update(self) -> None:
if self.refresh_condition is not None:
self.refresh_condition.wait()
self.cmd_helper.notify_update_response("Updating packages...")
try:
await self.cmd_helper.run_cmd(
f"{APT_CMD} update", timeout=300., notify=True)
await self.cmd_helper.run_cmd(
f"{APT_CMD} upgrade --yes", timeout=3600., notify=True)
except Exception:
raise self.server.error("Error updating system packages")
self.available_packages = []
self.cmd_helper.notify_update_response("Package update finished...",
is_complete=True)
async with self.mutex:
self.cmd_helper.notify_update_response("Updating packages...")
try:
await self.cmd_helper.run_cmd(
f"{APT_CMD} update", timeout=300., notify=True)
await self.cmd_helper.run_cmd(
f"{APT_CMD} upgrade --yes", timeout=3600., notify=True)
except Exception:
raise self.server.error("Error updating system packages")
self.available_packages = []
self.cmd_helper.notify_update_response(
"Package update finished...", is_complete=True)
def get_update_status(self) -> Dict[str, Any]:
return {
@ -665,34 +746,37 @@ class WebClientDeploy(BaseDeploy):
raise config.error(
"Invalid value for option 'persistent_files': "
"'.version' can not be persistent")
self.version: str = "?"
self.remote_version: str = "?"
self.dl_url: str = "?"
self.refresh_condition: Optional[Condition] = None
self._get_local_version()
self.dl_info: Tuple[str, str, int] = ("?", "?", 0)
self.refresh_evt: Optional[Event] = None
self.mutex: Lock = Lock()
logging.info(f"\nInitializing Client Updater: '{self.name}',"
f"\nversion: {self.version}"
f"\npath: {self.path}")
def _get_local_version(self) -> None:
async def _get_local_version(self) -> None:
version_path = self.path.joinpath(".version")
if version_path.is_file():
self.version = version_path.read_text().strip()
with ThreadPoolExecutor(max_workers=1) as tpe:
version = await IOLoop.current().run_in_executor(
tpe, version_path.read_text)
self.version = version.strip()
else:
self.version = "?"
async def refresh(self) -> None:
if self.refresh_condition is None:
self.refresh_condition = Condition()
else:
self.refresh_condition.wait()
if self.refresh_evt is not None:
self.refresh_evt.wait()
return
try:
self._get_local_version()
await self._get_remote_version()
except Exception:
logging.exception("Error Refreshing Client")
self.refresh_condition.notify_all()
self.refresh_condition = None
async with self.mutex:
self.refresh_evt = Event()
try:
await self._get_local_version()
await self._get_remote_version()
except Exception:
logging.exception("Error Refreshing Client")
self.refresh_evt.set()
self.refresh_evt = None
async def _get_remote_version(self) -> None:
# Remote state
@ -704,59 +788,81 @@ class WebClientDeploy(BaseDeploy):
logging.exception(f"Client {self.repo}: Github Request Error")
result = {}
self.remote_version = result.get('name', "?")
release_assets: Dict[str, Any] = result.get('assets', [{}])[0]
self.dl_url = release_assets.get('browser_download_url', "?")
release_asset: Dict[str, Any] = result.get('assets', [{}])[0]
dl_url: str = release_asset.get('browser_download_url', "?")
content_type: str = release_asset.get('content_type', "?")
size: int = release_asset.get('size', 0)
self.dl_info = (dl_url, content_type, size)
logging.info(
f"Github client Info Received:\nRepo: {self.name}\n"
f"Local Version: {self.version}\n"
f"Remote Version: {self.remote_version}\n"
f"url: {self.dl_url}")
f"url: {dl_url}\n"
f"size: {size}\n"
f"Content Type: {content_type}")
async def update(self) -> None:
if self.refresh_condition is not None:
# wait for refresh if in progess
self.refresh_condition.wait()
if self.remote_version == "?":
await self.refresh()
async with self.mutex:
if self.remote_version == "?":
await self._get_remote_version()
if self.remote_version == "?":
raise self.server.error(
f"Client {self.repo}: Unable to locate update")
dl_url, content_type, size = self.dl_info
if dl_url == "?":
raise self.server.error(
f"Client {self.repo}: Unable to locate update")
if self.dl_url == "?":
raise self.server.error(
f"Client {self.repo}: Invalid download url")
if self.version == self.remote_version:
# Already up to date
return
self.cmd_helper.notify_update_response(
f"Downloading Client: {self.name}")
archive = await self.cmd_helper.http_download_request(self.dl_url)
with tempfile.TemporaryDirectory(
suffix=self.name, prefix="client") as tempdirname:
tempdir = pathlib.Path(tempdirname)
if self.path.is_dir():
# find and move persistent files
for fname in os.listdir(self.path):
src_path = self.path.joinpath(fname)
if fname in self.persistent_files:
dest_dir = tempdir.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(src_path, dest_dir)
shutil.rmtree(self.path)
os.mkdir(self.path)
with zipfile.ZipFile(io.BytesIO(archive)) as zf:
zf.extractall(self.path)
# Move temporary files back into
for fname in os.listdir(tempdir):
src_path = tempdir.joinpath(fname)
dest_dir = self.path.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(src_path, dest_dir)
self.version = self.remote_version
version_path = self.path.joinpath(".version")
if not version_path.exists():
version_path.write_text(self.version)
self.cmd_helper.notify_update_response(
f"Client Update Finished: {self.name}", is_complete=True)
f"Client {self.repo}: Invalid download url")
if self.version == self.remote_version:
# Already up to date
return
self.cmd_helper.notify_update_response(
f"Downloading Client: {self.name}")
with tempfile.TemporaryDirectory(
suffix=self.name, prefix="client") as tempdirname:
tempdir = pathlib.Path(tempdirname)
temp_download_file = tempdir.joinpath(f"{self.name}.zip")
temp_persist_dir = tempdir.joinpath(self.name)
await self.cmd_helper.streaming_download_request(
dl_url, temp_download_file, content_type, size)
self.cmd_helper.notify_update_response(
f"Download Complete, extracting release to '{self.path}'")
with ThreadPoolExecutor(max_workers=1) as tpe:
await IOLoop.current().run_in_executor(
tpe, self._extract_release, temp_persist_dir,
temp_download_file)
self.version = self.remote_version
version_path = self.path.joinpath(".version")
if not version_path.exists():
with ThreadPoolExecutor(max_workers=1) as tpe:
await IOLoop.current().run_in_executor(
tpe, version_path.write_text, self.version)
self.cmd_helper.notify_update_response(
f"Client Update Finished: {self.name}", is_complete=True)
def _extract_release(self,
persist_dir: pathlib.Path,
release_file: pathlib.Path
) -> None:
if not persist_dir.exists():
os.mkdir(persist_dir)
if self.path.is_dir():
# find and move persistent files
for fname in os.listdir(self.path):
src_path = self.path.joinpath(fname)
if fname in self.persistent_files:
dest_dir = persist_dir.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(str(src_path), str(dest_dir))
shutil.rmtree(self.path)
os.mkdir(self.path)
with zipfile.ZipFile(release_file) as zf:
zf.extractall(self.path)
# Move temporary files back into
for fname in os.listdir(persist_dir):
src_path = persist_dir.joinpath(fname)
dest_dir = self.path.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(str(src_path), str(dest_dir))
def get_update_status(self) -> Dict[str, Any]:
return {