update_manager: Implement async streaming downloads

This allows Moonraker to report download progress.  This also resolves potential issues with I/O blocking the event loop.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2021-07-04 11:23:10 -04:00
parent f2b48d0f1a
commit 5d35cc0d10
1 changed files with 212 additions and 106 deletions

View File

@ -12,13 +12,13 @@ import json
import sys
import shutil
import zipfile
import io
import time
import tempfile
from concurrent.futures import ThreadPoolExecutor
import tornado.gen
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.httpclient import AsyncHTTPClient
from tornado.locks import Event, Condition, Lock
from tornado.locks import Event, Lock
from .base_deploy import BaseDeploy
from .app_deploy import AppDeploy
from .git_deploy import GitDeploy
@ -28,6 +28,7 @@ from typing import (
TYPE_CHECKING,
Any,
Optional,
Tuple,
Type,
Union,
Dict,
@ -530,14 +531,18 @@ class CommandHelper:
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def http_download_request(self, url: str) -> bytes:
async def http_download_request(self,
url: str,
content_type: str,
timeout: float = 180.
) -> bytes:
retries = 5
while retries:
try:
timeout = time.time() + 130.
timeout = time.time() + timeout + 10.
fut = self.http_client.fetch(
url, headers={"Accept": "application/zip"},
connect_timeout=5., request_timeout=120.)
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout)
resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut)
except Exception:
@ -551,6 +556,38 @@ class CommandHelper:
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def streaming_download_request(self,
url: str,
dest: Union[str, pathlib.Path],
content_type: str,
size: int,
timeout: float = 180.
) -> None:
if isinstance(dest, str):
dest = pathlib.Path(dest)
retries = 5
while retries:
dl = StreamingDownload(self, dest, size)
try:
timeout = time.time() + timeout + 10.
fut = self.http_client.fetch(
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout,
streaming_callback=dl.on_chunk_recd)
resp: HTTPResponse
resp = await tornado.gen.with_timeout(timeout, fut)
except Exception:
retries -= 1
logging.exception("Error Processing Download")
if not retries:
raise
await tornado.gen.sleep(1.)
continue
finally:
await dl.close()
if resp.code < 400:
return
def notify_update_response(self,
resp: Union[str, bytes],
is_complete: bool = False
@ -589,6 +626,50 @@ class CachedGithubResponse:
self.etag = etag
self.cached_result = result
class StreamingDownload:
def __init__(self,
cmd_helper: CommandHelper,
dest: pathlib.Path,
download_size: int) -> None:
self.cmd_helper = cmd_helper
self.ioloop = IOLoop.current()
self.name = dest.name
self.file_hdl = dest.open('wb')
self.download_size = download_size
self.total_recd: int = 0
self.last_pct: int = 0
self.chunk_buffer: List[bytes] = []
self.busy_evt: Event = Event()
self.busy_evt.set()
def on_chunk_recd(self, chunk: bytes) -> None:
if not chunk:
return
self.chunk_buffer.append(chunk)
if not self.busy_evt.is_set():
return
self.busy_evt.clear()
self.ioloop.spawn_callback(self._process_buffer)
async def close(self):
await self.busy_evt.wait()
self.file_hdl.close()
async def _process_buffer(self):
while self.chunk_buffer:
chunk = self.chunk_buffer.pop(0)
self.total_recd += len(chunk)
pct = int(self.total_recd / self.download_size * 100 + .5)
with ThreadPoolExecutor(max_workers=1) as tpe:
await self.ioloop.run_in_executor(
tpe, self.file_hdl.write, chunk)
if pct >= self.last_pct + 5:
self.last_pct = pct
totals = f"{self.total_recd // 1024} KiB / " \
f"{self.download_size // 1024} KiB"
self.cmd_helper.notify_update_response(
f"Downloading {self.name}: {totals} [{pct}%]")
self.busy_evt.set()
class PackageDeploy(BaseDeploy):
def __init__(self,
@ -597,15 +678,16 @@ class PackageDeploy(BaseDeploy):
) -> None:
super().__init__(config, cmd_helper)
self.available_packages: List[str] = []
self.refresh_condition: Optional[Condition] = None
self.refresh_evt: Optional[Event] = None
self.mutex: Lock = Lock()
async def refresh(self, fetch_packages: bool = True) -> None:
# TODO: Use python-apt python lib rather than command line for updates
if self.refresh_condition is None:
self.refresh_condition = Condition()
else:
self.refresh_condition.wait()
if self.refresh_evt is not None:
self.refresh_evt.wait()
return
async with self.mutex:
self.refresh_evt = Event()
try:
if fetch_packages:
await self.cmd_helper.run_cmd(
@ -623,12 +705,11 @@ class PackageDeploy(BaseDeploy):
f"\n{pkg_msg}")
except Exception:
logging.exception("Error Refreshing System Packages")
self.refresh_condition.notify_all()
self.refresh_condition = None
self.refresh_evt.set()
self.refresh_evt = None
async def update(self) -> None:
if self.refresh_condition is not None:
self.refresh_condition.wait()
async with self.mutex:
self.cmd_helper.notify_update_response("Updating packages...")
try:
await self.cmd_helper.run_cmd(
@ -638,8 +719,8 @@ class PackageDeploy(BaseDeploy):
except Exception:
raise self.server.error("Error updating system packages")
self.available_packages = []
self.cmd_helper.notify_update_response("Package update finished...",
is_complete=True)
self.cmd_helper.notify_update_response(
"Package update finished...", is_complete=True)
def get_update_status(self) -> Dict[str, Any]:
return {
@ -665,34 +746,37 @@ class WebClientDeploy(BaseDeploy):
raise config.error(
"Invalid value for option 'persistent_files': "
"'.version' can not be persistent")
self.version: str = "?"
self.remote_version: str = "?"
self.dl_url: str = "?"
self.refresh_condition: Optional[Condition] = None
self._get_local_version()
self.dl_info: Tuple[str, str, int] = ("?", "?", 0)
self.refresh_evt: Optional[Event] = None
self.mutex: Lock = Lock()
logging.info(f"\nInitializing Client Updater: '{self.name}',"
f"\nversion: {self.version}"
f"\npath: {self.path}")
def _get_local_version(self) -> None:
async def _get_local_version(self) -> None:
version_path = self.path.joinpath(".version")
if version_path.is_file():
self.version = version_path.read_text().strip()
with ThreadPoolExecutor(max_workers=1) as tpe:
version = await IOLoop.current().run_in_executor(
tpe, version_path.read_text)
self.version = version.strip()
else:
self.version = "?"
async def refresh(self) -> None:
if self.refresh_condition is None:
self.refresh_condition = Condition()
else:
self.refresh_condition.wait()
if self.refresh_evt is not None:
self.refresh_evt.wait()
return
async with self.mutex:
self.refresh_evt = Event()
try:
self._get_local_version()
await self._get_local_version()
await self._get_remote_version()
except Exception:
logging.exception("Error Refreshing Client")
self.refresh_condition.notify_all()
self.refresh_condition = None
self.refresh_evt.set()
self.refresh_evt = None
async def _get_remote_version(self) -> None:
# Remote state
@ -704,24 +788,28 @@ class WebClientDeploy(BaseDeploy):
logging.exception(f"Client {self.repo}: Github Request Error")
result = {}
self.remote_version = result.get('name', "?")
release_assets: Dict[str, Any] = result.get('assets', [{}])[0]
self.dl_url = release_assets.get('browser_download_url', "?")
release_asset: Dict[str, Any] = result.get('assets', [{}])[0]
dl_url: str = release_asset.get('browser_download_url', "?")
content_type: str = release_asset.get('content_type', "?")
size: int = release_asset.get('size', 0)
self.dl_info = (dl_url, content_type, size)
logging.info(
f"Github client Info Received:\nRepo: {self.name}\n"
f"Local Version: {self.version}\n"
f"Remote Version: {self.remote_version}\n"
f"url: {self.dl_url}")
f"url: {dl_url}\n"
f"size: {size}\n"
f"Content Type: {content_type}")
async def update(self) -> None:
if self.refresh_condition is not None:
# wait for refresh if in progess
self.refresh_condition.wait()
async with self.mutex:
if self.remote_version == "?":
await self.refresh()
await self._get_remote_version()
if self.remote_version == "?":
raise self.server.error(
f"Client {self.repo}: Unable to locate update")
if self.dl_url == "?":
dl_url, content_type, size = self.dl_info
if dl_url == "?":
raise self.server.error(
f"Client {self.repo}: Invalid download url")
if self.version == self.remote_version:
@ -729,34 +817,52 @@ class WebClientDeploy(BaseDeploy):
return
self.cmd_helper.notify_update_response(
f"Downloading Client: {self.name}")
archive = await self.cmd_helper.http_download_request(self.dl_url)
with tempfile.TemporaryDirectory(
suffix=self.name, prefix="client") as tempdirname:
tempdir = pathlib.Path(tempdirname)
temp_download_file = tempdir.joinpath(f"{self.name}.zip")
temp_persist_dir = tempdir.joinpath(self.name)
await self.cmd_helper.streaming_download_request(
dl_url, temp_download_file, content_type, size)
self.cmd_helper.notify_update_response(
f"Download Complete, extracting release to '{self.path}'")
with ThreadPoolExecutor(max_workers=1) as tpe:
await IOLoop.current().run_in_executor(
tpe, self._extract_release, temp_persist_dir,
temp_download_file)
self.version = self.remote_version
version_path = self.path.joinpath(".version")
if not version_path.exists():
with ThreadPoolExecutor(max_workers=1) as tpe:
await IOLoop.current().run_in_executor(
tpe, version_path.write_text, self.version)
self.cmd_helper.notify_update_response(
f"Client Update Finished: {self.name}", is_complete=True)
def _extract_release(self,
persist_dir: pathlib.Path,
release_file: pathlib.Path
) -> None:
if not persist_dir.exists():
os.mkdir(persist_dir)
if self.path.is_dir():
# find and move persistent files
for fname in os.listdir(self.path):
src_path = self.path.joinpath(fname)
if fname in self.persistent_files:
dest_dir = tempdir.joinpath(fname).parent
dest_dir = persist_dir.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(src_path, dest_dir)
shutil.move(str(src_path), str(dest_dir))
shutil.rmtree(self.path)
os.mkdir(self.path)
with zipfile.ZipFile(io.BytesIO(archive)) as zf:
with zipfile.ZipFile(release_file) as zf:
zf.extractall(self.path)
# Move temporary files back into
for fname in os.listdir(tempdir):
src_path = tempdir.joinpath(fname)
for fname in os.listdir(persist_dir):
src_path = persist_dir.joinpath(fname)
dest_dir = self.path.joinpath(fname).parent
os.makedirs(dest_dir, exist_ok=True)
shutil.move(src_path, dest_dir)
self.version = self.remote_version
version_path = self.path.joinpath(".version")
if not version_path.exists():
version_path.write_text(self.version)
self.cmd_helper.notify_update_response(
f"Client Update Finished: {self.name}", is_complete=True)
shutil.move(str(src_path), str(dest_dir))
def get_update_status(self) -> Dict[str, Any]:
return {