diff --git a/moonraker/components/mqtt.py b/moonraker/components/mqtt.py index 4650cf3..47fc9be 100644 --- a/moonraker/components/mqtt.py +++ b/moonraker/components/mqtt.py @@ -599,18 +599,16 @@ class MQTTClient(APITransport, Subscribable): request_method: str, callback: Callable[[WebRequest], Coroutine] ) -> RPCCallback: - async def func(**kwargs) -> Any: - self._check_timestamp(kwargs) - result = await callback( - WebRequest(endpoint, kwargs, request_method)) + async def func(args: Dict[str, Any]) -> Any: + self._check_timestamp(args) + result = await callback(WebRequest(endpoint, args, request_method)) return result return func def _generate_remote_callback(self, endpoint: str) -> RPCCallback: - async def func(**kwargs) -> Any: - self._check_timestamp(kwargs) - result = await self.klippy.request( - WebRequest(endpoint, kwargs)) + async def func(args: Dict[str, Any]) -> Any: + self._check_timestamp(args) + result = await self.klippy.request(WebRequest(endpoint, args)) return result return func diff --git a/moonraker/websockets.py b/moonraker/websockets.py index ac5b1c1..76f090a 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -207,18 +207,13 @@ class JsonRPC: method = self.methods.get(method_name, None) if method is None: return self.build_error(-32601, "Method not found", req_id) + params: Dict[str, Any] = {} if 'params' in obj: params = obj['params'] - if isinstance(params, list): - response = await self.execute_method( - method, req_id, conn, *params) - elif isinstance(params, dict): - response = await self.execute_method( - method, req_id, conn, **params) - else: - return self.build_error(-32600, "Invalid Request", req_id) - else: - response = await self.execute_method(method, req_id, conn) + if not isinstance(params, dict): + return self.build_error( + -32602, f"Invalid params:", req_id, True) + response = await self.execute_method(method, req_id, conn, params) return response def process_response( @@ -247,17 +242,15 @@ class JsonRPC: conn.resolve_pending_response(response_id, ret) async def execute_method(self, - method: RPCCallback, + callback: RPCCallback, req_id: Optional[int], conn: Optional[WebSocket], - *args, - **kwargs + params: Dict[str, Any] ) -> Optional[Dict[str, Any]]: + if conn is not None: + params["_socket_"] = conn try: - if conn is not None: - result = await method(conn, *args, **kwargs) - else: - result = await method(*args, **kwargs) + result = await callback(params) except TypeError as e: return self.build_error( -32602, f"Invalid params:\n{e}", req_id, True) @@ -351,9 +344,10 @@ class WebsocketManager(APITransport): self.rpc.remove_method(jrpc_method) def _generate_callback(self, endpoint: str) -> RPCCallback: - async def func(ws: WebSocket, **kwargs) -> Any: + async def func(args: Dict[str, Any]) -> Any: + ws: WebSocket = args.pop("_socket_") result = await self.klippy.request( - WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr, + WebRequest(endpoint, args, conn=ws, ip_addr=ws.ip_addr, user=ws.current_user)) return result return func @@ -363,32 +357,29 @@ class WebsocketManager(APITransport): request_method: str, callback: Callable[[WebRequest], Coroutine] ) -> RPCCallback: - async def func(ws: WebSocket, **kwargs) -> Any: + async def func(args: Dict[str, Any]) -> Any: + ws: WebSocket = args.pop("_socket_") result = await callback( - WebRequest(endpoint, kwargs, request_method, ws, + WebRequest(endpoint, args, request_method, ws, ip_addr=ws.ip_addr, user=ws.current_user)) return result return func - async def _handle_id_request(self, - ws: WebSocket, - **kwargs - ) -> Dict[str, int]: + async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: + ws: WebSocket = args["_socket_"] return {'websocket_id': ws.uid} - async def _handle_identify(self, - ws: WebSocket, - **kwargs - ) -> Dict[str, int]: + async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: + ws: WebSocket = args["_socket_"] if ws.identified: raise self.server.error( f"Connection already identified: {ws.client_data}" ) try: - name = str(kwargs["client_name"]) - version = str(kwargs["version"]) - client_type: str = str(kwargs["type"]).lower() - url = str(kwargs["url"]) + name = str(args["client_name"]) + version = str(args["version"]) + client_type: str = str(args["type"]).lower() + url = str(args["url"]) except KeyError as e: missing_key = str(e).split(":")[-1].strip() raise self.server.error(