diff --git a/moonraker/common.py b/moonraker/common.py index 6ea4b06..3d5884e 100644 --- a/moonraker/common.py +++ b/moonraker/common.py @@ -6,6 +6,9 @@ from __future__ import annotations import ipaddress +import logging +import copy +import json from .utils import ServerError, Sentinel # Annotation imports @@ -20,11 +23,14 @@ from typing import ( Union, Dict, List, + Awaitable ) if TYPE_CHECKING: - from .websockets import BaseSocketClient + from .server import Server + from .websockets import WebsocketManager from .components.authorization import Authorization + from asyncio import Future _T = TypeVar("_T") _C = TypeVar("_C", str, bool, float, int) IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -40,6 +46,202 @@ class Subscribable: ) -> None: raise NotImplementedError +class APIDefinition: + def __init__(self, + endpoint: str, + http_uri: str, + jrpc_methods: List[str], + request_methods: Union[str, List[str]], + transports: List[str], + callback: Optional[Callable[[WebRequest], Coroutine]], + need_object_parser: bool): + self.endpoint = endpoint + self.uri = http_uri + self.jrpc_methods = jrpc_methods + if not isinstance(request_methods, list): + request_methods = [request_methods] + self.request_methods = request_methods + self.supported_transports = transports + self.callback = callback + self.need_object_parser = need_object_parser + +class APITransport: + def register_api_handler(self, api_def: APIDefinition) -> None: + raise NotImplementedError + + def remove_api_handler(self, api_def: APIDefinition) -> None: + raise NotImplementedError + +class BaseRemoteConnection(Subscribable): + def on_create(self, server: Server) -> None: + self.server = server + self.eventloop = server.get_event_loop() + self.wsm: WebsocketManager = self.server.lookup_component("websockets") + self.rpc = self.wsm.rpc + self._uid = id(self) + self.ip_addr = "" + self.is_closed: bool = False + self.queue_busy: bool = False + self.pending_responses: Dict[int, Future] = {} + self.message_buf: List[Union[str, Dict[str, Any]]] = [] + self._connected_time: float = 0. + self._identified: bool = False + self._client_data: Dict[str, str] = { + "name": "unknown", + "version": "", + "type": "", + "url": "" + } + self._need_auth: bool = False + self._user_info: Optional[Dict[str, Any]] = None + + @property + def user_info(self) -> Optional[Dict[str, Any]]: + return self._user_info + + @user_info.setter + def user_info(self, uinfo: Dict[str, Any]) -> None: + self._user_info = uinfo + self._need_auth = False + + @property + def need_auth(self) -> bool: + return self._need_auth + + @property + def uid(self) -> int: + return self._uid + + @property + def hostname(self) -> str: + return "" + + @property + def start_time(self) -> float: + return self._connected_time + + @property + def identified(self) -> bool: + return self._identified + + @property + def client_data(self) -> Dict[str, str]: + return self._client_data + + @client_data.setter + def client_data(self, data: Dict[str, str]) -> None: + self._client_data = data + self._identified = True + + async def _process_message(self, message: str) -> None: + try: + response = await self.rpc.dispatch(message, self) + if response is not None: + self.queue_message(response) + except Exception: + logging.exception("Websocket Command Error") + + def queue_message(self, message: Union[str, Dict[str, Any]]): + self.message_buf.append(message) + if self.queue_busy: + return + self.queue_busy = True + self.eventloop.register_callback(self._write_messages) + + def authenticate( + self, + token: Optional[str] = None, + api_key: Optional[str] = None + ) -> None: + auth: AuthComp = self.server.lookup_component("authorization", None) + if auth is None: + return + if token is not None: + self.user_info = auth.validate_jwt(token) + elif api_key is not None and self.user_info is None: + self.user_info = auth.validate_api_key(api_key) + else: + self.check_authenticated() + + def check_authenticated(self, path: str = "") -> None: + if not self._need_auth: + return + auth: AuthComp = self.server.lookup_component("authorization", None) + if auth is None: + return + if not auth.is_path_permitted(path): + raise self.server.error("Unauthorized", 401) + + def on_user_logout(self, user: str) -> bool: + if self._user_info is None: + return False + if user == self._user_info.get("username", ""): + self._user_info = None + return True + return False + + async def _write_messages(self): + if self.is_closed: + self.message_buf = [] + self.queue_busy = False + return + while self.message_buf: + msg = self.message_buf.pop(0) + await self.write_to_socket(msg) + self.queue_busy = False + + async def write_to_socket( + self, message: Union[str, Dict[str, Any]] + ) -> None: + raise NotImplementedError("Children must implement write_to_socket") + + def send_status(self, + status: Dict[str, Any], + eventtime: float + ) -> None: + if not status: + return + self.queue_message({ + 'jsonrpc': "2.0", + 'method': "notify_status_update", + 'params': [status, eventtime]}) + + def call_method( + self, + method: str, + params: Optional[Union[List, Dict[str, Any]]] = None + ) -> Awaitable: + fut = self.eventloop.create_future() + msg = { + 'jsonrpc': "2.0", + 'method': method, + 'id': id(fut) + } + if params is not None: + msg["params"] = params + self.pending_responses[id(fut)] = fut + self.queue_message(msg) + return fut + + def send_notification(self, name: str, data: List) -> None: + self.wsm.notify_clients(name, data, [self._uid]) + + def resolve_pending_response( + self, response_id: int, result: Any + ) -> bool: + fut = self.pending_responses.pop(response_id, None) + if fut is None: + return False + if isinstance(result, ServerError): + fut.set_exception(result) + else: + fut.set_result(result) + return True + + def close_socket(self, code: int, reason: str) -> None: + raise NotImplementedError("Children must implement close_socket()") + + class WebRequest: def __init__(self, endpoint: str, @@ -72,8 +274,8 @@ class WebRequest: def get_subscribable(self) -> Optional[Subscribable]: return self.conn - def get_client_connection(self) -> Optional[BaseSocketClient]: - if isinstance(self.conn, BaseSocketClient): + def get_client_connection(self) -> Optional[BaseRemoteConnection]: + if isinstance(self.conn, BaseRemoteConnection): return self.conn return None @@ -142,28 +344,188 @@ class WebRequest: ) -> Union[bool, _T]: return self._get_converted_arg(key, default, bool) -class APIDefinition: - def __init__(self, - endpoint: str, - http_uri: str, - jrpc_methods: List[str], - request_methods: Union[str, List[str]], - transports: List[str], - callback: Optional[Callable[[WebRequest], Coroutine]], - need_object_parser: bool): - self.endpoint = endpoint - self.uri = http_uri - self.jrpc_methods = jrpc_methods - if not isinstance(request_methods, list): - request_methods = [request_methods] - self.request_methods = request_methods - self.supported_transports = transports - self.callback = callback - self.need_object_parser = need_object_parser -class APITransport: - def register_api_handler(self, api_def: APIDefinition) -> None: - raise NotImplementedError +class JsonRPC: + def __init__( + self, server: Server, transport: str = "Websocket" + ) -> None: + self.methods: Dict[str, RPCCallback] = {} + self.transport = transport + self.sanitize_response = False + self.verbose = server.is_verbose_enabled() - def remove_api_handler(self, api_def: APIDefinition) -> None: - raise NotImplementedError + def _log_request(self, rpc_obj: Dict[str, Any], ) -> None: + if not self.verbose: + return + self.sanitize_response = False + output = rpc_obj + method: Optional[str] = rpc_obj.get("method") + params: Dict[str, Any] = rpc_obj.get("params", {}) + if isinstance(method, str): + if ( + method.startswith("access.") or + method == "machine.sudo.password" + ): + self.sanitize_response = True + if params and isinstance(params, dict): + output = copy.deepcopy(rpc_obj) + output["params"] = {key: "" for key in params} + elif method == "server.connection.identify": + output = copy.deepcopy(rpc_obj) + for field in ["access_token", "api_key"]: + if field in params: + output["params"][field] = "" + logging.debug(f"{self.transport} Received::{json.dumps(output)}") + + def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: + if not self.verbose: + return + if resp_obj is None: + return + output = resp_obj + if self.sanitize_response and "result" in resp_obj: + output = copy.deepcopy(resp_obj) + output["result"] = "" + self.sanitize_response = False + logging.debug(f"{self.transport} Response::{json.dumps(output)}") + + def register_method(self, + name: str, + method: RPCCallback + ) -> None: + self.methods[name] = method + + def remove_method(self, name: str) -> None: + self.methods.pop(name, None) + + async def dispatch(self, + data: str, + conn: Optional[BaseRemoteConnection] = None + ) -> Optional[str]: + try: + obj: Union[Dict[str, Any], List[dict]] = json.loads(data) + except Exception: + msg = f"{self.transport} data not json: {data}" + logging.exception(msg) + err = self.build_error(-32700, "Parse error") + return json.dumps(err) + if isinstance(obj, list): + responses: List[Dict[str, Any]] = [] + for item in obj: + self._log_request(item) + resp = await self.process_object(item, conn) + if resp is not None: + self._log_response(resp) + responses.append(resp) + if responses: + return json.dumps(responses) + else: + self._log_request(obj) + response = await self.process_object(obj, conn) + if response is not None: + self._log_response(response) + return json.dumps(response) + return None + + async def process_object(self, + obj: Dict[str, Any], + conn: Optional[BaseRemoteConnection] + ) -> Optional[Dict[str, Any]]: + req_id: Optional[int] = obj.get('id', None) + rpc_version: str = obj.get('jsonrpc', "") + if rpc_version != "2.0": + return self.build_error(-32600, "Invalid Request", req_id) + method_name = obj.get('method', Sentinel.MISSING) + if method_name is Sentinel.MISSING: + self.process_response(obj, conn) + return None + if not isinstance(method_name, str): + return self.build_error(-32600, "Invalid Request", req_id) + method = self.methods.get(method_name, None) + if method is None: + return self.build_error(-32601, "Method not found", req_id) + params: Dict[str, Any] = {} + if 'params' in obj: + params = obj['params'] + if not isinstance(params, dict): + return self.build_error( + -32602, f"Invalid params:", req_id, True) + response = await self.execute_method(method, req_id, conn, params) + return response + + def process_response( + self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection] + ) -> None: + if conn is None: + logging.debug(f"RPC Response to non-socket request: {obj}") + return + response_id = obj.get("id") + if response_id is None: + logging.debug(f"RPC Response with null ID: {obj}") + return + result = obj.get("result") + if result is None: + name = conn.client_data["name"] + error = obj.get("error") + msg = f"Invalid Response: {obj}" + code = -32600 + if isinstance(error, dict): + msg = error.get("message", msg) + code = error.get("code", code) + msg = f"{name} rpc error: {code} {msg}" + ret = ServerError(msg, 418) + else: + ret = result + conn.resolve_pending_response(response_id, ret) + + async def execute_method(self, + callback: RPCCallback, + req_id: Optional[int], + conn: Optional[BaseRemoteConnection], + params: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + if conn is not None: + params["_socket_"] = conn + try: + result = await callback(params) + except TypeError as e: + return self.build_error( + -32602, f"Invalid params:\n{e}", req_id, True) + except ServerError as e: + code = e.status_code + if code == 404: + code = -32601 + elif code == 401: + code = -32602 + return self.build_error(code, str(e), req_id, True) + except Exception as e: + return self.build_error(-31000, str(e), req_id, True) + + if req_id is None: + return None + else: + return self.build_result(result, req_id) + + def build_result(self, result: Any, req_id: int) -> Dict[str, Any]: + return { + 'jsonrpc': "2.0", + 'result': result, + 'id': req_id + } + + def build_error(self, + code: int, + msg: str, + req_id: Optional[int] = None, + is_exc: bool = False + ) -> Dict[str, Any]: + log_msg = f"JSON-RPC Request Error: {code}\n{msg}" + if is_exc: + logging.exception(log_msg) + else: + logging.info(log_msg) + return { + 'jsonrpc': "2.0", + 'error': {'code': code, 'message': msg}, + 'id': req_id + } diff --git a/moonraker/components/extensions.py b/moonraker/components/extensions.py index 1989145..a2a6d28 100644 --- a/moonraker/components/extensions.py +++ b/moonraker/components/extensions.py @@ -8,7 +8,7 @@ import asyncio import pathlib import logging import json -from ..websockets import BaseSocketClient +from ..common import BaseRemoteConnection from ..utils import get_unix_peer_credentials # Annotation imports @@ -32,7 +32,7 @@ UNIX_BUFFER_LIMIT = 20 * 1024 * 1024 class ExtensionManager: def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() - self.agents: Dict[str, BaseSocketClient] = {} + self.agents: Dict[str, BaseRemoteConnection] = {} self.uds_server: Optional[asyncio.AbstractServer] = None self.server.register_endpoint( "/connection/send_event", ["POST"], self._handle_agent_event, @@ -45,7 +45,7 @@ class ExtensionManager: "/server/extensions/request", ["POST"], self._handle_call_agent ) - def register_agent(self, connection: BaseSocketClient) -> None: + def register_agent(self, connection: BaseRemoteConnection) -> None: data = connection.client_data name = data["name"] client_type = data["type"] @@ -64,7 +64,7 @@ class ExtensionManager: } connection.send_notification("agent_event", [evt]) - def remove_agent(self, connection: BaseSocketClient) -> None: + def remove_agent(self, connection: BaseRemoteConnection) -> None: name = connection.client_data["name"] if name in self.agents: del self.agents[name] @@ -135,7 +135,7 @@ class ExtensionManager: await self.uds_server.wait_closed() self.uds_server = None -class UnixSocketClient(BaseSocketClient): +class UnixSocketClient(BaseRemoteConnection): def __init__( self, server: Server, diff --git a/moonraker/components/mqtt.py b/moonraker/components/mqtt.py index 16ed802..b5d36e2 100644 --- a/moonraker/components/mqtt.py +++ b/moonraker/components/mqtt.py @@ -13,8 +13,7 @@ import pathlib import ssl from collections import deque import paho.mqtt.client as paho_mqtt -from ..common import Subscribable, WebRequest, APITransport -from ..websockets import JsonRPC +from ..common import Subscribable, WebRequest, APITransport, JsonRPC # Annotation imports from typing import ( diff --git a/moonraker/components/simplyprint.py b/moonraker/components/simplyprint.py index b8f5983..9ed89b8 100644 --- a/moonraker/components/simplyprint.py +++ b/moonraker/components/simplyprint.py @@ -32,7 +32,8 @@ from typing import ( if TYPE_CHECKING: from ..app import InternalTransport from ..confighelper import ConfigHelper - from ..websockets import WebsocketManager, BaseSocketClient + from ..websockets import WebsocketManager + from ..common import BaseRemoteConnection from tornado.websocket import WebSocketClientConnection from .database import MoonrakerDatabase from .klippy_apis import KlippyAPI @@ -609,7 +610,7 @@ class SimplyPrint(Subscribable): is_on = device_info["status"] == "on" self.send_sp("power_controller", {"on": is_on}) - def _on_websocket_identified(self, ws: BaseSocketClient) -> None: + def _on_websocket_identified(self, ws: BaseRemoteConnection) -> None: if ( self.cache.current_wsid is None and ws.client_data.get("type", "") == "web" @@ -622,7 +623,7 @@ class SimplyPrint(Subscribable): self.cache.current_wsid = ws.uid self.send_sp("machine_data", ui_data) - def _on_websocket_removed(self, ws: BaseSocketClient) -> None: + def _on_websocket_removed(self, ws: BaseRemoteConnection) -> None: if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid: return ui_data = self._get_ui_info() diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 93c99ea..ada6e8d 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -7,13 +7,17 @@ from __future__ import annotations import logging import ipaddress -import json import asyncio -import copy from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.web import HTTPError -from .common import WebRequest, Subscribable, APITransport, APIDefinition -from .utils import ServerError, Sentinel +from .common import ( + WebRequest, + BaseRemoteConnection, + APITransport, + APIDefinition, + JsonRPC +) +from .utils import ServerError # Annotation imports from typing import ( @@ -42,195 +46,10 @@ if TYPE_CHECKING: CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"] -class JsonRPC: - def __init__( - self, server: Server, transport: str = "Websocket" - ) -> None: - self.methods: Dict[str, RPCCallback] = {} - self.transport = transport - self.sanitize_response = False - self.verbose = server.is_verbose_enabled() - - def _log_request(self, rpc_obj: Dict[str, Any], ) -> None: - if not self.verbose: - return - self.sanitize_response = False - output = rpc_obj - method: Optional[str] = rpc_obj.get("method") - params: Dict[str, Any] = rpc_obj.get("params", {}) - if isinstance(method, str): - if ( - method.startswith("access.") or - method == "machine.sudo.password" - ): - self.sanitize_response = True - if params and isinstance(params, dict): - output = copy.deepcopy(rpc_obj) - output["params"] = {key: "" for key in params} - elif method == "server.connection.identify": - output = copy.deepcopy(rpc_obj) - for field in ["access_token", "api_key"]: - if field in params: - output["params"][field] = "" - logging.debug(f"{self.transport} Received::{json.dumps(output)}") - - def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None: - if not self.verbose: - return - if resp_obj is None: - return - output = resp_obj - if self.sanitize_response and "result" in resp_obj: - output = copy.deepcopy(resp_obj) - output["result"] = "" - self.sanitize_response = False - logging.debug(f"{self.transport} Response::{json.dumps(output)}") - - def register_method(self, - name: str, - method: RPCCallback - ) -> None: - self.methods[name] = method - - def remove_method(self, name: str) -> None: - self.methods.pop(name, None) - - async def dispatch(self, - data: str, - conn: Optional[BaseSocketClient] = None - ) -> Optional[str]: - try: - obj: Union[Dict[str, Any], List[dict]] = json.loads(data) - except Exception: - msg = f"{self.transport} data not json: {data}" - logging.exception(msg) - err = self.build_error(-32700, "Parse error") - return json.dumps(err) - if isinstance(obj, list): - responses: List[Dict[str, Any]] = [] - for item in obj: - self._log_request(item) - resp = await self.process_object(item, conn) - if resp is not None: - self._log_response(resp) - responses.append(resp) - if responses: - return json.dumps(responses) - else: - self._log_request(obj) - response = await self.process_object(obj, conn) - if response is not None: - self._log_response(response) - return json.dumps(response) - return None - - async def process_object(self, - obj: Dict[str, Any], - conn: Optional[BaseSocketClient] - ) -> Optional[Dict[str, Any]]: - req_id: Optional[int] = obj.get('id', None) - rpc_version: str = obj.get('jsonrpc', "") - if rpc_version != "2.0": - return self.build_error(-32600, "Invalid Request", req_id) - method_name = obj.get('method', Sentinel.MISSING) - if method_name is Sentinel.MISSING: - self.process_response(obj, conn) - return None - if not isinstance(method_name, str): - return self.build_error(-32600, "Invalid Request", req_id) - method = self.methods.get(method_name, None) - if method is None: - return self.build_error(-32601, "Method not found", req_id) - params: Dict[str, Any] = {} - if 'params' in obj: - params = obj['params'] - if not isinstance(params, dict): - return self.build_error( - -32602, f"Invalid params:", req_id, True) - response = await self.execute_method(method, req_id, conn, params) - return response - - def process_response( - self, obj: Dict[str, Any], conn: Optional[BaseSocketClient] - ) -> None: - if conn is None: - logging.debug(f"RPC Response to non-socket request: {obj}") - return - response_id = obj.get("id") - if response_id is None: - logging.debug(f"RPC Response with null ID: {obj}") - return - result = obj.get("result") - if result is None: - name = conn.client_data["name"] - error = obj.get("error") - msg = f"Invalid Response: {obj}" - code = -32600 - if isinstance(error, dict): - msg = error.get("message", msg) - code = error.get("code", code) - msg = f"{name} rpc error: {code} {msg}" - ret = ServerError(msg, 418) - else: - ret = result - conn.resolve_pending_response(response_id, ret) - - async def execute_method(self, - callback: RPCCallback, - req_id: Optional[int], - conn: Optional[BaseSocketClient], - params: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - if conn is not None: - params["_socket_"] = conn - try: - result = await callback(params) - except TypeError as e: - return self.build_error( - -32602, f"Invalid params:\n{e}", req_id, True) - except ServerError as e: - code = e.status_code - if code == 404: - code = -32601 - elif code == 401: - code = -32602 - return self.build_error(code, str(e), req_id, True) - except Exception as e: - return self.build_error(-31000, str(e), req_id, True) - - if req_id is None: - return None - else: - return self.build_result(result, req_id) - - def build_result(self, result: Any, req_id: int) -> Dict[str, Any]: - return { - 'jsonrpc': "2.0", - 'result': result, - 'id': req_id - } - - def build_error(self, - code: int, - msg: str, - req_id: Optional[int] = None, - is_exc: bool = False - ) -> Dict[str, Any]: - log_msg = f"JSON-RPC Request Error: {code}\n{msg}" - if is_exc: - logging.exception(log_msg) - else: - logging.info(log_msg) - return { - 'jsonrpc': "2.0", - 'error': {'code': code, 'message': msg}, - 'id': req_id - } - class WebsocketManager(APITransport): def __init__(self, server: Server) -> None: self.server = server - self.clients: Dict[int, BaseSocketClient] = {} + self.clients: Dict[int, BaseRemoteConnection] = {} self.bridge_connections: Dict[int, BridgeSocket] = {} self.rpc = JsonRPC(server) self.closed_event: Optional[asyncio.Event] = None @@ -289,7 +108,7 @@ class WebsocketManager(APITransport): callback: Callable[[WebRequest], Coroutine] ) -> RPCCallback: async def func(args: Dict[str, Any]) -> Any: - sc: BaseSocketClient = args.pop("_socket_") + sc: BaseRemoteConnection = args.pop("_socket_") sc.check_authenticated(path=endpoint) result = await callback( WebRequest(endpoint, args, request_method, sc, @@ -298,12 +117,12 @@ class WebsocketManager(APITransport): return func async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: - sc: BaseSocketClient = args["_socket_"] + sc: BaseRemoteConnection = args["_socket_"] sc.check_authenticated() return {'websocket_id': sc.uid} async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: - sc: BaseSocketClient = args["_socket_"] + sc: BaseRemoteConnection = args["_socket_"] sc.authenticate( token=args.get("access_token", None), api_key=args.get("api_key", None) @@ -355,7 +174,7 @@ class WebsocketManager(APITransport): def has_socket(self, ws_id: int) -> bool: return ws_id in self.clients - def get_client(self, ws_id: int) -> Optional[BaseSocketClient]: + def get_client(self, ws_id: int) -> Optional[BaseRemoteConnection]: sc = self.clients.get(ws_id, None) if sc is None or not isinstance(sc, WebSocket): return None @@ -363,37 +182,37 @@ class WebsocketManager(APITransport): def get_clients_by_type( self, client_type: str - ) -> List[BaseSocketClient]: + ) -> List[BaseRemoteConnection]: if not client_type: return [] - ret: List[BaseSocketClient] = [] + ret: List[BaseRemoteConnection] = [] for sc in self.clients.values(): if sc.client_data.get("type", "") == client_type.lower(): ret.append(sc) return ret - def get_clients_by_name(self, name: str) -> List[BaseSocketClient]: + def get_clients_by_name(self, name: str) -> List[BaseRemoteConnection]: if not name: return [] - ret: List[BaseSocketClient] = [] + ret: List[BaseRemoteConnection] = [] for sc in self.clients.values(): if sc.client_data.get("name", "").lower() == name.lower(): ret.append(sc) return ret - def get_unidentified_clients(self) -> List[BaseSocketClient]: - ret: List[BaseSocketClient] = [] + def get_unidentified_clients(self) -> List[BaseRemoteConnection]: + ret: List[BaseRemoteConnection] = [] for sc in self.clients.values(): if not sc.client_data: ret.append(sc) return ret - def add_client(self, sc: BaseSocketClient) -> None: + def add_client(self, sc: BaseRemoteConnection) -> None: self.clients[sc.uid] = sc self.server.send_event("websockets:client_added", sc) logging.debug(f"New Websocket Added: {sc.uid}") - def remove_client(self, sc: BaseSocketClient) -> None: + def remove_client(self, sc: BaseRemoteConnection) -> None: old_sc = self.clients.pop(sc.uid, None) if old_sc is not None: self.server.send_event("websockets:client_removed", sc) @@ -449,176 +268,7 @@ class WebsocketManager(APITransport): pass self.closed_event = None -class BaseSocketClient(Subscribable): - def on_create(self, server: Server) -> None: - self.server = server - self.eventloop = server.get_event_loop() - self.wsm: WebsocketManager = self.server.lookup_component("websockets") - self.rpc = self.wsm.rpc - self._uid = id(self) - self.ip_addr = "" - self.is_closed: bool = False - self.queue_busy: bool = False - self.pending_responses: Dict[int, asyncio.Future] = {} - self.message_buf: List[Union[str, Dict[str, Any]]] = [] - self._connected_time: float = 0. - self._identified: bool = False - self._client_data: Dict[str, str] = { - "name": "unknown", - "version": "", - "type": "", - "url": "" - } - self._need_auth: bool = False - self._user_info: Optional[Dict[str, Any]] = None - - @property - def user_info(self) -> Optional[Dict[str, Any]]: - return self._user_info - - @user_info.setter - def user_info(self, uinfo: Dict[str, Any]) -> None: - self._user_info = uinfo - self._need_auth = False - - @property - def need_auth(self) -> bool: - return self._need_auth - - @property - def uid(self) -> int: - return self._uid - - @property - def hostname(self) -> str: - return "" - - @property - def start_time(self) -> float: - return self._connected_time - - @property - def identified(self) -> bool: - return self._identified - - @property - def client_data(self) -> Dict[str, str]: - return self._client_data - - @client_data.setter - def client_data(self, data: Dict[str, str]) -> None: - self._client_data = data - self._identified = True - - async def _process_message(self, message: str) -> None: - try: - response = await self.rpc.dispatch(message, self) - if response is not None: - self.queue_message(response) - except Exception: - logging.exception("Websocket Command Error") - - def queue_message(self, message: Union[str, Dict[str, Any]]): - self.message_buf.append(message) - if self.queue_busy: - return - self.queue_busy = True - self.eventloop.register_callback(self._write_messages) - - def authenticate( - self, - token: Optional[str] = None, - api_key: Optional[str] = None - ) -> None: - auth: AuthComp = self.server.lookup_component("authorization", None) - if auth is None: - return - if token is not None: - self.user_info = auth.validate_jwt(token) - elif api_key is not None and self.user_info is None: - self.user_info = auth.validate_api_key(api_key) - else: - self.check_authenticated() - - def check_authenticated(self, path: str = "") -> None: - if not self._need_auth: - return - auth: AuthComp = self.server.lookup_component("authorization", None) - if auth is None: - return - if not auth.is_path_permitted(path): - raise self.server.error("Unauthorized", 401) - - def on_user_logout(self, user: str) -> bool: - if self._user_info is None: - return False - if user == self._user_info.get("username", ""): - self._user_info = None - return True - return False - - async def _write_messages(self): - if self.is_closed: - self.message_buf = [] - self.queue_busy = False - return - while self.message_buf: - msg = self.message_buf.pop(0) - await self.write_to_socket(msg) - self.queue_busy = False - - async def write_to_socket( - self, message: Union[str, Dict[str, Any]] - ) -> None: - raise NotImplementedError("Children must implement write_to_socket") - - def send_status(self, - status: Dict[str, Any], - eventtime: float - ) -> None: - if not status: - return - self.queue_message({ - 'jsonrpc': "2.0", - 'method': "notify_status_update", - 'params': [status, eventtime]}) - - def call_method( - self, - method: str, - params: Optional[Union[List, Dict[str, Any]]] = None - ) -> Awaitable: - fut = self.eventloop.create_future() - msg = { - 'jsonrpc': "2.0", - 'method': method, - 'id': id(fut) - } - if params is not None: - msg["params"] = params - self.pending_responses[id(fut)] = fut - self.queue_message(msg) - return fut - - def send_notification(self, name: str, data: List) -> None: - self.wsm.notify_clients(name, data, [self._uid]) - - def resolve_pending_response( - self, response_id: int, result: Any - ) -> bool: - fut = self.pending_responses.pop(response_id, None) - if fut is None: - return False - if isinstance(result, ServerError): - fut.set_exception(result) - else: - fut.set_result(result) - return True - - def close_socket(self, code: int, reason: str) -> None: - raise NotImplementedError("Children must implement close_socket()") - -class WebSocket(WebSocketHandler, BaseSocketClient): +class WebSocket(WebSocketHandler, BaseRemoteConnection): connection_count: int = 0 def initialize(self) -> None: