diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 84a38a1..77bafc3 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -15,10 +15,11 @@ class Sentinel: pass class WebRequest: - def __init__(self, endpoint, args, action=""): + def __init__(self, endpoint, args, action="", conn=None): self.endpoint = endpoint self.action = action self.args = args + self.conn = conn def get_endpoint(self): return self.endpoint @@ -29,6 +30,9 @@ class WebRequest: def get_args(self): return self.args + def get_connection(self): + return self.conn + def _get_converted_arg(self, key, default=Sentinel, dtype=str): if key not in self.args: if default == Sentinel: @@ -79,7 +83,7 @@ class JsonRPC: def remove_method(self, name): self.methods.pop(name, None) - async def dispatch(self, data): + async def dispatch(self, data, ws): response = None try: request = json.loads(data) @@ -92,19 +96,19 @@ class JsonRPC: if isinstance(request, list): response = [] for req in request: - resp = await self.process_request(req) + resp = await self.process_request(req, ws) if resp is not None: response.append(resp) if not response: response = None else: - response = await self.process_request(request) + response = await self.process_request(request, ws) if response is not None: response = json.dumps(response) logging.debug("Websocket Response::" + response) return response - async def process_request(self, request): + async def process_request(self, request, ws): req_id = request.get('id', None) rpc_version = request.get('jsonrpc', "") method_name = request.get('method', None) @@ -116,18 +120,20 @@ class JsonRPC: if 'params' in request: params = request['params'] if isinstance(params, list): - response = await self.execute_method(method, req_id, *params) + response = await self.execute_method( + method, req_id, ws, *params) elif isinstance(params, dict): - response = await self.execute_method(method, req_id, **params) + response = await self.execute_method( + method, req_id, ws, **params) else: return self.build_error(-32600, "Invalid Request", req_id) else: - response = await self.execute_method(method, req_id) + response = await self.execute_method(method, req_id, ws) return response - async def execute_method(self, method, req_id, *args, **kwargs): + async def execute_method(self, method, req_id, ws, *args, **kwargs): try: - result = await method(*args, **kwargs) + result = await method(ws, *args, **kwargs) except TypeError as e: return self.build_error(-32603, f"Invalid params:\n{e}", req_id) except ServerError as e: @@ -209,16 +215,16 @@ class WebsocketManager: self.rpc.remove_method(ws_method) def _generate_callback(self, endpoint): - async def func(**kwargs): + async def func(ws, **kwargs): result = await self.server.make_request( - WebRequest(endpoint, kwargs)) + WebRequest(endpoint, kwargs, conn=ws)) return result return func def _generate_local_callback(self, endpoint, request_method, callback): - async def func(**kwargs): + async def func(ws, **kwargs): result = await callback( - WebRequest(endpoint, kwargs, request_method)) + WebRequest(endpoint, kwargs, request_method, ws)) return result return func @@ -274,7 +280,7 @@ class WebSocket(WebSocketHandler): async def _process_message(self, message): try: - response = await self.rpc.dispatch(message) + response = await self.rpc.dispatch(message, self) if response is not None: self.write_message(response) except Exception: