websockets: move JsonRPC and BaseSocketClient to common
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
201e84cd94
commit
1b9f29db13
|
@ -6,6 +6,9 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
from .utils import ServerError, Sentinel
|
from .utils import ServerError, Sentinel
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
|
@ -20,11 +23,14 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
|
Awaitable
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .websockets import BaseSocketClient
|
from .server import Server
|
||||||
|
from .websockets import WebsocketManager
|
||||||
from .components.authorization import Authorization
|
from .components.authorization import Authorization
|
||||||
|
from asyncio import Future
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
_C = TypeVar("_C", str, bool, float, int)
|
_C = TypeVar("_C", str, bool, float, int)
|
||||||
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||||
|
@ -40,6 +46,202 @@ class Subscribable:
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class APIDefinition:
|
||||||
|
def __init__(self,
|
||||||
|
endpoint: str,
|
||||||
|
http_uri: str,
|
||||||
|
jrpc_methods: List[str],
|
||||||
|
request_methods: Union[str, List[str]],
|
||||||
|
transports: List[str],
|
||||||
|
callback: Optional[Callable[[WebRequest], Coroutine]],
|
||||||
|
need_object_parser: bool):
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.uri = http_uri
|
||||||
|
self.jrpc_methods = jrpc_methods
|
||||||
|
if not isinstance(request_methods, list):
|
||||||
|
request_methods = [request_methods]
|
||||||
|
self.request_methods = request_methods
|
||||||
|
self.supported_transports = transports
|
||||||
|
self.callback = callback
|
||||||
|
self.need_object_parser = need_object_parser
|
||||||
|
|
||||||
|
class APITransport:
|
||||||
|
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class BaseRemoteConnection(Subscribable):
|
||||||
|
def on_create(self, server: Server) -> None:
|
||||||
|
self.server = server
|
||||||
|
self.eventloop = server.get_event_loop()
|
||||||
|
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||||
|
self.rpc = self.wsm.rpc
|
||||||
|
self._uid = id(self)
|
||||||
|
self.ip_addr = ""
|
||||||
|
self.is_closed: bool = False
|
||||||
|
self.queue_busy: bool = False
|
||||||
|
self.pending_responses: Dict[int, Future] = {}
|
||||||
|
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||||
|
self._connected_time: float = 0.
|
||||||
|
self._identified: bool = False
|
||||||
|
self._client_data: Dict[str, str] = {
|
||||||
|
"name": "unknown",
|
||||||
|
"version": "",
|
||||||
|
"type": "",
|
||||||
|
"url": ""
|
||||||
|
}
|
||||||
|
self._need_auth: bool = False
|
||||||
|
self._user_info: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._user_info
|
||||||
|
|
||||||
|
@user_info.setter
|
||||||
|
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
||||||
|
self._user_info = uinfo
|
||||||
|
self._need_auth = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_auth(self) -> bool:
|
||||||
|
return self._need_auth
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uid(self) -> int:
|
||||||
|
return self._uid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hostname(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_time(self) -> float:
|
||||||
|
return self._connected_time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identified(self) -> bool:
|
||||||
|
return self._identified
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_data(self) -> Dict[str, str]:
|
||||||
|
return self._client_data
|
||||||
|
|
||||||
|
@client_data.setter
|
||||||
|
def client_data(self, data: Dict[str, str]) -> None:
|
||||||
|
self._client_data = data
|
||||||
|
self._identified = True
|
||||||
|
|
||||||
|
async def _process_message(self, message: str) -> None:
|
||||||
|
try:
|
||||||
|
response = await self.rpc.dispatch(message, self)
|
||||||
|
if response is not None:
|
||||||
|
self.queue_message(response)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Websocket Command Error")
|
||||||
|
|
||||||
|
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
||||||
|
self.message_buf.append(message)
|
||||||
|
if self.queue_busy:
|
||||||
|
return
|
||||||
|
self.queue_busy = True
|
||||||
|
self.eventloop.register_callback(self._write_messages)
|
||||||
|
|
||||||
|
def authenticate(
|
||||||
|
self,
|
||||||
|
token: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||||
|
if auth is None:
|
||||||
|
return
|
||||||
|
if token is not None:
|
||||||
|
self.user_info = auth.validate_jwt(token)
|
||||||
|
elif api_key is not None and self.user_info is None:
|
||||||
|
self.user_info = auth.validate_api_key(api_key)
|
||||||
|
else:
|
||||||
|
self.check_authenticated()
|
||||||
|
|
||||||
|
def check_authenticated(self, path: str = "") -> None:
|
||||||
|
if not self._need_auth:
|
||||||
|
return
|
||||||
|
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||||
|
if auth is None:
|
||||||
|
return
|
||||||
|
if not auth.is_path_permitted(path):
|
||||||
|
raise self.server.error("Unauthorized", 401)
|
||||||
|
|
||||||
|
def on_user_logout(self, user: str) -> bool:
|
||||||
|
if self._user_info is None:
|
||||||
|
return False
|
||||||
|
if user == self._user_info.get("username", ""):
|
||||||
|
self._user_info = None
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _write_messages(self):
|
||||||
|
if self.is_closed:
|
||||||
|
self.message_buf = []
|
||||||
|
self.queue_busy = False
|
||||||
|
return
|
||||||
|
while self.message_buf:
|
||||||
|
msg = self.message_buf.pop(0)
|
||||||
|
await self.write_to_socket(msg)
|
||||||
|
self.queue_busy = False
|
||||||
|
|
||||||
|
async def write_to_socket(
|
||||||
|
self, message: Union[str, Dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError("Children must implement write_to_socket")
|
||||||
|
|
||||||
|
def send_status(self,
|
||||||
|
status: Dict[str, Any],
|
||||||
|
eventtime: float
|
||||||
|
) -> None:
|
||||||
|
if not status:
|
||||||
|
return
|
||||||
|
self.queue_message({
|
||||||
|
'jsonrpc': "2.0",
|
||||||
|
'method': "notify_status_update",
|
||||||
|
'params': [status, eventtime]})
|
||||||
|
|
||||||
|
def call_method(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
params: Optional[Union[List, Dict[str, Any]]] = None
|
||||||
|
) -> Awaitable:
|
||||||
|
fut = self.eventloop.create_future()
|
||||||
|
msg = {
|
||||||
|
'jsonrpc': "2.0",
|
||||||
|
'method': method,
|
||||||
|
'id': id(fut)
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
msg["params"] = params
|
||||||
|
self.pending_responses[id(fut)] = fut
|
||||||
|
self.queue_message(msg)
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def send_notification(self, name: str, data: List) -> None:
|
||||||
|
self.wsm.notify_clients(name, data, [self._uid])
|
||||||
|
|
||||||
|
def resolve_pending_response(
|
||||||
|
self, response_id: int, result: Any
|
||||||
|
) -> bool:
|
||||||
|
fut = self.pending_responses.pop(response_id, None)
|
||||||
|
if fut is None:
|
||||||
|
return False
|
||||||
|
if isinstance(result, ServerError):
|
||||||
|
fut.set_exception(result)
|
||||||
|
else:
|
||||||
|
fut.set_result(result)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close_socket(self, code: int, reason: str) -> None:
|
||||||
|
raise NotImplementedError("Children must implement close_socket()")
|
||||||
|
|
||||||
|
|
||||||
class WebRequest:
|
class WebRequest:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
|
@ -72,8 +274,8 @@ class WebRequest:
|
||||||
def get_subscribable(self) -> Optional[Subscribable]:
|
def get_subscribable(self) -> Optional[Subscribable]:
|
||||||
return self.conn
|
return self.conn
|
||||||
|
|
||||||
def get_client_connection(self) -> Optional[BaseSocketClient]:
|
def get_client_connection(self) -> Optional[BaseRemoteConnection]:
|
||||||
if isinstance(self.conn, BaseSocketClient):
|
if isinstance(self.conn, BaseRemoteConnection):
|
||||||
return self.conn
|
return self.conn
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -142,28 +344,188 @@ class WebRequest:
|
||||||
) -> Union[bool, _T]:
|
) -> Union[bool, _T]:
|
||||||
return self._get_converted_arg(key, default, bool)
|
return self._get_converted_arg(key, default, bool)
|
||||||
|
|
||||||
class APIDefinition:
|
|
||||||
def __init__(self,
|
|
||||||
endpoint: str,
|
|
||||||
http_uri: str,
|
|
||||||
jrpc_methods: List[str],
|
|
||||||
request_methods: Union[str, List[str]],
|
|
||||||
transports: List[str],
|
|
||||||
callback: Optional[Callable[[WebRequest], Coroutine]],
|
|
||||||
need_object_parser: bool):
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.uri = http_uri
|
|
||||||
self.jrpc_methods = jrpc_methods
|
|
||||||
if not isinstance(request_methods, list):
|
|
||||||
request_methods = [request_methods]
|
|
||||||
self.request_methods = request_methods
|
|
||||||
self.supported_transports = transports
|
|
||||||
self.callback = callback
|
|
||||||
self.need_object_parser = need_object_parser
|
|
||||||
|
|
||||||
class APITransport:
|
class JsonRPC:
|
||||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
def __init__(
|
||||||
raise NotImplementedError
|
self, server: Server, transport: str = "Websocket"
|
||||||
|
) -> None:
|
||||||
|
self.methods: Dict[str, RPCCallback] = {}
|
||||||
|
self.transport = transport
|
||||||
|
self.sanitize_response = False
|
||||||
|
self.verbose = server.is_verbose_enabled()
|
||||||
|
|
||||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
||||||
raise NotImplementedError
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
self.sanitize_response = False
|
||||||
|
output = rpc_obj
|
||||||
|
method: Optional[str] = rpc_obj.get("method")
|
||||||
|
params: Dict[str, Any] = rpc_obj.get("params", {})
|
||||||
|
if isinstance(method, str):
|
||||||
|
if (
|
||||||
|
method.startswith("access.") or
|
||||||
|
method == "machine.sudo.password"
|
||||||
|
):
|
||||||
|
self.sanitize_response = True
|
||||||
|
if params and isinstance(params, dict):
|
||||||
|
output = copy.deepcopy(rpc_obj)
|
||||||
|
output["params"] = {key: "<sanitized>" for key in params}
|
||||||
|
elif method == "server.connection.identify":
|
||||||
|
output = copy.deepcopy(rpc_obj)
|
||||||
|
for field in ["access_token", "api_key"]:
|
||||||
|
if field in params:
|
||||||
|
output["params"][field] = "<sanitized>"
|
||||||
|
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
|
||||||
|
|
||||||
|
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
||||||
|
if not self.verbose:
|
||||||
|
return
|
||||||
|
if resp_obj is None:
|
||||||
|
return
|
||||||
|
output = resp_obj
|
||||||
|
if self.sanitize_response and "result" in resp_obj:
|
||||||
|
output = copy.deepcopy(resp_obj)
|
||||||
|
output["result"] = "<sanitized>"
|
||||||
|
self.sanitize_response = False
|
||||||
|
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
|
||||||
|
|
||||||
|
def register_method(self,
|
||||||
|
name: str,
|
||||||
|
method: RPCCallback
|
||||||
|
) -> None:
|
||||||
|
self.methods[name] = method
|
||||||
|
|
||||||
|
def remove_method(self, name: str) -> None:
|
||||||
|
self.methods.pop(name, None)
|
||||||
|
|
||||||
|
async def dispatch(self,
|
||||||
|
data: str,
|
||||||
|
conn: Optional[BaseRemoteConnection] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
try:
|
||||||
|
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
|
||||||
|
except Exception:
|
||||||
|
msg = f"{self.transport} data not json: {data}"
|
||||||
|
logging.exception(msg)
|
||||||
|
err = self.build_error(-32700, "Parse error")
|
||||||
|
return json.dumps(err)
|
||||||
|
if isinstance(obj, list):
|
||||||
|
responses: List[Dict[str, Any]] = []
|
||||||
|
for item in obj:
|
||||||
|
self._log_request(item)
|
||||||
|
resp = await self.process_object(item, conn)
|
||||||
|
if resp is not None:
|
||||||
|
self._log_response(resp)
|
||||||
|
responses.append(resp)
|
||||||
|
if responses:
|
||||||
|
return json.dumps(responses)
|
||||||
|
else:
|
||||||
|
self._log_request(obj)
|
||||||
|
response = await self.process_object(obj, conn)
|
||||||
|
if response is not None:
|
||||||
|
self._log_response(response)
|
||||||
|
return json.dumps(response)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_object(self,
|
||||||
|
obj: Dict[str, Any],
|
||||||
|
conn: Optional[BaseRemoteConnection]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
req_id: Optional[int] = obj.get('id', None)
|
||||||
|
rpc_version: str = obj.get('jsonrpc', "")
|
||||||
|
if rpc_version != "2.0":
|
||||||
|
return self.build_error(-32600, "Invalid Request", req_id)
|
||||||
|
method_name = obj.get('method', Sentinel.MISSING)
|
||||||
|
if method_name is Sentinel.MISSING:
|
||||||
|
self.process_response(obj, conn)
|
||||||
|
return None
|
||||||
|
if not isinstance(method_name, str):
|
||||||
|
return self.build_error(-32600, "Invalid Request", req_id)
|
||||||
|
method = self.methods.get(method_name, None)
|
||||||
|
if method is None:
|
||||||
|
return self.build_error(-32601, "Method not found", req_id)
|
||||||
|
params: Dict[str, Any] = {}
|
||||||
|
if 'params' in obj:
|
||||||
|
params = obj['params']
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
return self.build_error(
|
||||||
|
-32602, f"Invalid params:", req_id, True)
|
||||||
|
response = await self.execute_method(method, req_id, conn, params)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def process_response(
|
||||||
|
self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection]
|
||||||
|
) -> None:
|
||||||
|
if conn is None:
|
||||||
|
logging.debug(f"RPC Response to non-socket request: {obj}")
|
||||||
|
return
|
||||||
|
response_id = obj.get("id")
|
||||||
|
if response_id is None:
|
||||||
|
logging.debug(f"RPC Response with null ID: {obj}")
|
||||||
|
return
|
||||||
|
result = obj.get("result")
|
||||||
|
if result is None:
|
||||||
|
name = conn.client_data["name"]
|
||||||
|
error = obj.get("error")
|
||||||
|
msg = f"Invalid Response: {obj}"
|
||||||
|
code = -32600
|
||||||
|
if isinstance(error, dict):
|
||||||
|
msg = error.get("message", msg)
|
||||||
|
code = error.get("code", code)
|
||||||
|
msg = f"{name} rpc error: {code} {msg}"
|
||||||
|
ret = ServerError(msg, 418)
|
||||||
|
else:
|
||||||
|
ret = result
|
||||||
|
conn.resolve_pending_response(response_id, ret)
|
||||||
|
|
||||||
|
async def execute_method(self,
|
||||||
|
callback: RPCCallback,
|
||||||
|
req_id: Optional[int],
|
||||||
|
conn: Optional[BaseRemoteConnection],
|
||||||
|
params: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
if conn is not None:
|
||||||
|
params["_socket_"] = conn
|
||||||
|
try:
|
||||||
|
result = await callback(params)
|
||||||
|
except TypeError as e:
|
||||||
|
return self.build_error(
|
||||||
|
-32602, f"Invalid params:\n{e}", req_id, True)
|
||||||
|
except ServerError as e:
|
||||||
|
code = e.status_code
|
||||||
|
if code == 404:
|
||||||
|
code = -32601
|
||||||
|
elif code == 401:
|
||||||
|
code = -32602
|
||||||
|
return self.build_error(code, str(e), req_id, True)
|
||||||
|
except Exception as e:
|
||||||
|
return self.build_error(-31000, str(e), req_id, True)
|
||||||
|
|
||||||
|
if req_id is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self.build_result(result, req_id)
|
||||||
|
|
||||||
|
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
'jsonrpc': "2.0",
|
||||||
|
'result': result,
|
||||||
|
'id': req_id
|
||||||
|
}
|
||||||
|
|
||||||
|
def build_error(self,
|
||||||
|
code: int,
|
||||||
|
msg: str,
|
||||||
|
req_id: Optional[int] = None,
|
||||||
|
is_exc: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
|
||||||
|
if is_exc:
|
||||||
|
logging.exception(log_msg)
|
||||||
|
else:
|
||||||
|
logging.info(log_msg)
|
||||||
|
return {
|
||||||
|
'jsonrpc': "2.0",
|
||||||
|
'error': {'code': code, 'message': msg},
|
||||||
|
'id': req_id
|
||||||
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from ..websockets import BaseSocketClient
|
from ..common import BaseRemoteConnection
|
||||||
from ..utils import get_unix_peer_credentials
|
from ..utils import get_unix_peer_credentials
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
|
@ -32,7 +32,7 @@ UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
|
||||||
class ExtensionManager:
|
class ExtensionManager:
|
||||||
def __init__(self, config: ConfigHelper) -> None:
|
def __init__(self, config: ConfigHelper) -> None:
|
||||||
self.server = config.get_server()
|
self.server = config.get_server()
|
||||||
self.agents: Dict[str, BaseSocketClient] = {}
|
self.agents: Dict[str, BaseRemoteConnection] = {}
|
||||||
self.uds_server: Optional[asyncio.AbstractServer] = None
|
self.uds_server: Optional[asyncio.AbstractServer] = None
|
||||||
self.server.register_endpoint(
|
self.server.register_endpoint(
|
||||||
"/connection/send_event", ["POST"], self._handle_agent_event,
|
"/connection/send_event", ["POST"], self._handle_agent_event,
|
||||||
|
@ -45,7 +45,7 @@ class ExtensionManager:
|
||||||
"/server/extensions/request", ["POST"], self._handle_call_agent
|
"/server/extensions/request", ["POST"], self._handle_call_agent
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_agent(self, connection: BaseSocketClient) -> None:
|
def register_agent(self, connection: BaseRemoteConnection) -> None:
|
||||||
data = connection.client_data
|
data = connection.client_data
|
||||||
name = data["name"]
|
name = data["name"]
|
||||||
client_type = data["type"]
|
client_type = data["type"]
|
||||||
|
@ -64,7 +64,7 @@ class ExtensionManager:
|
||||||
}
|
}
|
||||||
connection.send_notification("agent_event", [evt])
|
connection.send_notification("agent_event", [evt])
|
||||||
|
|
||||||
def remove_agent(self, connection: BaseSocketClient) -> None:
|
def remove_agent(self, connection: BaseRemoteConnection) -> None:
|
||||||
name = connection.client_data["name"]
|
name = connection.client_data["name"]
|
||||||
if name in self.agents:
|
if name in self.agents:
|
||||||
del self.agents[name]
|
del self.agents[name]
|
||||||
|
@ -135,7 +135,7 @@ class ExtensionManager:
|
||||||
await self.uds_server.wait_closed()
|
await self.uds_server.wait_closed()
|
||||||
self.uds_server = None
|
self.uds_server = None
|
||||||
|
|
||||||
class UnixSocketClient(BaseSocketClient):
|
class UnixSocketClient(BaseRemoteConnection):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server: Server,
|
server: Server,
|
||||||
|
|
|
@ -13,8 +13,7 @@ import pathlib
|
||||||
import ssl
|
import ssl
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import paho.mqtt.client as paho_mqtt
|
import paho.mqtt.client as paho_mqtt
|
||||||
from ..common import Subscribable, WebRequest, APITransport
|
from ..common import Subscribable, WebRequest, APITransport, JsonRPC
|
||||||
from ..websockets import JsonRPC
|
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
from typing import (
|
from typing import (
|
||||||
|
|
|
@ -32,7 +32,8 @@ from typing import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..app import InternalTransport
|
from ..app import InternalTransport
|
||||||
from ..confighelper import ConfigHelper
|
from ..confighelper import ConfigHelper
|
||||||
from ..websockets import WebsocketManager, BaseSocketClient
|
from ..websockets import WebsocketManager
|
||||||
|
from ..common import BaseRemoteConnection
|
||||||
from tornado.websocket import WebSocketClientConnection
|
from tornado.websocket import WebSocketClientConnection
|
||||||
from .database import MoonrakerDatabase
|
from .database import MoonrakerDatabase
|
||||||
from .klippy_apis import KlippyAPI
|
from .klippy_apis import KlippyAPI
|
||||||
|
@ -609,7 +610,7 @@ class SimplyPrint(Subscribable):
|
||||||
is_on = device_info["status"] == "on"
|
is_on = device_info["status"] == "on"
|
||||||
self.send_sp("power_controller", {"on": is_on})
|
self.send_sp("power_controller", {"on": is_on})
|
||||||
|
|
||||||
def _on_websocket_identified(self, ws: BaseSocketClient) -> None:
|
def _on_websocket_identified(self, ws: BaseRemoteConnection) -> None:
|
||||||
if (
|
if (
|
||||||
self.cache.current_wsid is None and
|
self.cache.current_wsid is None and
|
||||||
ws.client_data.get("type", "") == "web"
|
ws.client_data.get("type", "") == "web"
|
||||||
|
@ -622,7 +623,7 @@ class SimplyPrint(Subscribable):
|
||||||
self.cache.current_wsid = ws.uid
|
self.cache.current_wsid = ws.uid
|
||||||
self.send_sp("machine_data", ui_data)
|
self.send_sp("machine_data", ui_data)
|
||||||
|
|
||||||
def _on_websocket_removed(self, ws: BaseSocketClient) -> None:
|
def _on_websocket_removed(self, ws: BaseRemoteConnection) -> None:
|
||||||
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
|
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
|
||||||
return
|
return
|
||||||
ui_data = self._get_ui_info()
|
ui_data = self._get_ui_info()
|
||||||
|
|
|
@ -7,13 +7,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
|
||||||
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
||||||
from tornado.web import HTTPError
|
from tornado.web import HTTPError
|
||||||
from .common import WebRequest, Subscribable, APITransport, APIDefinition
|
from .common import (
|
||||||
from .utils import ServerError, Sentinel
|
WebRequest,
|
||||||
|
BaseRemoteConnection,
|
||||||
|
APITransport,
|
||||||
|
APIDefinition,
|
||||||
|
JsonRPC
|
||||||
|
)
|
||||||
|
from .utils import ServerError
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -42,195 +46,10 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
||||||
|
|
||||||
class JsonRPC:
|
|
||||||
def __init__(
|
|
||||||
self, server: Server, transport: str = "Websocket"
|
|
||||||
) -> None:
|
|
||||||
self.methods: Dict[str, RPCCallback] = {}
|
|
||||||
self.transport = transport
|
|
||||||
self.sanitize_response = False
|
|
||||||
self.verbose = server.is_verbose_enabled()
|
|
||||||
|
|
||||||
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
|
||||||
if not self.verbose:
|
|
||||||
return
|
|
||||||
self.sanitize_response = False
|
|
||||||
output = rpc_obj
|
|
||||||
method: Optional[str] = rpc_obj.get("method")
|
|
||||||
params: Dict[str, Any] = rpc_obj.get("params", {})
|
|
||||||
if isinstance(method, str):
|
|
||||||
if (
|
|
||||||
method.startswith("access.") or
|
|
||||||
method == "machine.sudo.password"
|
|
||||||
):
|
|
||||||
self.sanitize_response = True
|
|
||||||
if params and isinstance(params, dict):
|
|
||||||
output = copy.deepcopy(rpc_obj)
|
|
||||||
output["params"] = {key: "<sanitized>" for key in params}
|
|
||||||
elif method == "server.connection.identify":
|
|
||||||
output = copy.deepcopy(rpc_obj)
|
|
||||||
for field in ["access_token", "api_key"]:
|
|
||||||
if field in params:
|
|
||||||
output["params"][field] = "<sanitized>"
|
|
||||||
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
|
|
||||||
|
|
||||||
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
|
||||||
if not self.verbose:
|
|
||||||
return
|
|
||||||
if resp_obj is None:
|
|
||||||
return
|
|
||||||
output = resp_obj
|
|
||||||
if self.sanitize_response and "result" in resp_obj:
|
|
||||||
output = copy.deepcopy(resp_obj)
|
|
||||||
output["result"] = "<sanitized>"
|
|
||||||
self.sanitize_response = False
|
|
||||||
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
|
|
||||||
|
|
||||||
def register_method(self,
|
|
||||||
name: str,
|
|
||||||
method: RPCCallback
|
|
||||||
) -> None:
|
|
||||||
self.methods[name] = method
|
|
||||||
|
|
||||||
def remove_method(self, name: str) -> None:
|
|
||||||
self.methods.pop(name, None)
|
|
||||||
|
|
||||||
async def dispatch(self,
|
|
||||||
data: str,
|
|
||||||
conn: Optional[BaseSocketClient] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
try:
|
|
||||||
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
|
|
||||||
except Exception:
|
|
||||||
msg = f"{self.transport} data not json: {data}"
|
|
||||||
logging.exception(msg)
|
|
||||||
err = self.build_error(-32700, "Parse error")
|
|
||||||
return json.dumps(err)
|
|
||||||
if isinstance(obj, list):
|
|
||||||
responses: List[Dict[str, Any]] = []
|
|
||||||
for item in obj:
|
|
||||||
self._log_request(item)
|
|
||||||
resp = await self.process_object(item, conn)
|
|
||||||
if resp is not None:
|
|
||||||
self._log_response(resp)
|
|
||||||
responses.append(resp)
|
|
||||||
if responses:
|
|
||||||
return json.dumps(responses)
|
|
||||||
else:
|
|
||||||
self._log_request(obj)
|
|
||||||
response = await self.process_object(obj, conn)
|
|
||||||
if response is not None:
|
|
||||||
self._log_response(response)
|
|
||||||
return json.dumps(response)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def process_object(self,
|
|
||||||
obj: Dict[str, Any],
|
|
||||||
conn: Optional[BaseSocketClient]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
req_id: Optional[int] = obj.get('id', None)
|
|
||||||
rpc_version: str = obj.get('jsonrpc', "")
|
|
||||||
if rpc_version != "2.0":
|
|
||||||
return self.build_error(-32600, "Invalid Request", req_id)
|
|
||||||
method_name = obj.get('method', Sentinel.MISSING)
|
|
||||||
if method_name is Sentinel.MISSING:
|
|
||||||
self.process_response(obj, conn)
|
|
||||||
return None
|
|
||||||
if not isinstance(method_name, str):
|
|
||||||
return self.build_error(-32600, "Invalid Request", req_id)
|
|
||||||
method = self.methods.get(method_name, None)
|
|
||||||
if method is None:
|
|
||||||
return self.build_error(-32601, "Method not found", req_id)
|
|
||||||
params: Dict[str, Any] = {}
|
|
||||||
if 'params' in obj:
|
|
||||||
params = obj['params']
|
|
||||||
if not isinstance(params, dict):
|
|
||||||
return self.build_error(
|
|
||||||
-32602, f"Invalid params:", req_id, True)
|
|
||||||
response = await self.execute_method(method, req_id, conn, params)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
|
|
||||||
) -> None:
|
|
||||||
if conn is None:
|
|
||||||
logging.debug(f"RPC Response to non-socket request: {obj}")
|
|
||||||
return
|
|
||||||
response_id = obj.get("id")
|
|
||||||
if response_id is None:
|
|
||||||
logging.debug(f"RPC Response with null ID: {obj}")
|
|
||||||
return
|
|
||||||
result = obj.get("result")
|
|
||||||
if result is None:
|
|
||||||
name = conn.client_data["name"]
|
|
||||||
error = obj.get("error")
|
|
||||||
msg = f"Invalid Response: {obj}"
|
|
||||||
code = -32600
|
|
||||||
if isinstance(error, dict):
|
|
||||||
msg = error.get("message", msg)
|
|
||||||
code = error.get("code", code)
|
|
||||||
msg = f"{name} rpc error: {code} {msg}"
|
|
||||||
ret = ServerError(msg, 418)
|
|
||||||
else:
|
|
||||||
ret = result
|
|
||||||
conn.resolve_pending_response(response_id, ret)
|
|
||||||
|
|
||||||
async def execute_method(self,
|
|
||||||
callback: RPCCallback,
|
|
||||||
req_id: Optional[int],
|
|
||||||
conn: Optional[BaseSocketClient],
|
|
||||||
params: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
if conn is not None:
|
|
||||||
params["_socket_"] = conn
|
|
||||||
try:
|
|
||||||
result = await callback(params)
|
|
||||||
except TypeError as e:
|
|
||||||
return self.build_error(
|
|
||||||
-32602, f"Invalid params:\n{e}", req_id, True)
|
|
||||||
except ServerError as e:
|
|
||||||
code = e.status_code
|
|
||||||
if code == 404:
|
|
||||||
code = -32601
|
|
||||||
elif code == 401:
|
|
||||||
code = -32602
|
|
||||||
return self.build_error(code, str(e), req_id, True)
|
|
||||||
except Exception as e:
|
|
||||||
return self.build_error(-31000, str(e), req_id, True)
|
|
||||||
|
|
||||||
if req_id is None:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return self.build_result(result, req_id)
|
|
||||||
|
|
||||||
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
'jsonrpc': "2.0",
|
|
||||||
'result': result,
|
|
||||||
'id': req_id
|
|
||||||
}
|
|
||||||
|
|
||||||
def build_error(self,
|
|
||||||
code: int,
|
|
||||||
msg: str,
|
|
||||||
req_id: Optional[int] = None,
|
|
||||||
is_exc: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
|
|
||||||
if is_exc:
|
|
||||||
logging.exception(log_msg)
|
|
||||||
else:
|
|
||||||
logging.info(log_msg)
|
|
||||||
return {
|
|
||||||
'jsonrpc': "2.0",
|
|
||||||
'error': {'code': code, 'message': msg},
|
|
||||||
'id': req_id
|
|
||||||
}
|
|
||||||
|
|
||||||
class WebsocketManager(APITransport):
|
class WebsocketManager(APITransport):
|
||||||
def __init__(self, server: Server) -> None:
|
def __init__(self, server: Server) -> None:
|
||||||
self.server = server
|
self.server = server
|
||||||
self.clients: Dict[int, BaseSocketClient] = {}
|
self.clients: Dict[int, BaseRemoteConnection] = {}
|
||||||
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
||||||
self.rpc = JsonRPC(server)
|
self.rpc = JsonRPC(server)
|
||||||
self.closed_event: Optional[asyncio.Event] = None
|
self.closed_event: Optional[asyncio.Event] = None
|
||||||
|
@ -289,7 +108,7 @@ class WebsocketManager(APITransport):
|
||||||
callback: Callable[[WebRequest], Coroutine]
|
callback: Callable[[WebRequest], Coroutine]
|
||||||
) -> RPCCallback:
|
) -> RPCCallback:
|
||||||
async def func(args: Dict[str, Any]) -> Any:
|
async def func(args: Dict[str, Any]) -> Any:
|
||||||
sc: BaseSocketClient = args.pop("_socket_")
|
sc: BaseRemoteConnection = args.pop("_socket_")
|
||||||
sc.check_authenticated(path=endpoint)
|
sc.check_authenticated(path=endpoint)
|
||||||
result = await callback(
|
result = await callback(
|
||||||
WebRequest(endpoint, args, request_method, sc,
|
WebRequest(endpoint, args, request_method, sc,
|
||||||
|
@ -298,12 +117,12 @@ class WebsocketManager(APITransport):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||||
sc: BaseSocketClient = args["_socket_"]
|
sc: BaseRemoteConnection = args["_socket_"]
|
||||||
sc.check_authenticated()
|
sc.check_authenticated()
|
||||||
return {'websocket_id': sc.uid}
|
return {'websocket_id': sc.uid}
|
||||||
|
|
||||||
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||||
sc: BaseSocketClient = args["_socket_"]
|
sc: BaseRemoteConnection = args["_socket_"]
|
||||||
sc.authenticate(
|
sc.authenticate(
|
||||||
token=args.get("access_token", None),
|
token=args.get("access_token", None),
|
||||||
api_key=args.get("api_key", None)
|
api_key=args.get("api_key", None)
|
||||||
|
@ -355,7 +174,7 @@ class WebsocketManager(APITransport):
|
||||||
def has_socket(self, ws_id: int) -> bool:
|
def has_socket(self, ws_id: int) -> bool:
|
||||||
return ws_id in self.clients
|
return ws_id in self.clients
|
||||||
|
|
||||||
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
|
def get_client(self, ws_id: int) -> Optional[BaseRemoteConnection]:
|
||||||
sc = self.clients.get(ws_id, None)
|
sc = self.clients.get(ws_id, None)
|
||||||
if sc is None or not isinstance(sc, WebSocket):
|
if sc is None or not isinstance(sc, WebSocket):
|
||||||
return None
|
return None
|
||||||
|
@ -363,37 +182,37 @@ class WebsocketManager(APITransport):
|
||||||
|
|
||||||
def get_clients_by_type(
|
def get_clients_by_type(
|
||||||
self, client_type: str
|
self, client_type: str
|
||||||
) -> List[BaseSocketClient]:
|
) -> List[BaseRemoteConnection]:
|
||||||
if not client_type:
|
if not client_type:
|
||||||
return []
|
return []
|
||||||
ret: List[BaseSocketClient] = []
|
ret: List[BaseRemoteConnection] = []
|
||||||
for sc in self.clients.values():
|
for sc in self.clients.values():
|
||||||
if sc.client_data.get("type", "") == client_type.lower():
|
if sc.client_data.get("type", "") == client_type.lower():
|
||||||
ret.append(sc)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
|
def get_clients_by_name(self, name: str) -> List[BaseRemoteConnection]:
|
||||||
if not name:
|
if not name:
|
||||||
return []
|
return []
|
||||||
ret: List[BaseSocketClient] = []
|
ret: List[BaseRemoteConnection] = []
|
||||||
for sc in self.clients.values():
|
for sc in self.clients.values():
|
||||||
if sc.client_data.get("name", "").lower() == name.lower():
|
if sc.client_data.get("name", "").lower() == name.lower():
|
||||||
ret.append(sc)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get_unidentified_clients(self) -> List[BaseSocketClient]:
|
def get_unidentified_clients(self) -> List[BaseRemoteConnection]:
|
||||||
ret: List[BaseSocketClient] = []
|
ret: List[BaseRemoteConnection] = []
|
||||||
for sc in self.clients.values():
|
for sc in self.clients.values():
|
||||||
if not sc.client_data:
|
if not sc.client_data:
|
||||||
ret.append(sc)
|
ret.append(sc)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def add_client(self, sc: BaseSocketClient) -> None:
|
def add_client(self, sc: BaseRemoteConnection) -> None:
|
||||||
self.clients[sc.uid] = sc
|
self.clients[sc.uid] = sc
|
||||||
self.server.send_event("websockets:client_added", sc)
|
self.server.send_event("websockets:client_added", sc)
|
||||||
logging.debug(f"New Websocket Added: {sc.uid}")
|
logging.debug(f"New Websocket Added: {sc.uid}")
|
||||||
|
|
||||||
def remove_client(self, sc: BaseSocketClient) -> None:
|
def remove_client(self, sc: BaseRemoteConnection) -> None:
|
||||||
old_sc = self.clients.pop(sc.uid, None)
|
old_sc = self.clients.pop(sc.uid, None)
|
||||||
if old_sc is not None:
|
if old_sc is not None:
|
||||||
self.server.send_event("websockets:client_removed", sc)
|
self.server.send_event("websockets:client_removed", sc)
|
||||||
|
@ -449,176 +268,7 @@ class WebsocketManager(APITransport):
|
||||||
pass
|
pass
|
||||||
self.closed_event = None
|
self.closed_event = None
|
||||||
|
|
||||||
class BaseSocketClient(Subscribable):
|
class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
||||||
def on_create(self, server: Server) -> None:
|
|
||||||
self.server = server
|
|
||||||
self.eventloop = server.get_event_loop()
|
|
||||||
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
|
||||||
self.rpc = self.wsm.rpc
|
|
||||||
self._uid = id(self)
|
|
||||||
self.ip_addr = ""
|
|
||||||
self.is_closed: bool = False
|
|
||||||
self.queue_busy: bool = False
|
|
||||||
self.pending_responses: Dict[int, asyncio.Future] = {}
|
|
||||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
|
||||||
self._connected_time: float = 0.
|
|
||||||
self._identified: bool = False
|
|
||||||
self._client_data: Dict[str, str] = {
|
|
||||||
"name": "unknown",
|
|
||||||
"version": "",
|
|
||||||
"type": "",
|
|
||||||
"url": ""
|
|
||||||
}
|
|
||||||
self._need_auth: bool = False
|
|
||||||
self._user_info: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def user_info(self) -> Optional[Dict[str, Any]]:
|
|
||||||
return self._user_info
|
|
||||||
|
|
||||||
@user_info.setter
|
|
||||||
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
|
||||||
self._user_info = uinfo
|
|
||||||
self._need_auth = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def need_auth(self) -> bool:
|
|
||||||
return self._need_auth
|
|
||||||
|
|
||||||
@property
|
|
||||||
def uid(self) -> int:
|
|
||||||
return self._uid
|
|
||||||
|
|
||||||
@property
|
|
||||||
def hostname(self) -> str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def start_time(self) -> float:
|
|
||||||
return self._connected_time
|
|
||||||
|
|
||||||
@property
|
|
||||||
def identified(self) -> bool:
|
|
||||||
return self._identified
|
|
||||||
|
|
||||||
@property
|
|
||||||
def client_data(self) -> Dict[str, str]:
|
|
||||||
return self._client_data
|
|
||||||
|
|
||||||
@client_data.setter
|
|
||||||
def client_data(self, data: Dict[str, str]) -> None:
|
|
||||||
self._client_data = data
|
|
||||||
self._identified = True
|
|
||||||
|
|
||||||
async def _process_message(self, message: str) -> None:
|
|
||||||
try:
|
|
||||||
response = await self.rpc.dispatch(message, self)
|
|
||||||
if response is not None:
|
|
||||||
self.queue_message(response)
|
|
||||||
except Exception:
|
|
||||||
logging.exception("Websocket Command Error")
|
|
||||||
|
|
||||||
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
|
||||||
self.message_buf.append(message)
|
|
||||||
if self.queue_busy:
|
|
||||||
return
|
|
||||||
self.queue_busy = True
|
|
||||||
self.eventloop.register_callback(self._write_messages)
|
|
||||||
|
|
||||||
def authenticate(
|
|
||||||
self,
|
|
||||||
token: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
||||||
if auth is None:
|
|
||||||
return
|
|
||||||
if token is not None:
|
|
||||||
self.user_info = auth.validate_jwt(token)
|
|
||||||
elif api_key is not None and self.user_info is None:
|
|
||||||
self.user_info = auth.validate_api_key(api_key)
|
|
||||||
else:
|
|
||||||
self.check_authenticated()
|
|
||||||
|
|
||||||
def check_authenticated(self, path: str = "") -> None:
|
|
||||||
if not self._need_auth:
|
|
||||||
return
|
|
||||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
|
||||||
if auth is None:
|
|
||||||
return
|
|
||||||
if not auth.is_path_permitted(path):
|
|
||||||
raise self.server.error("Unauthorized", 401)
|
|
||||||
|
|
||||||
def on_user_logout(self, user: str) -> bool:
|
|
||||||
if self._user_info is None:
|
|
||||||
return False
|
|
||||||
if user == self._user_info.get("username", ""):
|
|
||||||
self._user_info = None
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _write_messages(self):
|
|
||||||
if self.is_closed:
|
|
||||||
self.message_buf = []
|
|
||||||
self.queue_busy = False
|
|
||||||
return
|
|
||||||
while self.message_buf:
|
|
||||||
msg = self.message_buf.pop(0)
|
|
||||||
await self.write_to_socket(msg)
|
|
||||||
self.queue_busy = False
|
|
||||||
|
|
||||||
async def write_to_socket(
|
|
||||||
self, message: Union[str, Dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError("Children must implement write_to_socket")
|
|
||||||
|
|
||||||
def send_status(self,
|
|
||||||
status: Dict[str, Any],
|
|
||||||
eventtime: float
|
|
||||||
) -> None:
|
|
||||||
if not status:
|
|
||||||
return
|
|
||||||
self.queue_message({
|
|
||||||
'jsonrpc': "2.0",
|
|
||||||
'method': "notify_status_update",
|
|
||||||
'params': [status, eventtime]})
|
|
||||||
|
|
||||||
def call_method(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
params: Optional[Union[List, Dict[str, Any]]] = None
|
|
||||||
) -> Awaitable:
|
|
||||||
fut = self.eventloop.create_future()
|
|
||||||
msg = {
|
|
||||||
'jsonrpc': "2.0",
|
|
||||||
'method': method,
|
|
||||||
'id': id(fut)
|
|
||||||
}
|
|
||||||
if params is not None:
|
|
||||||
msg["params"] = params
|
|
||||||
self.pending_responses[id(fut)] = fut
|
|
||||||
self.queue_message(msg)
|
|
||||||
return fut
|
|
||||||
|
|
||||||
def send_notification(self, name: str, data: List) -> None:
|
|
||||||
self.wsm.notify_clients(name, data, [self._uid])
|
|
||||||
|
|
||||||
def resolve_pending_response(
|
|
||||||
self, response_id: int, result: Any
|
|
||||||
) -> bool:
|
|
||||||
fut = self.pending_responses.pop(response_id, None)
|
|
||||||
if fut is None:
|
|
||||||
return False
|
|
||||||
if isinstance(result, ServerError):
|
|
||||||
fut.set_exception(result)
|
|
||||||
else:
|
|
||||||
fut.set_result(result)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def close_socket(self, code: int, reason: str) -> None:
|
|
||||||
raise NotImplementedError("Children must implement close_socket()")
|
|
||||||
|
|
||||||
class WebSocket(WebSocketHandler, BaseSocketClient):
|
|
||||||
connection_count: int = 0
|
connection_count: int = 0
|
||||||
|
|
||||||
def initialize(self) -> None:
|
def initialize(self) -> None:
|
||||||
|
|
Loading…
Reference in New Issue