power: move process_request to the device class

Perform the entire request within a lock to prevent rentry until
the request completes.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-02-26 20:10:27 -05:00
parent 0d93cf2c39
commit 970c8a4181
No known key found for this signature in database
GPG Key ID: 7027245FBBDDF59A
1 changed files with 157 additions and 161 deletions

View File

@ -145,7 +145,7 @@ class PrinterPower:
fname = queue[0].get("filename", "unknown") fname = queue[0].get("filename", "unknown")
logging.debug( logging.debug(
f"Job '{fname}' queued, powering on device [{name}]") f"Job '{fname}' queued, powering on device [{name}]")
await self._process_request(dev, "on") await dev.process_request("on")
async def _handle_list_devices(self, async def _handle_list_devices(self,
web_request: WebRequest web_request: WebRequest
@ -169,7 +169,7 @@ class PrinterPower:
if action not in ["on", "off", "toggle"]: if action not in ["on", "off", "toggle"]:
raise self.server.error( raise self.server.error(
f"Invalid requested action '{action}'") f"Invalid requested action '{action}'")
result = await self._process_request(dev, action) result = await dev.process_request(action)
return {dev_name: result} return {dev_name: result}
async def _handle_batch_power_request(self, async def _handle_batch_power_request(self,
@ -184,44 +184,11 @@ class PrinterPower:
req = ep.split("/")[-1] req = ep.split("/")[-1]
for name, device in requested_devs.items(): for name, device in requested_devs.items():
if device is not None: if device is not None:
result[name] = await self._process_request(device, req) result[name] = await device.process_request(req)
else: else:
result[name] = "device_not_found" result[name] = "device_not_found"
return result return result
async def _process_request(self,
device: PowerDevice,
req: str
) -> str:
base_state: str = device.get_state()
ret = device.refresh_status()
if ret is not None:
await ret
cur_state: str = device.get_state()
if req == "toggle":
req = "on" if cur_state == "off" else "off"
if req in ["on", "off"]:
if req == cur_state:
# device is already in requested state, do nothing
if base_state != cur_state:
device.notify_power_changed()
return cur_state
printing = await self._check_klippy_printing()
if device.get_locked_while_printing() and printing:
raise self.server.error(
f"Unable to change power for {device.get_name()} "
"while printing")
ret = device.set_power(req)
if ret is not None:
await ret
cur_state = device.get_state()
await device.process_power_changed()
elif req != "status":
raise self.server.error(f"Unsupported power request: {req}")
elif base_state != cur_state:
device.notify_power_changed()
return cur_state
def set_device_power(self, device: str, state: Union[bool, str]) -> None: def set_device_power(self, device: str, state: Union[bool, str]) -> None:
request: str = "" request: str = ""
if isinstance(state, bool): if isinstance(state, bool):
@ -238,7 +205,7 @@ class PrinterPower:
return return
event_loop = self.server.get_event_loop() event_loop = self.server.get_event_loop()
event_loop.register_callback( event_loop.register_callback(
self._process_request, self.devices[device], request) self.devices[device].process_request, request)
async def add_device(self, name: str, device: PowerDevice) -> None: async def add_device(self, name: str, device: PowerDevice) -> None:
if name in self.devices: if name in self.devices:
@ -265,6 +232,7 @@ class PowerDevice:
self.name = name_parts[1] self.name = name_parts[1]
self.type: str = config.get('type') self.type: str = config.get('type')
self.state: str = "init" self.state: str = "init"
self.request_lock = asyncio.Lock()
self.locked_while_printing = config.getboolean( self.locked_while_printing = config.getboolean(
'locked_while_printing', False) 'locked_while_printing', False)
self.off_when_shutdown = config.getboolean('off_when_shutdown', False) self.off_when_shutdown = config.getboolean('off_when_shutdown', False)
@ -287,6 +255,13 @@ class PowerDevice:
self.on_when_queued = config.getboolean('on_when_upload_queued', self.on_when_queued = config.getboolean('on_when_upload_queued',
False, deprecate=True) False, deprecate=True)
async def _check_klippy_printing(self) -> bool:
kapis: APIComp = self.server.lookup_component('klippy_apis')
result: Dict[str, Any] = await kapis.query_objects(
{'print_stats': None}, default={})
pstate = result.get('print_stats', {}).get('state', "").lower()
return pstate == "printing"
def _is_bound_to_klipper(self): def _is_bound_to_klipper(self):
return ( return (
self.bound_service is not None and self.bound_service is not None and
@ -391,6 +366,37 @@ class PowerDevice:
) )
return None return None
async def process_request(self, req: str) -> str:
async with self.request_lock:
base_state: str = self.state
ret = self.refresh_status()
if ret is not None:
await ret
cur_state: str = self.state
if req == "toggle":
req = "on" if cur_state == "off" else "off"
if req in ["on", "off"]:
if req == cur_state:
# device is already in requested state, do nothing
if base_state != cur_state:
self.notify_power_changed()
return cur_state
printing = await self._check_klippy_printing()
if self.locked_while_printing and printing:
raise self.server.error(
f"Unable to change power for {self.name} "
"while printing")
ret = self.set_power(req)
if ret is not None:
await ret
cur_state = self.state
await self.process_power_changed()
elif req != "status":
raise self.server.error(f"Unsupported power request: {req}")
elif base_state != cur_state:
self.notify_power_changed()
return cur_state
def refresh_status(self) -> Optional[Coroutine]: def refresh_status(self) -> Optional[Coroutine]:
raise NotImplementedError raise NotImplementedError
@ -410,7 +416,6 @@ class HTTPDevice(PowerDevice):
) -> None: ) -> None:
super().__init__(config) super().__init__(config)
self.client = AsyncHTTPClient() self.client = AsyncHTTPClient()
self.request_mutex = asyncio.Lock()
self.addr: str = config.get("address") self.addr: str = config.get("address")
self.port = config.getint("port", default_port) self.port = config.getint("port", default_port)
self.user = config.load_template("user", default_user).render() self.user = config.load_template("user", default_user).render()
@ -419,8 +424,9 @@ class HTTPDevice(PowerDevice):
self.protocol = config.get("protocol", default_protocol) self.protocol = config.get("protocol", default_protocol)
async def initialize(self) -> None: async def initialize(self) -> None:
super().initialize() async with self.request_lock:
await self.refresh_status() super().initialize()
await self.refresh_status()
async def _send_http_command(self, async def _send_http_command(self,
url: str, url: str,
@ -450,26 +456,24 @@ class HTTPDevice(PowerDevice):
"_send_status_request must be implemented by children") "_send_status_request must be implemented by children")
async def refresh_status(self) -> None: async def refresh_status(self) -> None:
async with self.request_mutex: try:
try: state = await self._send_status_request()
state = await self._send_status_request() except Exception:
except Exception: self.state = "error"
self.state = "error" msg = f"Error Refeshing Device Status: {self.name}"
msg = f"Error Refeshing Device Status: {self.name}" logging.exception(msg)
logging.exception(msg) raise self.server.error(msg) from None
raise self.server.error(msg) from None self.state = state
self.state = state
async def set_power(self, state): async def set_power(self, state):
async with self.request_mutex: try:
try: state = await self._send_power_request(state)
state = await self._send_power_request(state) except Exception:
except Exception: self.state = "error"
self.state = "error" msg = f"Error Setting Device Status: {self.name} to {state}"
msg = f"Error Setting Device Status: {self.name} to {state}" logging.exception(msg)
logging.exception(msg) raise self.server.error(msg) from None
raise self.server.error(msg) from None self.state = state
self.state = state
class GpioDevice(PowerDevice): class GpioDevice(PowerDevice):
@ -544,7 +548,6 @@ class KlipperDevice(PowerDevice):
f" for 'klipper_device' [{config.get_name()}]") f" for 'klipper_device' [{config.get_name()}]")
self.is_shutdown: bool = False self.is_shutdown: bool = False
self.update_fut: Optional[asyncio.Future] = None self.update_fut: Optional[asyncio.Future] = None
self.request_mutex = asyncio.Lock()
self.timer: Optional[float] = config.getfloat( self.timer: Optional[float] = config.getfloat(
'timer', None, above=0.000001) 'timer', None, above=0.000001)
self.timer_handle: Optional[asyncio.TimerHandle] = None self.timer_handle: Optional[asyncio.TimerHandle] = None
@ -596,44 +599,42 @@ class KlipperDevice(PowerDevice):
async def refresh_status(self) -> None: async def refresh_status(self) -> None:
if self.is_shutdown or self.state in ["on", "off", "init"]: if self.is_shutdown or self.state in ["on", "off", "init"]:
return return
async with self.request_mutex: kapis: APIComp = self.server.lookup_component('klippy_apis')
kapis: APIComp = self.server.lookup_component('klippy_apis') req: Dict[str, Optional[List[str]]] = {self.object_name: None}
req: Dict[str, Optional[List[str]]] = {self.object_name: None} data: Optional[Dict[str, Any]]
data: Optional[Dict[str, Any]] data = await kapis.query_objects(req, None)
data = await kapis.query_objects(req, None) if not self._validate_data(data):
if not self._validate_data(data): self.state = "error"
self.state = "error" else:
else: assert data is not None
assert data is not None self._set_state_from_data(data)
self._set_state_from_data(data)
async def set_power(self, state: str) -> None: async def set_power(self, state: str) -> None:
if self.is_shutdown: if self.is_shutdown:
raise self.server.error( raise self.server.error(
f"Power Device {self.name}: Cannot set power for device " f"Power Device {self.name}: Cannot set power for device "
f"when Klipper is shutdown") f"when Klipper is shutdown")
async with self.request_mutex: self._reset_timer()
self._reset_timer() eventloop = self.server.get_event_loop()
eventloop = self.server.get_event_loop() self.update_fut = eventloop.create_future()
self.update_fut = eventloop.create_future() try:
try: kapis: APIComp = self.server.lookup_component('klippy_apis')
kapis: APIComp = self.server.lookup_component('klippy_apis') value = "1" if state == "on" else "0"
value = "1" if state == "on" else "0" await kapis.run_gcode(f"{self.gc_cmd} VALUE={value}")
await kapis.run_gcode(f"{self.gc_cmd} VALUE={value}") await asyncio.wait_for(self.update_fut, 1.)
await asyncio.wait_for(self.update_fut, 1.) except TimeoutError:
except TimeoutError: self.state = "error"
self.state = "error" raise self.server.error(
raise self.server.error( f"Power device {self.name}: Timeout "
f"Power device {self.name}: Timeout " "waiting for device state update")
"waiting for device state update") except Exception:
except Exception: self.state = "error"
self.state = "error" msg = f"Error Toggling Device Power: {self.name}"
msg = f"Error Toggling Device Power: {self.name}" logging.exception(msg)
logging.exception(msg) raise self.server.error(msg) from None
raise self.server.error(msg) from None finally:
finally: self.update_fut = None
self.update_fut = None self._check_timer()
self._check_timer()
def _validate_data(self, data: Optional[Dict[str, Any]]) -> bool: def _validate_data(self, data: Optional[Dict[str, Any]]) -> bool:
if data is None: if data is None:
@ -742,7 +743,6 @@ class TPLinkSmartPlug(PowerDevice):
def __init__(self, config: ConfigHelper) -> None: def __init__(self, config: ConfigHelper) -> None:
super().__init__(config) super().__init__(config)
self.timer = config.get("timer", "") self.timer = config.get("timer", "")
self.request_mutex = asyncio.Lock()
addr_and_output_id = config.get("address").split('/') addr_and_output_id = config.get("address").split('/')
self.addr = addr_and_output_id[0] self.addr = addr_and_output_id[0]
if (len(addr_and_output_id) > 1): if (len(addr_and_output_id) > 1):
@ -826,47 +826,46 @@ class TPLinkSmartPlug(PowerDevice):
return res return res
async def initialize(self) -> None: async def initialize(self) -> None:
super().initialize() async with self.request_lock:
await self.refresh_status() super().initialize()
await self.refresh_status()
async def refresh_status(self) -> None: async def refresh_status(self) -> None:
async with self.request_mutex: try:
try: state: str
state: str res = await self._send_tplink_command("info")
res = await self._send_tplink_command("info") if self.output_id is not None:
if self.output_id is not None: # TPLink device controls multiple devices
# TPLink device controls multiple devices children: Dict[int, Any]
children: Dict[int, Any] children = res['system']['get_sysinfo']['children']
children = res['system']['get_sysinfo']['children'] state = children[self.output_id]['state']
state = children[self.output_id]['state'] else:
else: state = res['system']['get_sysinfo']['relay_state']
state = res['system']['get_sysinfo']['relay_state'] except Exception:
except Exception: self.state = "error"
self.state = "error" msg = f"Error Refeshing Device Status: {self.name}"
msg = f"Error Refeshing Device Status: {self.name}" logging.exception(msg)
logging.exception(msg) raise self.server.error(msg) from None
raise self.server.error(msg) from None self.state = "on" if state else "off"
self.state = "on" if state else "off"
async def set_power(self, state) -> None: async def set_power(self, state) -> None:
async with self.request_mutex: err: int
err: int try:
try: if self.timer != "" and state == "off":
if self.timer != "" and state == "off": await self._send_tplink_command("clear_rules")
await self._send_tplink_command("clear_rules") res = await self._send_tplink_command("count_off")
res = await self._send_tplink_command("count_off") err = res['count_down']['add_rule']['err_code']
err = res['count_down']['add_rule']['err_code'] else:
else: res = await self._send_tplink_command(state)
res = await self._send_tplink_command(state) err = res['system']['set_relay_state']['err_code']
err = res['system']['set_relay_state']['err_code'] except Exception:
except Exception: err = 1
err = 1 logging.exception(f"Power Toggle Error: {self.name}")
logging.exception(f"Power Toggle Error: {self.name}") if err:
if err: self.state = "error"
self.state = "error" raise self.server.error(
raise self.server.error( f"Error Toggling Device Power: {self.name}")
f"Error Toggling Device Power: {self.name}") self.state = state
self.state = state
class Tasmota(HTTPDevice): class Tasmota(HTTPDevice):
@ -1143,7 +1142,6 @@ class MQTTDevice(PowerDevice):
self.mqtt.subscribe_topic( self.mqtt.subscribe_topic(
self.state_topic, self._on_state_update, self.qos) self.state_topic, self._on_state_update, self.qos)
self.query_response: Optional[asyncio.Future] = None self.query_response: Optional[asyncio.Future] = None
self.request_mutex = asyncio.Lock()
self.server.register_event_handler( self.server.register_event_handler(
"mqtt:connected", self._on_mqtt_connected) "mqtt:connected", self._on_mqtt_connected)
self.server.register_event_handler( self.server.register_event_handler(
@ -1151,7 +1149,7 @@ class MQTTDevice(PowerDevice):
def _on_state_update(self, payload: bytes) -> None: def _on_state_update(self, payload: bytes) -> None:
last_state = self.state last_state = self.state
in_request = self.request_mutex.locked() in_request = self.request_lock.locked()
err: Optional[Exception] = None err: Optional[Exception] = None
context = { context = {
'payload': payload.decode() 'payload': payload.decode()
@ -1185,7 +1183,7 @@ class MQTTDevice(PowerDevice):
self.query_response.set_result(response) self.query_response.set_result(response)
async def _on_mqtt_connected(self) -> None: async def _on_mqtt_connected(self) -> None:
async with self.request_mutex: async with self.request_lock:
if self.state in ["on", "off"]: if self.state in ["on", "off"]:
return return
self.state = "init" self.state = "init"
@ -1226,7 +1224,7 @@ class MQTTDevice(PowerDevice):
): ):
self.query_response.set_exception( self.query_response.set_exception(
self.server.error("MQTT Disconnected", 503)) self.server.error("MQTT Disconnected", 503))
async with self.request_mutex: async with self.request_lock:
self.state = "error" self.state = "error"
self.notify_power_changed() self.notify_power_changed()
@ -1239,15 +1237,14 @@ class MQTTDevice(PowerDevice):
raise self.server.error( raise self.server.error(
f"MQTT Power Device {self.name}: " f"MQTT Power Device {self.name}: "
"MQTT Not Connected", 503) "MQTT Not Connected", 503)
async with self.request_mutex: self.query_response = self.eventloop.create_future()
self.query_response = self.eventloop.create_future() try:
try: await self._wait_for_update(self.query_response)
await self._wait_for_update(self.query_response) except Exception:
except Exception: logging.exception(f"MQTT Power Device {self.name}: "
logging.exception(f"MQTT Power Device {self.name}: " "Failed to refresh state")
"Failed to refresh state") self.state = "error"
self.state = "error" self.query_response = None
self.query_response = None
async def _wait_for_update(self, fut: asyncio.Future, async def _wait_for_update(self, fut: asyncio.Future,
do_query: bool = True do_query: bool = True
@ -1265,26 +1262,25 @@ class MQTTDevice(PowerDevice):
raise self.server.error( raise self.server.error(
f"MQTT Power Device {self.name}: " f"MQTT Power Device {self.name}: "
"MQTT Not Connected", 503) "MQTT Not Connected", 503)
async with self.request_mutex: self.query_response = self.eventloop.create_future()
self.query_response = self.eventloop.create_future() new_state = "error"
try:
payload = self.cmd_payload.render({'command': state})
await self.mqtt.publish_topic(
self.cmd_topic, payload, self.qos,
retain=self.retain_cmd_state)
new_state = await self._wait_for_update(
self.query_response, do_query=self.must_query)
except Exception:
logging.exception(
f"MQTT Power Device {self.name}: Failed to set state")
new_state = "error" new_state = "error"
try: self.query_response = None
payload = self.cmd_payload.render({'command': state}) self.state = new_state
await self.mqtt.publish_topic( if self.state == "error":
self.cmd_topic, payload, self.qos, raise self.server.error(
retain=self.retain_cmd_state) f"MQTT Power Device {self.name}: Failed to set "
new_state = await self._wait_for_update( f"device to state '{state}'", 500)
self.query_response, do_query=self.must_query)
except Exception:
logging.exception(
f"MQTT Power Device {self.name}: Failed to set state")
new_state = "error"
self.query_response = None
self.state = new_state
if self.state == "error":
raise self.server.error(
f"MQTT Power Device {self.name}: Failed to set "
f"device to state '{state}'", 500)
# The power component has multiple configuration sections # The power component has multiple configuration sections