diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 0c67fb3..b53d788 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -4,108 +4,169 @@ # # This file may be distributed under the terms of the GNU GPLv3 license +from __future__ import annotations import logging import ipaddress -import tornado import json from tornado.ioloop import IOLoop from tornado.websocket import WebSocketHandler, WebSocketClosedError -from utils import ServerError +from tornado.locks import Lock +from utils import ServerError, SentinelClass -class Sentinel: - pass +# Annotation imports +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Callable, + Coroutine, + Type, + TypeVar, + Union, + Dict, + List, +) +if TYPE_CHECKING: + from moonraker import Server + from app import APIDefinition, MoonrakerApp + import components.authorization + _T = TypeVar("_T") + _C = TypeVar("_C", str, bool, float, int) + IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + ConvType = Union[str, bool, float, int] + ArgVal = Union[None, int, float, bool, str] + RPCCallback = Callable[..., Coroutine] + AuthComp = Optional[components.authorization.Authorization] + +SENTINEL = SentinelClass.get_instance() + +class Subscribable: + def send_status(self, status: Dict[str, Any]) -> None: + raise NotImplementedError class WebRequest: - def __init__(self, endpoint, args, action="", - conn=None, ip_addr="", user=None): + def __init__(self, + endpoint: str, + args: Dict[str, Any], + action: Optional[str] = "", + conn: Optional[Subscribable] = None, + ip_addr: str = "", + user: Optional[Dict[str, Any]] = None + ) -> None: self.endpoint = endpoint - self.action = action + self.action = action or "" self.args = args self.conn = conn + self.ip_addr: Optional[IPUnion] = None try: self.ip_addr = ipaddress.ip_address(ip_addr) except Exception: self.ip_addr = None self.current_user = user - def get_endpoint(self): + def get_endpoint(self) -> str: return self.endpoint - def get_action(self): + def get_action(self) -> str: return self.action - def get_args(self): + def get_args(self) -> Dict[str, Any]: return self.args - def get_connection(self): + def get_connection(self) -> Optional[Subscribable]: return self.conn - def get_ip_address(self): + def get_ip_address(self) -> Optional[IPUnion]: return self.ip_addr - def get_current_user(self): + def get_current_user(self) -> Optional[Dict[str, Any]]: return self.current_user - def _get_converted_arg(self, key, default=Sentinel, dtype=str): + def _get_converted_arg(self, + key: str, + default: Union[SentinelClass, _T], + dtype: Type[_C] + ) -> Union[_C, _T]: if key not in self.args: - if default == Sentinel: + if isinstance(default, SentinelClass): raise ServerError(f"No data for argument: {key}") return default val = self.args[key] try: - if dtype != bool: + if dtype is not bool: return dtype(val) else: if isinstance(val, str): val = val.lower() if val in ["true", "false"]: - return True if val == "true" else False + return True if val == "true" else False # type: ignore elif isinstance(val, bool): - return val + return val # type: ignore raise TypeError except Exception: raise ServerError( f"Unable to convert argument [{key}] to {dtype}: " f"value recieved: {val}") - def get(self, key, default=Sentinel): + def get(self, + key: str, + default: Union[SentinelClass, _T] = SENTINEL + ) -> Union[_T, Any]: val = self.args.get(key, default) - if val == Sentinel: + if isinstance(val, SentinelClass): raise ServerError(f"No data for argument: {key}") return val - def get_str(self, key, default=Sentinel): - return self._get_converted_arg(key, default) + def get_str(self, + key: str, + default: Union[SentinelClass, _T] = SENTINEL + ) -> Union[str, _T]: + return self._get_converted_arg(key, default, str) - def get_int(self, key, default=Sentinel): + def get_int(self, + key: str, + default: Union[SentinelClass, _T] = SENTINEL + ) -> Union[int, _T]: return self._get_converted_arg(key, default, int) - def get_float(self, key, default=Sentinel): + def get_float(self, + key: str, + default: Union[SentinelClass, _T] = SENTINEL + ) -> Union[float, _T]: return self._get_converted_arg(key, default, float) - def get_boolean(self, key, default=Sentinel): + def get_boolean(self, + key: str, + default: Union[SentinelClass, _T] = SENTINEL + ) -> Union[bool, _T]: return self._get_converted_arg(key, default, bool) class JsonRPC: - def __init__(self): - self.methods = {} + def __init__(self) -> None: + self.methods: Dict[str, RPCCallback] = {} - def register_method(self, name, method): + def register_method(self, + name: str, + method: RPCCallback + ) -> None: self.methods[name] = method - def remove_method(self, name): + def remove_method(self, name: str) -> None: self.methods.pop(name, None) - async def dispatch(self, data, ws): - response = None + async def dispatch(self, + data: str, + ws: WebSocket + ) -> Optional[str]: + response: Any = None try: - request = json.loads(data) + request: Union[Dict[str, Any], List[dict]] = json.loads(data) except Exception: msg = f"Websocket data not json: {data}" logging.exception(msg) response = self.build_error(-32700, "Parse error") return json.dumps(response) - logging.debug("Websocket Request::" + data) + logging.debug(f"Websocket Request::{data}") if isinstance(request, list): response = [] for req in request: @@ -121,9 +182,12 @@ class JsonRPC: logging.debug("Websocket Response::" + response) return response - async def process_request(self, request, ws): - req_id = request.get('id', None) - rpc_version = request.get('jsonrpc', "") + async def process_request(self, + request: Dict[str, Any], + ws: WebSocket + ) -> Optional[Dict[str, Any]]: + req_id: Optional[int] = request.get('id', None) + rpc_version: str = request.get('jsonrpc', "") method_name = request.get('method', None) if rpc_version != "2.0" or not isinstance(method_name, str): return self.build_error(-32600, "Invalid Request", req_id) @@ -144,7 +208,13 @@ class JsonRPC: response = await self.execute_method(method, req_id, ws) return response - async def execute_method(self, method, req_id, ws, *args, **kwargs): + async def execute_method(self, + method: RPCCallback, + req_id: Optional[int], + ws: WebSocket, + *args, + **kwargs + ) -> Optional[Dict[str, Any]]: try: result = await method(ws, *args, **kwargs) except TypeError as e: @@ -160,14 +230,19 @@ class JsonRPC: else: return self.build_result(result, req_id) - def build_result(self, result, req_id): + def build_result(self, result: Any, req_id: int) -> Dict[str, Any]: return { 'jsonrpc': "2.0", 'result': result, 'id': req_id } - def build_error(self, code, msg, req_id=None, is_exc=False): + def build_error(self, + code: int, + msg: str, + req_id: Optional[int] = None, + is_exc: bool = False + ) -> Dict[str, Any]: log_msg = f"JSON-RPC Request Error: {code}\n{msg}" if is_exc: logging.exception(log_msg) @@ -180,15 +255,18 @@ class JsonRPC: } class WebsocketManager: - def __init__(self, server): + def __init__(self, server: Server) -> None: self.server = server - self.websockets = {} - self.ws_lock = tornado.locks.Lock() + self.websockets: Dict[int, WebSocket] = {} + self.ws_lock = Lock() self.rpc = JsonRPC() self.rpc.register_method("server.websocket.id", self._handle_id_request) - def register_notification(self, event_name, notify_name=None): + def register_notification(self, + event_name: str, + notify_name: Optional[str] = None + ) -> None: if notify_name is None: notify_name = event_name.split(':')[-1] @@ -197,61 +275,74 @@ class WebsocketManager: self.server.register_event_handler( event_name, notify_handler) - def register_local_handler(self, api_def, callback): + def register_local_handler(self, + api_def: APIDefinition, + callback: Callable[[WebRequest], Coroutine] + ) -> None: for ws_method, req_method in \ zip(api_def.ws_methods, api_def.request_methods): rpc_cb = self._generate_local_callback( api_def.endpoint, req_method, callback) self.rpc.register_method(ws_method, rpc_cb) - def register_remote_handler(self, api_def): + def register_remote_handler(self, api_def: APIDefinition) -> None: ws_method = api_def.ws_methods[0] rpc_cb = self._generate_callback(api_def.endpoint) self.rpc.register_method(ws_method, rpc_cb) - def remove_handler(self, ws_method): + def remove_handler(self, ws_method: str) -> None: self.rpc.remove_method(ws_method) - def _generate_callback(self, endpoint): - async def func(ws, **kwargs): + def _generate_callback(self, endpoint: str) -> RPCCallback: + async def func(ws: WebSocket, **kwargs) -> Any: result = await self.server.make_request( WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr, user=ws.current_user)) return result return func - def _generate_local_callback(self, endpoint, request_method, callback): - async def func(ws, **kwargs): + def _generate_local_callback(self, + endpoint: str, + request_method: str, + callback: Callable[[WebRequest], Coroutine] + ) -> RPCCallback: + async def func(ws: WebSocket, **kwargs) -> Any: result = await callback( WebRequest(endpoint, kwargs, request_method, ws, ip_addr=ws.ip_addr, user=ws.current_user)) return result return func - async def _handle_id_request(self, ws, **kwargs): + async def _handle_id_request(self, + ws: WebSocket, + **kwargs + ) -> Dict[str, int]: return {'websocket_id': ws.uid} - def has_websocket(self, ws_id): + def has_websocket(self, ws_id: int) -> bool: return ws_id in self.websockets - def get_websocket(self, ws_id): + def get_websocket(self, ws_id: int) -> Optional[WebSocket]: return self.websockets.get(ws_id, None) - async def add_websocket(self, ws): + async def add_websocket(self, ws: WebSocket) -> None: async with self.ws_lock: self.websockets[ws.uid] = ws logging.info(f"New Websocket Added: {ws.uid}") - async def remove_websocket(self, ws): + async def remove_websocket(self, ws: WebSocket) -> None: async with self.ws_lock: old_ws = self.websockets.pop(ws.uid, None) if old_ws is not None: self.server.remove_subscription(old_ws) logging.info(f"Websocket Removed: {ws.uid}") - async def notify_websockets(self, name, data=Sentinel): - msg = {'jsonrpc': "2.0", 'method': "notify_" + name} - if data != Sentinel: + async def notify_websockets(self, + name: str, + data: Any = SENTINEL + ) -> None: + msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name} + if data != SENTINEL: msg['params'] = [data] async with self.ws_lock: for ws in list(self.websockets.values()): @@ -265,30 +356,30 @@ class WebsocketManager: logging.exception( f"Error sending data over websocket: {ws.uid}") - async def close(self): + async def close(self) -> None: async with self.ws_lock: for ws in list(self.websockets.values()): ws.close() self.websockets = {} -class WebSocket(WebSocketHandler): - def initialize(self): - app = self.settings['parent'] +class WebSocket(WebSocketHandler, Subscribable): + def initialize(self) -> None: + app: MoonrakerApp = self.settings['parent'] self.server = app.get_server() self.wsm = app.get_websocket_manager() self.rpc = self.wsm.rpc self.uid = id(self) - self.is_closed = False - self.ip_addr = self.request.remote_ip + self.is_closed: bool = False + self.ip_addr: str = self.request.remote_ip - async def open(self): + async def open(self, *args, **kwargs) -> None: await self.wsm.add_websocket(self) - def on_message(self, message): + def on_message(self, message: Union[bytes, str]) -> None: io_loop = IOLoop.current() io_loop.spawn_callback(self._process_message, message) - async def _process_message(self, message): + async def _process_message(self, message: str) -> None: try: response = await self.rpc.dispatch(message, self) if response is not None: @@ -296,7 +387,7 @@ class WebSocket(WebSocketHandler): except Exception: logging.exception("Websocket Command Error") - def send_status(self, status): + def send_status(self, status: Dict[str, Any]) -> None: if not status or self.is_closed: return try: @@ -312,14 +403,14 @@ class WebSocket(WebSocketHandler): logging.exception( f"Error sending data over websocket: {self.uid}") - def on_close(self): + def on_close(self) -> None: self.is_closed = True io_loop = IOLoop.current() io_loop.spawn_callback(self.wsm.remove_websocket, self) - def check_origin(self, origin): + def check_origin(self, origin: str) -> bool: if not super(WebSocket, self).check_origin(origin): - auth = self.server.lookup_component('authorization', None) + auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: return auth.check_cors(origin) return False @@ -327,6 +418,6 @@ class WebSocket(WebSocketHandler): # Check Authorized User def prepare(self): - auth = self.server.lookup_component('authorization', None) + auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: self.current_user = auth.check_authorized(self.request)