moonraker/tests/fixtures/websocket_client.py

137 lines
4.7 KiB
Python

from __future__ import annotations
import pytest
import json
import asyncio
import tornado.websocket
from typing import (
TYPE_CHECKING,
Union,
Tuple,
Callable,
Dict,
List,
Any,
Optional,
)
if TYPE_CHECKING:
from tornado.websocket import WebSocketClientConnection
class WebsocketError(Exception):
def __init__(self, code, *args: object) -> None:
super().__init__(*args)
self.code = code
class WebsocketClient:
error = WebsocketError
def __init__(self,
type: str = "ws",
port: int = 7010
) -> None:
self.ws: Optional[WebSocketClientConnection] = None
self.pending_requests: Dict[int, asyncio.Future] = {}
self.notify_cbs: Dict[str, List[Callable[..., None]]] = {}
assert type in ["ws", "wss"]
self.url = f"{type}://127.0.0.1:{port}/websocket"
async def connect(self, token: Optional[str] = None) -> None:
url = self.url
if token is not None:
url += f"?token={token}"
self.ws = await tornado.websocket.websocket_connect(
url, connect_timeout=2.,
on_message_callback=self._on_message_received)
async def request(self,
remote_method: str,
args: Dict[str, Any] = {}
) -> Dict[str, Any]:
if self.ws is None:
pytest.fail("Websocket Not Connected")
loop = asyncio.get_running_loop()
fut = loop.create_future()
req, req_id = self._encode_request(remote_method, args)
self.pending_requests[req_id] = fut
await self.ws.write_message(req)
return await asyncio.wait_for(fut, 2.)
def _encode_request(self,
method: str,
args: Dict[str, Any]
) -> Tuple[str, int]:
request: Dict[str, Any] = {
'jsonrpc': "2.0",
'method': method,
}
if args:
request['params'] = args
req_id = id(request)
request["id"] = req_id
return json.dumps(request), req_id
def _on_message_received(self, message: Union[str, bytes, None]) -> None:
if isinstance(message, str):
self._decode_jsonrpc(message)
def _decode_jsonrpc(self, data: str) -> None:
try:
resp: Dict[str, Any] = json.loads(data)
except json.JSONDecodeError:
pytest.fail(f"Websocket JSON Decode Error: {data}")
header = resp.get('jsonrpc', "")
if header != "2.0":
# Invalid Json, set error if we can get the id
pytest.fail(f"Invalid jsonrpc header: {data}")
req_id: Optional[int] = resp.get("id")
method: Optional[str] = resp.get("method")
if method is not None:
if req_id is None:
params = resp.get("params", [])
if not isinstance(params, list):
pytest.fail("jsonrpc notification params"
f"should always be a list: {data}")
if method in self.notify_cbs:
for func in self.notify_cbs[method]:
func(*params)
else:
# This is a request from the server (should not happen)
pytest.fail(f"Server should not request from client: {data}")
elif req_id is not None:
pending_fut = self.pending_requests.pop(req_id, None)
if pending_fut is None:
# No future pending for this response
return
# This is a response
if "result" in resp:
pending_fut.set_result(resp["result"])
elif "error" in resp:
err = resp["error"]
try:
code = err["code"]
msg = err["message"]
except Exception:
pytest.fail(f"Invalid jsonrpc error: {data}")
exc = WebsocketError(code, msg)
pending_fut.set_exception(exc)
else:
pytest.fail(
f"Invalid jsonrpc packet, no result or error: {data}")
else:
# Invalid json
pytest.fail(f"Invalid jsonrpc packet, no id: {data}")
def register_notify_callback(self, name: str, callback) -> None:
if name in self.notify_cbs:
self.notify_cbs[name].append(callback)
else:
self.notify_cbs[name][callback]
def close(self):
for fut in self.pending_requests.values():
if not fut.done():
fut.set_exception(WebsocketError(
0, "Closing Websocket Client"))
if self.ws is not None:
self.ws.close(1000, "Test Complete")