diff --git a/moonraker/websockets.py b/moonraker/websockets.py index e727d56..746f949 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -16,6 +16,7 @@ from utils import ServerError, SentinelClass from typing import ( TYPE_CHECKING, Any, + Awaitable, Optional, Callable, Coroutine, @@ -165,42 +166,47 @@ class JsonRPC: ) -> Optional[str]: response: Any = None try: - request: Union[Dict[str, Any], List[dict]] = json.loads(data) + obj: Union[Dict[str, Any], List[dict]] = json.loads(data) except Exception: msg = f"{self.transport} data not json: {data}" logging.exception(msg) response = self.build_error(-32700, "Parse error") return json.dumps(response) - logging.debug(f"{self.transport} Request::{data}") - if isinstance(request, list): + logging.debug(f"{self.transport} Received::{data}") + if isinstance(obj, list): response = [] - for req in request: - resp = await self.process_request(req, conn) + for item in obj: + resp = await self.process_object(item, conn) if resp is not None: response.append(resp) if not response: response = None else: - response = await self.process_request(request, conn) + response = await self.process_object(obj, conn) if response is not None: response = json.dumps(response) logging.debug(f"{self.transport} Response::{response}") return response - async def process_request(self, - request: Dict[str, Any], - conn: Optional[WebSocket] - ) -> Optional[Dict[str, Any]]: - req_id: Optional[int] = request.get('id', None) - rpc_version: str = request.get('jsonrpc', "") - method_name = request.get('method', None) - if rpc_version != "2.0" or not isinstance(method_name, str): + async def process_object(self, + obj: Dict[str, Any], + conn: Optional[WebSocket] + ) -> Optional[Dict[str, Any]]: + req_id: Optional[int] = obj.get('id', None) + rpc_version: str = obj.get('jsonrpc', "") + if rpc_version != "2.0": + return self.build_error(-32600, "Invalid Request", req_id) + method_name = obj.get('method', SENTINEL) + if method_name is SENTINEL: + self.process_response(obj, conn) + return None + if not isinstance(method_name, str): return self.build_error(-32600, "Invalid Request", req_id) method = self.methods.get(method_name, None) if method is None: return self.build_error(-32601, "Method not found", req_id) - if 'params' in request: - params = request['params'] + if 'params' in obj: + params = obj['params'] if isinstance(params, list): response = await self.execute_method( method, req_id, conn, *params) @@ -213,6 +219,29 @@ class JsonRPC: response = await self.execute_method(method, req_id, conn) return response + def process_response( + self, obj: Dict[str, Any], conn: Optional[WebSocket] + ) -> None: + if conn is None: + logging.debug(f"RPC Response to non-socket request: {obj}") + return + response_id = obj.get("id") + if response_id is None: + logging.debug(f"RPC Response with null ID: {obj}") + return + result = obj.get("result") + if result is None: + error = obj.get("error") + msg = f"Invalid RPC Response: {obj}" + code = 500 + if isinstance(error, dict): + msg = error.get("message", msg) + code = error.get("code", code) + ret = ServerError(msg, code) + else: + ret = result + conn.resolve_pending_response(response_id, ret) + async def execute_method(self, method: RPCCallback, req_id: Optional[int], @@ -449,6 +478,7 @@ class WebSocket(WebSocketHandler, Subscribable): self.is_closed: bool = False self.ip_addr: str = self.request.remote_ip self.queue_busy: bool = False + self.pending_responses: Dict[int, asyncio.Future] = {} self.message_buf: List[Union[str, Dict[str, Any]]] = [] self.last_pong_time: float = self.event_loop.get_loop_time() self._connected_time: float = 0. @@ -546,11 +576,43 @@ class WebSocket(WebSocketHandler, Subscribable): 'method': "notify_status_update", 'params': [status, eventtime]}) + def call_method( + self, + method: str, + params: Optional[Union[List, Dict[str, Any]]] = None + ) -> Awaitable: + fut = self.event_loop.create_future() + msg = { + 'jsonrpc': "2.0", + 'method': method, + 'id': id(fut) + } + if params is not None: + msg["params"] = params + self.pending_responses[id(fut)] = fut + self.queue_message(msg) + return fut + + def resolve_pending_response( + self, response_id: int, result: Any + ) -> bool: + fut = self.pending_responses.pop(response_id, None) + if fut is None: + return False + if isinstance(result, ServerError): + fut.set_exception(result) + else: + fut.set_result(result) + return True + def on_close(self) -> None: self.is_closed = True self.message_buf = [] now = self.event_loop.get_loop_time() pong_elapsed = now - self.last_pong_time + for resp in self.pending_responses.values(): + resp.set_exception(ServerError("Client Socket Disconnected", 500)) + self.pending_responses = {} logging.info(f"Websocket Closed: ID: {self.uid} " f"Close Code: {self.close_code}, " f"Close Reason: {self.close_reason}, "