update_manager: combine "init" locks into a single event.

It isn't necessary for each updater to have their own init lock. Combine them all into a single Event that is set after the "initialize_updaters" method completes.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-03-18 06:29:32 -04:00
parent 5e99378466
commit 0c455fcc0d
1 changed files with 4 additions and 41 deletions

View File

@ -78,6 +78,7 @@ class UpdateManager:
% (client_type, section)) % (client_type, section))
self.cmd_request_lock = Lock() self.cmd_request_lock = Lock()
self.initialized_lock = Event()
self.is_refreshing = False self.is_refreshing = False
# Auto Status Refresh # Auto Status Refresh
@ -125,6 +126,7 @@ class UpdateManager:
if asyncio.iscoroutine(ret): if asyncio.iscoroutine(ret):
await ret await ret
self.is_refreshing = False self.is_refreshing = False
self.initialized_lock.set()
async def _set_klipper_repo(self): async def _set_klipper_repo(self):
kinfo = self.server.get_klippy_info() kinfo = self.server.get_klippy_info()
@ -187,6 +189,7 @@ class UpdateManager:
self.server.send_event("update_manager:update_refreshed", uinfo) self.server.send_event("update_manager:update_refreshed", uinfo)
async def _handle_update_request(self, web_request): async def _handle_update_request(self, web_request):
await self.initialized_lock.wait()
if await self._check_klippy_printing(): if await self._check_klippy_printing():
raise self.server.error("Update Refused: Klippy is printing") raise self.server.error("Update Refused: Klippy is printing")
app = web_request.get_endpoint().split("/")[-1] app = web_request.get_endpoint().split("/")[-1]
@ -213,6 +216,7 @@ class UpdateManager:
return "ok" return "ok"
async def _handle_status_request(self, web_request): async def _handle_status_request(self, web_request):
await self.initialized_lock.wait()
check_refresh = web_request.get_boolean('refresh', False) check_refresh = web_request.get_boolean('refresh', False)
# Don't refresh if a print is currently in progress or # Don't refresh if a print is currently in progress or
# if an update is in progress. Just return the current # if an update is in progress. Just return the current
@ -230,7 +234,6 @@ class UpdateManager:
vinfo = {} vinfo = {}
try: try:
for name, updater in list(self.updaters.items()): for name, updater in list(self.updaters.items()):
await updater.check_initialized(120.)
if need_refresh: if need_refresh:
ret = updater.refresh() ret = updater.refresh()
if asyncio.iscoroutine(ret): if asyncio.iscoroutine(ret):
@ -269,7 +272,6 @@ class CommandHelper:
self.gh_rate_limit = None self.gh_rate_limit = None
self.gh_limit_remaining = None self.gh_limit_remaining = None
self.gh_limit_reset_time = None self.gh_limit_reset_time = None
self.gh_init_evt = Event()
# Update In Progress Tracking # Update In Progress Tracking
self.cur_update_app = self.cur_update_id = None self.cur_update_app = self.cur_update_id = None
@ -321,7 +323,6 @@ class CommandHelper:
f"Rate Limit Reset Time: {reset_time}, " f"Rate Limit Reset Time: {reset_time}, "
f"Seconds Since Epoch: {self.gh_limit_reset_time}") f"Seconds Since Epoch: {self.gh_limit_reset_time}")
break break
self.gh_init_evt.set()
async def run_cmd(self, cmd, timeout=20., notify=False, async def run_cmd(self, cmd, timeout=20., notify=False,
retries=1, env=None): retries=1, env=None):
@ -340,14 +341,6 @@ class CommandHelper:
return result return result
async def github_api_request(self, url, etag=None, is_init=False): async def github_api_request(self, url, etag=None, is_init=False):
if not is_init:
timeout = time.time() + 30.
try:
await self.gh_init_evt.wait(timeout)
except Exception:
raise self.server.error(
"Timeout while waiting for GitHub "
"API Rate Limit initialization")
if self.gh_limit_remaining == 0: if self.gh_limit_remaining == 0:
curtime = time.time() curtime = time.time()
if curtime < self.gh_limit_reset_time: if curtime < self.gh_limit_reset_time:
@ -453,7 +446,6 @@ class GitUpdater:
self.repo_path = path self.repo_path = path
origin = config.get("origin").lower() origin = config.get("origin").lower()
self.repo = GitRepo(cmd_helper, path, self.name, origin) self.repo = GitRepo(cmd_helper, path, self.name, origin)
self.init_evt = Event()
self.debug = self.cmd_helper.is_debug_enabled() self.debug = self.cmd_helper.is_debug_enabled()
self.env = config.get("env", env) self.env = config.get("env", env)
dist_packages = None dist_packages = None
@ -531,19 +523,11 @@ class GitUpdater:
logging.debug(log_msg) logging.debug(log_msg)
self.cmd_helper.notify_update_response(log_msg, is_complete) self.cmd_helper.notify_update_response(log_msg, is_complete)
async def check_initialized(self, timeout=None):
if self.init_evt.is_set():
return
if timeout is not None:
timeout = IOLoop.current().time() + timeout
await self.init_evt.wait(timeout)
async def refresh(self): async def refresh(self):
try: try:
await self._update_repo_state() await self._update_repo_state()
except Exception: except Exception:
logging.exception("Error Refreshing git state") logging.exception("Error Refreshing git state")
self.init_evt.set()
async def _update_repo_state(self, need_fetch=True): async def _update_repo_state(self, need_fetch=True):
self.is_valid = False self.is_valid = False
@ -564,7 +548,6 @@ class GitUpdater:
self._log_info("Validity check for git repo passed") self._log_info("Validity check for git repo passed")
async def update(self, update_deps=False): async def update(self, update_deps=False):
await self.check_initialized(20.)
await self.repo.wait_for_init() await self.repo.wait_for_init()
if not self.is_valid: if not self.is_valid:
raise self._log_exc("Update aborted, repo not valid", False) raise self._log_exc("Update aborted, repo not valid", False)
@ -1237,7 +1220,6 @@ class PackageUpdater:
self.server = cmd_helper.get_server() self.server = cmd_helper.get_server()
self.cmd_helper = cmd_helper self.cmd_helper = cmd_helper
self.available_packages = [] self.available_packages = []
self.init_evt = Event()
self.refresh_condition = None self.refresh_condition = None
async def refresh(self, fetch_packages=True): async def refresh(self, fetch_packages=True):
@ -1264,19 +1246,10 @@ class PackageUpdater:
f"\n{pkg_list}") f"\n{pkg_list}")
except Exception: except Exception:
logging.exception("Error Refreshing System Packages") logging.exception("Error Refreshing System Packages")
self.init_evt.set()
self.refresh_condition.notify_all() self.refresh_condition.notify_all()
self.refresh_condition = None self.refresh_condition = None
async def check_initialized(self, timeout=None):
if self.init_evt.is_set():
return
if timeout is not None:
timeout = IOLoop.current().time() + timeout
await self.init_evt.wait(timeout)
async def update(self, *args): async def update(self, *args):
await self.check_initialized(20.)
if self.refresh_condition is not None: if self.refresh_condition is not None:
self.refresh_condition.wait() self.refresh_condition.wait()
self.cmd_helper.notify_update_response("Updating packages...") self.cmd_helper.notify_update_response("Updating packages...")
@ -1319,7 +1292,6 @@ class WebUpdater:
self.version = self.remote_version = self.dl_url = "?" self.version = self.remote_version = self.dl_url = "?"
self.etag = None self.etag = None
self.init_evt = Event()
self.refresh_condition = None self.refresh_condition = None
self._get_local_version() self._get_local_version()
logging.info(f"\nInitializing Client Updater: '{self.name}'," logging.info(f"\nInitializing Client Updater: '{self.name}',"
@ -1333,13 +1305,6 @@ class WebUpdater:
v = f.read() v = f.read()
self.version = v.strip() self.version = v.strip()
async def check_initialized(self, timeout=None):
if self.init_evt.is_set():
return
if timeout is not None:
timeout = IOLoop.current().time() + timeout
await self.init_evt.wait(timeout)
async def refresh(self): async def refresh(self):
if self.refresh_condition is None: if self.refresh_condition is None:
self.refresh_condition = Condition() self.refresh_condition = Condition()
@ -1351,7 +1316,6 @@ class WebUpdater:
await self._get_remote_version() await self._get_remote_version()
except Exception: except Exception:
logging.exception("Error Refreshing Client") logging.exception("Error Refreshing Client")
self.init_evt.set()
self.refresh_condition.notify_all() self.refresh_condition.notify_all()
self.refresh_condition = None self.refresh_condition = None
@ -1378,7 +1342,6 @@ class WebUpdater:
f"url: {self.dl_url}") f"url: {self.dl_url}")
async def update(self, *args): async def update(self, *args):
await self.check_initialized(20.)
if self.refresh_condition is not None: if self.refresh_condition is not None:
# wait for refresh if in progess # wait for refresh if in progess
self.refresh_condition.wait() self.refresh_condition.wait()