update_manager: use the http_client component

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-02-28 12:37:50 -05:00
parent 5bf112bcbc
commit faf956c65d
No known key found for this signature in database
GPG Key ID: 7027245FBBDDF59A
2 changed files with 67 additions and 232 deletions

View File

@ -9,14 +9,12 @@ import asyncio
import os
import pathlib
import logging
import json
import sys
import shutil
import zipfile
import time
import tempfile
import re
from tornado.httpclient import AsyncHTTPClient
from thirdparty.packagekit import enums as PkEnum
from .base_deploy import BaseDeploy
from .app_deploy import AppDeploy
@ -37,7 +35,6 @@ from typing import (
cast
)
if TYPE_CHECKING:
from tornado.httpclient import HTTPResponse
from moonraker import Server
from confighelper import ConfigHelper
from websockets import WebRequest
@ -47,6 +44,7 @@ if TYPE_CHECKING:
from components.database import NamespaceWrapper
from components.dbus_manager import DbusManager
from components.machine import Machine
from components.http_client import HttpClient
from eventloop import FlexTimer
from dbus_next import Variant
from dbus_next.aio import ProxyInterface
@ -457,13 +455,14 @@ class UpdateManager:
return "ok"
def close(self) -> None:
self.cmd_helper.close()
if self.refresh_timer is not None:
self.refresh_timer.stop()
class CommandHelper:
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.http_client: HttpClient
self.http_client = self.server.lookup_component("http_client")
self.debug_enabled = config.getboolean('enable_repo_debug', False)
if self.debug_enabled:
logging.warning("UPDATE MANAGER: REPO DEBUG ENABLED")
@ -471,8 +470,6 @@ class CommandHelper:
self.scmd_error = shell_cmd.error
self.build_shell_command = shell_cmd.build_shell_command
self.pkg_updater: Optional[PackageDeploy] = None
self.http_client = AsyncHTTPClient()
self.github_request_cache: Dict[str, CachedGithubResponse] = {}
# database management
db: DBComp = self.server.lookup_component('database')
@ -497,6 +494,9 @@ class CommandHelper:
def get_server(self) -> Server:
return self.server
def get_http_client(self) -> HttpClient:
return self.http_client
def get_refresh_interval(self) -> float:
return self.refresh_interval
@ -531,13 +531,6 @@ class CommandHelper:
def set_package_updater(self, updater: PackageDeploy) -> None:
self.pkg_updater = updater
def get_rate_limit_stats(self) -> Dict[str, Any]:
return {
'github_rate_limit': self.gh_rate_limit,
'github_requests_remaining': self.gh_limit_remaining,
'github_limit_reset_time': self.gh_limit_reset_time,
}
async def run_cmd(self,
cmd: str,
timeout: float = 20.,
@ -568,135 +561,6 @@ class CommandHelper:
sig_idx=sig_idx)
return result
async def github_api_request(self,
url: str,
retries: int = 5
) -> JsonType:
if (
self.gh_limit_reset_time is not None and
self.gh_limit_remaining == 0
):
curtime = time.time()
if curtime < self.gh_limit_reset_time:
raise self.server.error(
f"GitHub Rate Limit Reached\nRequest: {url}\n"
f"Limit Reset Time: {time.ctime(self.gh_limit_remaining)}")
if url in self.github_request_cache:
cached_request = self.github_request_cache[url]
etag: Optional[str] = cached_request.get_etag()
else:
cached_request = CachedGithubResponse()
etag = None
self.github_request_cache[url] = cached_request
headers = {"Accept": "application/vnd.github.v3+json"}
if etag is not None:
headers['If-None-Match'] = etag
for i in range(retries):
error: Optional[Exception] = None
try:
fut = self.http_client.fetch(
url, headers=headers, connect_timeout=5.,
request_timeout=5., raise_error=False)
resp: HTTPResponse
resp = await asyncio.wait_for(fut, 10.)
except Exception:
logging.exception(
f"Error Processing GitHub API request: {url}")
if i + 1 < retries:
await asyncio.sleep(1.)
continue
etag = resp.headers.get('etag', None)
if etag is not None:
if etag[:2] == "W/":
etag = etag[2:]
logging.info(
"GitHub API Request Processed\n"
f"URL: {url}\n"
f"Response Code: {resp.code}\n"
f"Response Reason: {resp.reason}\n"
f"ETag: {etag}")
if resp.code == 403:
error = self.server.error(
f"Forbidden GitHub Request: {resp.reason}")
elif resp.code == 304:
logging.info(f"Github Request not Modified: {url}")
return cached_request.get_cached_result()
if resp.code != 200:
logging.info(
f"Github Request failed: {resp.code} {resp.reason}")
if i + 1 < retries:
await asyncio.sleep(1.)
continue
# Update rate limit on return success
if 'X-Ratelimit-Limit' in resp.headers:
self.gh_rate_limit = int(resp.headers['X-Ratelimit-Limit'])
self.gh_limit_remaining = int(
resp.headers['X-Ratelimit-Remaining'])
self.gh_limit_reset_time = float(
resp.headers['X-Ratelimit-Reset'])
if error is not None:
raise error
decoded = json.loads(resp.body)
if etag is not None:
cached_request.update_result(etag, decoded)
return decoded
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def http_download_request(self,
url: str,
content_type: str,
timeout: float = 180.,
retries: int = 5
) -> bytes:
for i in range(retries):
try:
fut = self.http_client.fetch(
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout)
resp: HTTPResponse
resp = await asyncio.wait_for(fut, timeout + 10.)
except Exception:
logging.exception("Error Processing Download")
if i + 1 == retries:
raise
await asyncio.sleep(1.)
continue
return resp.body
raise self.server.error(
f"Retries exceeded for GitHub API request: {url}")
async def streaming_download_request(self,
url: str,
dest: Union[str, pathlib.Path],
content_type: str,
size: int,
timeout: float = 180.,
retries: int = 5
) -> None:
if isinstance(dest, str):
dest = pathlib.Path(dest)
for i in range(retries):
dl = StreamingDownload(self, dest, size)
try:
fut = self.http_client.fetch(
url, headers={"Accept": content_type},
connect_timeout=5., request_timeout=timeout,
streaming_callback=dl.on_chunk_recd)
resp: HTTPResponse
resp = await asyncio.wait_for(fut, timeout + 10.)
except Exception:
logging.exception("Error Processing Download")
if i + 1 == retries:
raise
await asyncio.sleep(1.)
continue
finally:
await dl.close()
if resp.code < 400:
return
raise self.server.error(f"Retries exceeded for request: {url}")
def notify_update_response(self,
resp: Union[str, bytes],
is_complete: bool = False
@ -723,66 +587,20 @@ class CommandHelper:
return
await self.pkg_updater.install_packages(package_list, **kwargs)
def close(self) -> None:
self.http_client.close()
def get_rate_limit_stats(self):
return self.http_client.github_api_stats()
class CachedGithubResponse:
def __init__(self) -> None:
self.etag: Optional[str] = None
self.cached_result: JsonType = {}
def get_etag(self) -> Optional[str]:
return self.etag
def get_cached_result(self) -> JsonType:
return self.cached_result
def update_result(self, etag: str, result: JsonType) -> None:
self.etag = etag
self.cached_result = result
class StreamingDownload:
def __init__(self,
cmd_helper: CommandHelper,
dest: pathlib.Path,
download_size: int) -> None:
self.cmd_helper = cmd_helper
self.event_loop = cmd_helper.get_server().get_event_loop()
self.name = dest.name
self.file_hdl = dest.open('wb')
self.download_size = download_size
self.total_recd: int = 0
self.last_pct: int = 0
self.chunk_buffer: List[bytes] = []
self.busy_evt: asyncio.Event = asyncio.Event()
self.busy_evt.set()
def on_chunk_recd(self, chunk: bytes) -> None:
if not chunk:
return
self.chunk_buffer.append(chunk)
if not self.busy_evt.is_set():
return
self.busy_evt.clear()
self.event_loop.register_callback(self._process_buffer)
async def close(self):
await self.busy_evt.wait()
self.file_hdl.close()
async def _process_buffer(self):
while self.chunk_buffer:
chunk = self.chunk_buffer.pop(0)
self.total_recd += len(chunk)
pct = int(self.total_recd / self.download_size * 100 + .5)
await self.event_loop.run_in_thread(self.file_hdl.write, chunk)
if pct >= self.last_pct + 5:
self.last_pct = pct
totals = f"{self.total_recd // 1024} KiB / " \
f"{self.download_size // 1024} KiB"
self.cmd_helper.notify_update_response(
f"Downloading {self.name}: {totals} [{pct}%]")
self.busy_evt.set()
def on_download_progress(self,
progress: int,
download_size: int,
downloaded: int
) -> None:
totals = (
f"{downloaded // 1024} KiB / "
f"{download_size// 1024} KiB"
)
self.notify_update_response(
f"Downloading {self.cur_update_app}: {totals} [{progress}%]")
class PackageDeploy(BaseDeploy):
def __init__(self,
@ -1334,22 +1152,32 @@ class WebClientDeploy(BaseDeploy):
async def _get_remote_version(self) -> None:
# Remote state
url = f"https://api.github.com/repos/{self.repo}/releases"
try:
releases = await self.cmd_helper.github_api_request(url)
assert isinstance(releases, list)
except Exception:
logging.exception(f"Client {self.repo}: Github Request Error")
releases = []
result: Dict[str, Any] = {}
for release in releases:
if self.channel == "stable":
if not release['prerelease']:
result = release
break
resource = f"repos/{self.repo}/releases/latest"
else:
resource = f"repos/{self.repo}/releases?per_page=1"
client = self.cmd_helper.get_http_client()
resp = await client.github_api_request(resource, attempts=3)
release: Union[List[Any], Dict[str, Any]] = {}
if resp.status_code == 304:
if self.remote_version == "?" and resp.content:
# Not modified, however we need to restore state from
# cached content
release = resp.json()
else:
# Either not necessary or not possible to restore from cache
return
elif resp.has_error():
logging.info(
f"Client {self.repo}: Github Request Error - {resp.error}")
else:
release = resp.json()
result: Dict[str, Any] = {}
if isinstance(release, list):
if release:
result = release[0]
else:
result = release
break
self.remote_version = result.get('name', "?")
release_asset: Dict[str, Any] = result.get('assets', [{}])[0]
dl_url: str = release_asset.get('browser_download_url', "?")
@ -1395,8 +1223,10 @@ class WebClientDeploy(BaseDeploy):
tempdir = pathlib.Path(tempdirname)
temp_download_file = tempdir.joinpath(f"{self.name}.zip")
temp_persist_dir = tempdir.joinpath(self.name)
await self.cmd_helper.streaming_download_request(
dl_url, temp_download_file, content_type, size)
client = self.cmd_helper.get_http_client()
await client.download_file(
dl_url, content_type, temp_download_file, size,
self.cmd_helper.on_download_progress)
self.cmd_helper.notify_update_response(
f"Download Complete, extracting release to '{self.path}'")
await event_loop.run_in_thread(

View File

@ -185,9 +185,9 @@ class ZipDeploy(AppDeploy):
f"Host repo mismatch, received: {host_repo}, "
f"expected: {self.host_repo}. This could result in "
" a failed update.")
url = f"https://api.github.com/repos/{self.host_repo}/releases"
resource = f"repos/{self.host_repo}/releases"
current_release, latest_release = await self._fetch_github_releases(
url, release_tag)
resource, release_tag)
await self._validate_current_release(release_info, current_release)
if not self.errors:
self.verified = True
@ -196,11 +196,14 @@ class ZipDeploy(AppDeploy):
self._log_zipapp_info()
async def _fetch_github_releases(self,
release_url: str,
resource: str,
current_tag: Optional[str] = None
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
try:
releases = await self.cmd_helper.github_api_request(release_url)
client = self.cmd_helper.get_http_client()
resp = await client.github_api_request(resource, attempts=3)
resp.raise_for_status()
releases = resp.json()
assert isinstance(releases, list)
except Exception:
self.log_exc("Error fetching releases from GitHub")
@ -237,8 +240,8 @@ class ZipDeploy(AppDeploy):
self._add_error(
"RELEASE_INFO not found in current release assets")
info_url, content_type, size = asset_info['RELEASE_INFO']
rinfo_bytes = await self.cmd_helper.http_download_request(
info_url, content_type)
client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(info_url, content_type)
github_rinfo: Dict[str, Any] = json.loads(rinfo_bytes)
if github_rinfo.get(self.name, {}) != release_info:
self._add_error(
@ -255,8 +258,8 @@ class ZipDeploy(AppDeploy):
asset_info = self._get_asset_urls(release, asset_names)
if "RELEASE_INFO" in asset_info:
asset_url, content_type, size = asset_info['RELEASE_INFO']
rinfo_bytes = await self.cmd_helper.http_download_request(
asset_url, content_type)
client = self.cmd_helper.get_http_client()
rinfo_bytes = await client.get_file(asset_url, content_type)
update_release_info: Dict[str, Any] = json.loads(rinfo_bytes)
update_info = update_release_info.get(self.name, {})
self.lastest_hash = update_info.get('commit_hash', "?")
@ -272,8 +275,8 @@ class ZipDeploy(AppDeploy):
# Only report commit log if versions change
if "COMMIT_LOG" in asset_info:
asset_url, content_type, size = asset_info['COMMIT_LOG']
commit_bytes = await self.cmd_helper.http_download_request(
asset_url, content_type)
client = self.cmd_helper.get_http_client()
commit_bytes = await client.get_file(asset_url, content_type)
commit_info: Dict[str, Any] = json.loads(commit_bytes)
self.commit_log = commit_info.get(self.name, [])
if zip_file_name in asset_info:
@ -379,8 +382,10 @@ class ZipDeploy(AppDeploy):
suffix=self.name, prefix="app") as tempdirname:
tempdir = pathlib.Path(tempdirname)
temp_download_file = tempdir.joinpath(f"{self.name}.zip")
await self.cmd_helper.streaming_download_request(
dl_url, temp_download_file, content_type, size)
client = self.cmd_helper.get_http_client()
await client.download_file(
dl_url, content_type, temp_download_file, size,
self.cmd_helper.on_download_progress)
self.notify_status(
f"Download Complete, extracting release to '{self.path}'")
event_loop = self.server.get_event_loop()
@ -396,8 +401,8 @@ class ZipDeploy(AppDeploy):
hard: bool = False,
force_dep_update: bool = False
) -> None:
url = f"https://api.github.com/repos/{self.host_repo}/releases"
releases = await self._fetch_github_releases(url)
res = f"repos/{self.host_repo}/releases"
releases = await self._fetch_github_releases(res)
await self._process_latest_release(releases[1])
await self.update(force_dep_update=force_dep_update)