klippy_connection: apply a mutex to the subscription request
The subscripition request is reentrant in Klippy. Sending multiple requests from the same connection may create a race condition, so use a lock to prevent reentry. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
7c8d68c0a1
commit
c3697d0656
|
@ -65,6 +65,7 @@ class KlippyConnection:
|
||||||
# Connection State
|
# Connection State
|
||||||
self.connection_task: Optional[asyncio.Task] = None
|
self.connection_task: Optional[asyncio.Task] = None
|
||||||
self.closing: bool = False
|
self.closing: bool = False
|
||||||
|
self.subscription_lock = asyncio.Lock()
|
||||||
self._klippy_info: Dict[str, Any] = {}
|
self._klippy_info: Dict[str, Any] = {}
|
||||||
self._klippy_identified: bool = False
|
self._klippy_identified: bool = False
|
||||||
self._klippy_initializing: bool = False
|
self._klippy_initializing: bool = False
|
||||||
|
@ -524,57 +525,65 @@ class KlippyConnection:
|
||||||
async def _request_subscripton(self,
|
async def _request_subscripton(self,
|
||||||
web_request: WebRequest
|
web_request: WebRequest
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
args = web_request.get_args()
|
async with self.subscription_lock:
|
||||||
conn = web_request.get_subscribable()
|
args = web_request.get_args()
|
||||||
|
conn = web_request.get_subscribable()
|
||||||
|
|
||||||
# Build the subscription request from a superset of all client
|
# Build the subscription request from a superset of all client
|
||||||
# subscriptions
|
# subscriptions
|
||||||
sub = args.get('objects', {})
|
sub = args.get('objects', {})
|
||||||
if conn is None:
|
if conn is None:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
"No connection associated with subscription request")
|
"No connection associated with subscription request")
|
||||||
self.subscriptions[conn] = sub
|
self.subscriptions[conn] = sub
|
||||||
all_subs: Dict[str, Any] = {}
|
all_subs: Dict[str, Any] = {}
|
||||||
# request superset of all client subscriptions
|
# request superset of all client subscriptions
|
||||||
for sub in self.subscriptions.values():
|
for sub in self.subscriptions.values():
|
||||||
for obj, items in sub.items():
|
for obj, items in sub.items():
|
||||||
if obj in all_subs:
|
if obj in all_subs:
|
||||||
pi = all_subs[obj]
|
pi = all_subs[obj]
|
||||||
if items is None or pi is None:
|
if items is None or pi is None:
|
||||||
all_subs[obj] = None
|
all_subs[obj] = None
|
||||||
|
else:
|
||||||
|
uitems = list(set(pi) | set(items))
|
||||||
|
all_subs[obj] = uitems
|
||||||
else:
|
else:
|
||||||
uitems = list(set(pi) | set(items))
|
all_subs[obj] = items
|
||||||
all_subs[obj] = uitems
|
args['objects'] = all_subs
|
||||||
else:
|
args['response_template'] = {'method': "process_status_update"}
|
||||||
all_subs[obj] = items
|
|
||||||
args['objects'] = all_subs
|
|
||||||
args['response_template'] = {'method': "process_status_update"}
|
|
||||||
|
|
||||||
result = await self._request_standard(web_request)
|
result = await self._request_standard(web_request, 20.0)
|
||||||
|
|
||||||
# prune the status response
|
# prune the status response
|
||||||
pruned_status = {}
|
pruned_status = {}
|
||||||
all_status = result['status']
|
all_status: Dict[str, Any] = result['status']
|
||||||
sub = self.subscriptions.get(conn, {})
|
sub = self.subscriptions.get(conn, {})
|
||||||
for obj, fields in all_status.items():
|
for obj, fields in all_status.items():
|
||||||
if obj in sub:
|
if obj in sub:
|
||||||
valid_fields = sub[obj]
|
valid_fields = sub[obj]
|
||||||
if valid_fields is None:
|
if valid_fields is None:
|
||||||
pruned_status[obj] = fields
|
pruned_status[obj] = fields
|
||||||
else:
|
else:
|
||||||
pruned_status[obj] = {k: v for k, v in fields.items()
|
pruned_status[obj] = {
|
||||||
if k in valid_fields}
|
k: v for k, v in fields.items() if k in valid_fields
|
||||||
result['status'] = pruned_status
|
}
|
||||||
return result
|
result['status'] = pruned_status
|
||||||
|
return result
|
||||||
|
|
||||||
async def _request_standard(self, web_request: WebRequest) -> Any:
|
async def _request_standard(
|
||||||
|
self, web_request: WebRequest, timeout: Optional[float] = None
|
||||||
|
) -> Any:
|
||||||
rpc_method = web_request.get_endpoint()
|
rpc_method = web_request.get_endpoint()
|
||||||
args = web_request.get_args()
|
args = web_request.get_args()
|
||||||
# Create a base klippy request
|
# Create a base klippy request
|
||||||
base_request = KlippyRequest(rpc_method, args)
|
base_request = KlippyRequest(rpc_method, args)
|
||||||
self.pending_requests[base_request.id] = base_request
|
self.pending_requests[base_request.id] = base_request
|
||||||
self.event_loop.register_callback(self._write_request, base_request)
|
self.event_loop.register_callback(self._write_request, base_request)
|
||||||
return await base_request.wait()
|
try:
|
||||||
|
return await asyncio.wait_for(base_request.wait(), timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.pending_requests.pop(base_request.id, None)
|
||||||
|
raise self.server.error("Klippy request timed out", 500)
|
||||||
|
|
||||||
def remove_subscription(self, conn: Subscribable) -> None:
|
def remove_subscription(self, conn: Subscribable) -> None:
|
||||||
self.subscriptions.pop(conn, None)
|
self.subscriptions.pop(conn, None)
|
||||||
|
|
Loading…
Reference in New Issue