update_manager: Implement async streaming downloads
This allows Moonraker to report download progress. This also resolves potential issues with I/O blocking the event loop. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
f2b48d0f1a
commit
5d35cc0d10
|
@ -12,13 +12,13 @@ import json
|
|||
import sys
|
||||
import shutil
|
||||
import zipfile
|
||||
import io
|
||||
import time
|
||||
import tempfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import tornado.gen
|
||||
from tornado.ioloop import IOLoop, PeriodicCallback
|
||||
from tornado.httpclient import AsyncHTTPClient
|
||||
from tornado.locks import Event, Condition, Lock
|
||||
from tornado.locks import Event, Lock
|
||||
from .base_deploy import BaseDeploy
|
||||
from .app_deploy import AppDeploy
|
||||
from .git_deploy import GitDeploy
|
||||
|
@ -28,6 +28,7 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Any,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
Dict,
|
||||
|
@ -530,14 +531,18 @@ class CommandHelper:
|
|||
raise self.server.error(
|
||||
f"Retries exceeded for GitHub API request: {url}")
|
||||
|
||||
async def http_download_request(self, url: str) -> bytes:
|
||||
async def http_download_request(self,
|
||||
url: str,
|
||||
content_type: str,
|
||||
timeout: float = 180.
|
||||
) -> bytes:
|
||||
retries = 5
|
||||
while retries:
|
||||
try:
|
||||
timeout = time.time() + 130.
|
||||
timeout = time.time() + timeout + 10.
|
||||
fut = self.http_client.fetch(
|
||||
url, headers={"Accept": "application/zip"},
|
||||
connect_timeout=5., request_timeout=120.)
|
||||
url, headers={"Accept": content_type},
|
||||
connect_timeout=5., request_timeout=timeout)
|
||||
resp: HTTPResponse
|
||||
resp = await tornado.gen.with_timeout(timeout, fut)
|
||||
except Exception:
|
||||
|
@ -551,6 +556,38 @@ class CommandHelper:
|
|||
raise self.server.error(
|
||||
f"Retries exceeded for GitHub API request: {url}")
|
||||
|
||||
async def streaming_download_request(self,
|
||||
url: str,
|
||||
dest: Union[str, pathlib.Path],
|
||||
content_type: str,
|
||||
size: int,
|
||||
timeout: float = 180.
|
||||
) -> None:
|
||||
if isinstance(dest, str):
|
||||
dest = pathlib.Path(dest)
|
||||
retries = 5
|
||||
while retries:
|
||||
dl = StreamingDownload(self, dest, size)
|
||||
try:
|
||||
timeout = time.time() + timeout + 10.
|
||||
fut = self.http_client.fetch(
|
||||
url, headers={"Accept": content_type},
|
||||
connect_timeout=5., request_timeout=timeout,
|
||||
streaming_callback=dl.on_chunk_recd)
|
||||
resp: HTTPResponse
|
||||
resp = await tornado.gen.with_timeout(timeout, fut)
|
||||
except Exception:
|
||||
retries -= 1
|
||||
logging.exception("Error Processing Download")
|
||||
if not retries:
|
||||
raise
|
||||
await tornado.gen.sleep(1.)
|
||||
continue
|
||||
finally:
|
||||
await dl.close()
|
||||
if resp.code < 400:
|
||||
return
|
||||
|
||||
def notify_update_response(self,
|
||||
resp: Union[str, bytes],
|
||||
is_complete: bool = False
|
||||
|
@ -589,6 +626,50 @@ class CachedGithubResponse:
|
|||
self.etag = etag
|
||||
self.cached_result = result
|
||||
|
||||
class StreamingDownload:
|
||||
def __init__(self,
|
||||
cmd_helper: CommandHelper,
|
||||
dest: pathlib.Path,
|
||||
download_size: int) -> None:
|
||||
self.cmd_helper = cmd_helper
|
||||
self.ioloop = IOLoop.current()
|
||||
self.name = dest.name
|
||||
self.file_hdl = dest.open('wb')
|
||||
self.download_size = download_size
|
||||
self.total_recd: int = 0
|
||||
self.last_pct: int = 0
|
||||
self.chunk_buffer: List[bytes] = []
|
||||
self.busy_evt: Event = Event()
|
||||
self.busy_evt.set()
|
||||
|
||||
def on_chunk_recd(self, chunk: bytes) -> None:
|
||||
if not chunk:
|
||||
return
|
||||
self.chunk_buffer.append(chunk)
|
||||
if not self.busy_evt.is_set():
|
||||
return
|
||||
self.busy_evt.clear()
|
||||
self.ioloop.spawn_callback(self._process_buffer)
|
||||
|
||||
async def close(self):
|
||||
await self.busy_evt.wait()
|
||||
self.file_hdl.close()
|
||||
|
||||
async def _process_buffer(self):
|
||||
while self.chunk_buffer:
|
||||
chunk = self.chunk_buffer.pop(0)
|
||||
self.total_recd += len(chunk)
|
||||
pct = int(self.total_recd / self.download_size * 100 + .5)
|
||||
with ThreadPoolExecutor(max_workers=1) as tpe:
|
||||
await self.ioloop.run_in_executor(
|
||||
tpe, self.file_hdl.write, chunk)
|
||||
if pct >= self.last_pct + 5:
|
||||
self.last_pct = pct
|
||||
totals = f"{self.total_recd // 1024} KiB / " \
|
||||
f"{self.download_size // 1024} KiB"
|
||||
self.cmd_helper.notify_update_response(
|
||||
f"Downloading {self.name}: {totals} [{pct}%]")
|
||||
self.busy_evt.set()
|
||||
|
||||
class PackageDeploy(BaseDeploy):
|
||||
def __init__(self,
|
||||
|
@ -597,15 +678,16 @@ class PackageDeploy(BaseDeploy):
|
|||
) -> None:
|
||||
super().__init__(config, cmd_helper)
|
||||
self.available_packages: List[str] = []
|
||||
self.refresh_condition: Optional[Condition] = None
|
||||
self.refresh_evt: Optional[Event] = None
|
||||
self.mutex: Lock = Lock()
|
||||
|
||||
async def refresh(self, fetch_packages: bool = True) -> None:
|
||||
# TODO: Use python-apt python lib rather than command line for updates
|
||||
if self.refresh_condition is None:
|
||||
self.refresh_condition = Condition()
|
||||
else:
|
||||
self.refresh_condition.wait()
|
||||
if self.refresh_evt is not None:
|
||||
self.refresh_evt.wait()
|
||||
return
|
||||
async with self.mutex:
|
||||
self.refresh_evt = Event()
|
||||
try:
|
||||
if fetch_packages:
|
||||
await self.cmd_helper.run_cmd(
|
||||
|
@ -623,12 +705,11 @@ class PackageDeploy(BaseDeploy):
|
|||
f"\n{pkg_msg}")
|
||||
except Exception:
|
||||
logging.exception("Error Refreshing System Packages")
|
||||
self.refresh_condition.notify_all()
|
||||
self.refresh_condition = None
|
||||
self.refresh_evt.set()
|
||||
self.refresh_evt = None
|
||||
|
||||
async def update(self) -> None:
|
||||
if self.refresh_condition is not None:
|
||||
self.refresh_condition.wait()
|
||||
async with self.mutex:
|
||||
self.cmd_helper.notify_update_response("Updating packages...")
|
||||
try:
|
||||
await self.cmd_helper.run_cmd(
|
||||
|
@ -638,8 +719,8 @@ class PackageDeploy(BaseDeploy):
|
|||
except Exception:
|
||||
raise self.server.error("Error updating system packages")
|
||||
self.available_packages = []
|
||||
self.cmd_helper.notify_update_response("Package update finished...",
|
||||
is_complete=True)
|
||||
self.cmd_helper.notify_update_response(
|
||||
"Package update finished...", is_complete=True)
|
||||
|
||||
def get_update_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
@ -665,34 +746,37 @@ class WebClientDeploy(BaseDeploy):
|
|||
raise config.error(
|
||||
"Invalid value for option 'persistent_files': "
|
||||
"'.version' can not be persistent")
|
||||
|
||||
self.version: str = "?"
|
||||
self.remote_version: str = "?"
|
||||
self.dl_url: str = "?"
|
||||
self.refresh_condition: Optional[Condition] = None
|
||||
self._get_local_version()
|
||||
self.dl_info: Tuple[str, str, int] = ("?", "?", 0)
|
||||
self.refresh_evt: Optional[Event] = None
|
||||
self.mutex: Lock = Lock()
|
||||
logging.info(f"\nInitializing Client Updater: '{self.name}',"
|
||||
f"\nversion: {self.version}"
|
||||
f"\npath: {self.path}")
|
||||
|
||||
def _get_local_version(self) -> None:
|
||||
async def _get_local_version(self) -> None:
|
||||
version_path = self.path.joinpath(".version")
|
||||
if version_path.is_file():
|
||||
self.version = version_path.read_text().strip()
|
||||
with ThreadPoolExecutor(max_workers=1) as tpe:
|
||||
version = await IOLoop.current().run_in_executor(
|
||||
tpe, version_path.read_text)
|
||||
self.version = version.strip()
|
||||
else:
|
||||
self.version = "?"
|
||||
|
||||
async def refresh(self) -> None:
|
||||
if self.refresh_condition is None:
|
||||
self.refresh_condition = Condition()
|
||||
else:
|
||||
self.refresh_condition.wait()
|
||||
if self.refresh_evt is not None:
|
||||
self.refresh_evt.wait()
|
||||
return
|
||||
async with self.mutex:
|
||||
self.refresh_evt = Event()
|
||||
try:
|
||||
self._get_local_version()
|
||||
await self._get_local_version()
|
||||
await self._get_remote_version()
|
||||
except Exception:
|
||||
logging.exception("Error Refreshing Client")
|
||||
self.refresh_condition.notify_all()
|
||||
self.refresh_condition = None
|
||||
self.refresh_evt.set()
|
||||
self.refresh_evt = None
|
||||
|
||||
async def _get_remote_version(self) -> None:
|
||||
# Remote state
|
||||
|
@ -704,24 +788,28 @@ class WebClientDeploy(BaseDeploy):
|
|||
logging.exception(f"Client {self.repo}: Github Request Error")
|
||||
result = {}
|
||||
self.remote_version = result.get('name', "?")
|
||||
release_assets: Dict[str, Any] = result.get('assets', [{}])[0]
|
||||
self.dl_url = release_assets.get('browser_download_url', "?")
|
||||
release_asset: Dict[str, Any] = result.get('assets', [{}])[0]
|
||||
dl_url: str = release_asset.get('browser_download_url', "?")
|
||||
content_type: str = release_asset.get('content_type', "?")
|
||||
size: int = release_asset.get('size', 0)
|
||||
self.dl_info = (dl_url, content_type, size)
|
||||
logging.info(
|
||||
f"Github client Info Received:\nRepo: {self.name}\n"
|
||||
f"Local Version: {self.version}\n"
|
||||
f"Remote Version: {self.remote_version}\n"
|
||||
f"url: {self.dl_url}")
|
||||
f"url: {dl_url}\n"
|
||||
f"size: {size}\n"
|
||||
f"Content Type: {content_type}")
|
||||
|
||||
async def update(self) -> None:
|
||||
if self.refresh_condition is not None:
|
||||
# wait for refresh if in progess
|
||||
self.refresh_condition.wait()
|
||||
async with self.mutex:
|
||||
if self.remote_version == "?":
|
||||
await self.refresh()
|
||||
await self._get_remote_version()
|
||||
if self.remote_version == "?":
|
||||
raise self.server.error(
|
||||
f"Client {self.repo}: Unable to locate update")
|
||||
if self.dl_url == "?":
|
||||
dl_url, content_type, size = self.dl_info
|
||||
if dl_url == "?":
|
||||
raise self.server.error(
|
||||
f"Client {self.repo}: Invalid download url")
|
||||
if self.version == self.remote_version:
|
||||
|
@ -729,34 +817,52 @@ class WebClientDeploy(BaseDeploy):
|
|||
return
|
||||
self.cmd_helper.notify_update_response(
|
||||
f"Downloading Client: {self.name}")
|
||||
archive = await self.cmd_helper.http_download_request(self.dl_url)
|
||||
with tempfile.TemporaryDirectory(
|
||||
suffix=self.name, prefix="client") as tempdirname:
|
||||
tempdir = pathlib.Path(tempdirname)
|
||||
temp_download_file = tempdir.joinpath(f"{self.name}.zip")
|
||||
temp_persist_dir = tempdir.joinpath(self.name)
|
||||
await self.cmd_helper.streaming_download_request(
|
||||
dl_url, temp_download_file, content_type, size)
|
||||
self.cmd_helper.notify_update_response(
|
||||
f"Download Complete, extracting release to '{self.path}'")
|
||||
with ThreadPoolExecutor(max_workers=1) as tpe:
|
||||
await IOLoop.current().run_in_executor(
|
||||
tpe, self._extract_release, temp_persist_dir,
|
||||
temp_download_file)
|
||||
self.version = self.remote_version
|
||||
version_path = self.path.joinpath(".version")
|
||||
if not version_path.exists():
|
||||
with ThreadPoolExecutor(max_workers=1) as tpe:
|
||||
await IOLoop.current().run_in_executor(
|
||||
tpe, version_path.write_text, self.version)
|
||||
self.cmd_helper.notify_update_response(
|
||||
f"Client Update Finished: {self.name}", is_complete=True)
|
||||
|
||||
def _extract_release(self,
|
||||
persist_dir: pathlib.Path,
|
||||
release_file: pathlib.Path
|
||||
) -> None:
|
||||
if not persist_dir.exists():
|
||||
os.mkdir(persist_dir)
|
||||
if self.path.is_dir():
|
||||
# find and move persistent files
|
||||
for fname in os.listdir(self.path):
|
||||
src_path = self.path.joinpath(fname)
|
||||
if fname in self.persistent_files:
|
||||
dest_dir = tempdir.joinpath(fname).parent
|
||||
dest_dir = persist_dir.joinpath(fname).parent
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
shutil.move(src_path, dest_dir)
|
||||
shutil.move(str(src_path), str(dest_dir))
|
||||
shutil.rmtree(self.path)
|
||||
os.mkdir(self.path)
|
||||
with zipfile.ZipFile(io.BytesIO(archive)) as zf:
|
||||
with zipfile.ZipFile(release_file) as zf:
|
||||
zf.extractall(self.path)
|
||||
# Move temporary files back into
|
||||
for fname in os.listdir(tempdir):
|
||||
src_path = tempdir.joinpath(fname)
|
||||
for fname in os.listdir(persist_dir):
|
||||
src_path = persist_dir.joinpath(fname)
|
||||
dest_dir = self.path.joinpath(fname).parent
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
shutil.move(src_path, dest_dir)
|
||||
self.version = self.remote_version
|
||||
version_path = self.path.joinpath(".version")
|
||||
if not version_path.exists():
|
||||
version_path.write_text(self.version)
|
||||
self.cmd_helper.notify_update_response(
|
||||
f"Client Update Finished: {self.name}", is_complete=True)
|
||||
shutil.move(str(src_path), str(dest_dir))
|
||||
|
||||
def get_update_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
|
|
Loading…
Reference in New Issue