klippy_connection: apply a mutex to the subscription request
The subscripition request is reentrant in Klippy. Sending multiple requests from the same connection may create a race condition, so use a lock to prevent reentry. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
7c8d68c0a1
commit
c3697d0656
|
@ -65,6 +65,7 @@ class KlippyConnection:
|
|||
# Connection State
|
||||
self.connection_task: Optional[asyncio.Task] = None
|
||||
self.closing: bool = False
|
||||
self.subscription_lock = asyncio.Lock()
|
||||
self._klippy_info: Dict[str, Any] = {}
|
||||
self._klippy_identified: bool = False
|
||||
self._klippy_initializing: bool = False
|
||||
|
@ -524,57 +525,65 @@ class KlippyConnection:
|
|||
async def _request_subscripton(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, Any]:
|
||||
args = web_request.get_args()
|
||||
conn = web_request.get_subscribable()
|
||||
async with self.subscription_lock:
|
||||
args = web_request.get_args()
|
||||
conn = web_request.get_subscribable()
|
||||
|
||||
# Build the subscription request from a superset of all client
|
||||
# subscriptions
|
||||
sub = args.get('objects', {})
|
||||
if conn is None:
|
||||
raise self.server.error(
|
||||
"No connection associated with subscription request")
|
||||
self.subscriptions[conn] = sub
|
||||
all_subs: Dict[str, Any] = {}
|
||||
# request superset of all client subscriptions
|
||||
for sub in self.subscriptions.values():
|
||||
for obj, items in sub.items():
|
||||
if obj in all_subs:
|
||||
pi = all_subs[obj]
|
||||
if items is None or pi is None:
|
||||
all_subs[obj] = None
|
||||
# Build the subscription request from a superset of all client
|
||||
# subscriptions
|
||||
sub = args.get('objects', {})
|
||||
if conn is None:
|
||||
raise self.server.error(
|
||||
"No connection associated with subscription request")
|
||||
self.subscriptions[conn] = sub
|
||||
all_subs: Dict[str, Any] = {}
|
||||
# request superset of all client subscriptions
|
||||
for sub in self.subscriptions.values():
|
||||
for obj, items in sub.items():
|
||||
if obj in all_subs:
|
||||
pi = all_subs[obj]
|
||||
if items is None or pi is None:
|
||||
all_subs[obj] = None
|
||||
else:
|
||||
uitems = list(set(pi) | set(items))
|
||||
all_subs[obj] = uitems
|
||||
else:
|
||||
uitems = list(set(pi) | set(items))
|
||||
all_subs[obj] = uitems
|
||||
else:
|
||||
all_subs[obj] = items
|
||||
args['objects'] = all_subs
|
||||
args['response_template'] = {'method': "process_status_update"}
|
||||
all_subs[obj] = items
|
||||
args['objects'] = all_subs
|
||||
args['response_template'] = {'method': "process_status_update"}
|
||||
|
||||
result = await self._request_standard(web_request)
|
||||
result = await self._request_standard(web_request, 20.0)
|
||||
|
||||
# prune the status response
|
||||
pruned_status = {}
|
||||
all_status = result['status']
|
||||
sub = self.subscriptions.get(conn, {})
|
||||
for obj, fields in all_status.items():
|
||||
if obj in sub:
|
||||
valid_fields = sub[obj]
|
||||
if valid_fields is None:
|
||||
pruned_status[obj] = fields
|
||||
else:
|
||||
pruned_status[obj] = {k: v for k, v in fields.items()
|
||||
if k in valid_fields}
|
||||
result['status'] = pruned_status
|
||||
return result
|
||||
# prune the status response
|
||||
pruned_status = {}
|
||||
all_status: Dict[str, Any] = result['status']
|
||||
sub = self.subscriptions.get(conn, {})
|
||||
for obj, fields in all_status.items():
|
||||
if obj in sub:
|
||||
valid_fields = sub[obj]
|
||||
if valid_fields is None:
|
||||
pruned_status[obj] = fields
|
||||
else:
|
||||
pruned_status[obj] = {
|
||||
k: v for k, v in fields.items() if k in valid_fields
|
||||
}
|
||||
result['status'] = pruned_status
|
||||
return result
|
||||
|
||||
async def _request_standard(self, web_request: WebRequest) -> Any:
|
||||
async def _request_standard(
|
||||
self, web_request: WebRequest, timeout: Optional[float] = None
|
||||
) -> Any:
|
||||
rpc_method = web_request.get_endpoint()
|
||||
args = web_request.get_args()
|
||||
# Create a base klippy request
|
||||
base_request = KlippyRequest(rpc_method, args)
|
||||
self.pending_requests[base_request.id] = base_request
|
||||
self.event_loop.register_callback(self._write_request, base_request)
|
||||
return await base_request.wait()
|
||||
try:
|
||||
return await asyncio.wait_for(base_request.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self.pending_requests.pop(base_request.id, None)
|
||||
raise self.server.error("Klippy request timed out", 500)
|
||||
|
||||
def remove_subscription(self, conn: Subscribable) -> None:
|
||||
self.subscriptions.pop(conn, None)
|
||||
|
|
Loading…
Reference in New Issue