websockets: add annotations

Implement a "Subscribable" base class for objects that can maintain a status subscription.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-11 18:14:07 -04:00
parent 96e69240ca
commit 9c76dbef7a
1 changed files with 163 additions and 72 deletions

View File

@ -4,108 +4,169 @@
# #
# This file may be distributed under the terms of the GNU GPLv3 license # This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import logging import logging
import ipaddress import ipaddress
import tornado
import json import json
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.websocket import WebSocketHandler, WebSocketClosedError
from utils import ServerError from tornado.locks import Lock
from utils import ServerError, SentinelClass
class Sentinel: # Annotation imports
pass from typing import (
TYPE_CHECKING,
Any,
Optional,
Callable,
Coroutine,
Type,
TypeVar,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from moonraker import Server
from app import APIDefinition, MoonrakerApp
import components.authorization
_T = TypeVar("_T")
_C = TypeVar("_C", str, bool, float, int)
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
ConvType = Union[str, bool, float, int]
ArgVal = Union[None, int, float, bool, str]
RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[components.authorization.Authorization]
SENTINEL = SentinelClass.get_instance()
class Subscribable:
def send_status(self, status: Dict[str, Any]) -> None:
raise NotImplementedError
class WebRequest: class WebRequest:
def __init__(self, endpoint, args, action="", def __init__(self,
conn=None, ip_addr="", user=None): endpoint: str,
args: Dict[str, Any],
action: Optional[str] = "",
conn: Optional[Subscribable] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
self.endpoint = endpoint self.endpoint = endpoint
self.action = action self.action = action or ""
self.args = args self.args = args
self.conn = conn self.conn = conn
self.ip_addr: Optional[IPUnion] = None
try: try:
self.ip_addr = ipaddress.ip_address(ip_addr) self.ip_addr = ipaddress.ip_address(ip_addr)
except Exception: except Exception:
self.ip_addr = None self.ip_addr = None
self.current_user = user self.current_user = user
def get_endpoint(self): def get_endpoint(self) -> str:
return self.endpoint return self.endpoint
def get_action(self): def get_action(self) -> str:
return self.action return self.action
def get_args(self): def get_args(self) -> Dict[str, Any]:
return self.args return self.args
def get_connection(self): def get_connection(self) -> Optional[Subscribable]:
return self.conn return self.conn
def get_ip_address(self): def get_ip_address(self) -> Optional[IPUnion]:
return self.ip_addr return self.ip_addr
def get_current_user(self): def get_current_user(self) -> Optional[Dict[str, Any]]:
return self.current_user return self.current_user
def _get_converted_arg(self, key, default=Sentinel, dtype=str): def _get_converted_arg(self,
key: str,
default: Union[SentinelClass, _T],
dtype: Type[_C]
) -> Union[_C, _T]:
if key not in self.args: if key not in self.args:
if default == Sentinel: if isinstance(default, SentinelClass):
raise ServerError(f"No data for argument: {key}") raise ServerError(f"No data for argument: {key}")
return default return default
val = self.args[key] val = self.args[key]
try: try:
if dtype != bool: if dtype is not bool:
return dtype(val) return dtype(val)
else: else:
if isinstance(val, str): if isinstance(val, str):
val = val.lower() val = val.lower()
if val in ["true", "false"]: if val in ["true", "false"]:
return True if val == "true" else False return True if val == "true" else False # type: ignore
elif isinstance(val, bool): elif isinstance(val, bool):
return val return val # type: ignore
raise TypeError raise TypeError
except Exception: except Exception:
raise ServerError( raise ServerError(
f"Unable to convert argument [{key}] to {dtype}: " f"Unable to convert argument [{key}] to {dtype}: "
f"value recieved: {val}") f"value recieved: {val}")
def get(self, key, default=Sentinel): def get(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[_T, Any]:
val = self.args.get(key, default) val = self.args.get(key, default)
if val == Sentinel: if isinstance(val, SentinelClass):
raise ServerError(f"No data for argument: {key}") raise ServerError(f"No data for argument: {key}")
return val return val
def get_str(self, key, default=Sentinel): def get_str(self,
return self._get_converted_arg(key, default) key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[str, _T]:
return self._get_converted_arg(key, default, str)
def get_int(self, key, default=Sentinel): def get_int(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[int, _T]:
return self._get_converted_arg(key, default, int) return self._get_converted_arg(key, default, int)
def get_float(self, key, default=Sentinel): def get_float(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[float, _T]:
return self._get_converted_arg(key, default, float) return self._get_converted_arg(key, default, float)
def get_boolean(self, key, default=Sentinel): def get_boolean(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[bool, _T]:
return self._get_converted_arg(key, default, bool) return self._get_converted_arg(key, default, bool)
class JsonRPC: class JsonRPC:
def __init__(self): def __init__(self) -> None:
self.methods = {} self.methods: Dict[str, RPCCallback] = {}
def register_method(self, name, method): def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method self.methods[name] = method
def remove_method(self, name): def remove_method(self, name: str) -> None:
self.methods.pop(name, None) self.methods.pop(name, None)
async def dispatch(self, data, ws): async def dispatch(self,
response = None data: str,
ws: WebSocket
) -> Optional[str]:
response: Any = None
try: try:
request = json.loads(data) request: Union[Dict[str, Any], List[dict]] = json.loads(data)
except Exception: except Exception:
msg = f"Websocket data not json: {data}" msg = f"Websocket data not json: {data}"
logging.exception(msg) logging.exception(msg)
response = self.build_error(-32700, "Parse error") response = self.build_error(-32700, "Parse error")
return json.dumps(response) return json.dumps(response)
logging.debug("Websocket Request::" + data) logging.debug(f"Websocket Request::{data}")
if isinstance(request, list): if isinstance(request, list):
response = [] response = []
for req in request: for req in request:
@ -121,9 +182,12 @@ class JsonRPC:
logging.debug("Websocket Response::" + response) logging.debug("Websocket Response::" + response)
return response return response
async def process_request(self, request, ws): async def process_request(self,
req_id = request.get('id', None) request: Dict[str, Any],
rpc_version = request.get('jsonrpc', "") ws: WebSocket
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = request.get('id', None)
rpc_version: str = request.get('jsonrpc', "")
method_name = request.get('method', None) method_name = request.get('method', None)
if rpc_version != "2.0" or not isinstance(method_name, str): if rpc_version != "2.0" or not isinstance(method_name, str):
return self.build_error(-32600, "Invalid Request", req_id) return self.build_error(-32600, "Invalid Request", req_id)
@ -144,7 +208,13 @@ class JsonRPC:
response = await self.execute_method(method, req_id, ws) response = await self.execute_method(method, req_id, ws)
return response return response
async def execute_method(self, method, req_id, ws, *args, **kwargs): async def execute_method(self,
method: RPCCallback,
req_id: Optional[int],
ws: WebSocket,
*args,
**kwargs
) -> Optional[Dict[str, Any]]:
try: try:
result = await method(ws, *args, **kwargs) result = await method(ws, *args, **kwargs)
except TypeError as e: except TypeError as e:
@ -160,14 +230,19 @@ class JsonRPC:
else: else:
return self.build_result(result, req_id) return self.build_result(result, req_id)
def build_result(self, result, req_id): def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
return { return {
'jsonrpc': "2.0", 'jsonrpc': "2.0",
'result': result, 'result': result,
'id': req_id 'id': req_id
} }
def build_error(self, code, msg, req_id=None, is_exc=False): def build_error(self,
code: int,
msg: str,
req_id: Optional[int] = None,
is_exc: bool = False
) -> Dict[str, Any]:
log_msg = f"JSON-RPC Request Error: {code}\n{msg}" log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
if is_exc: if is_exc:
logging.exception(log_msg) logging.exception(log_msg)
@ -180,15 +255,18 @@ class JsonRPC:
} }
class WebsocketManager: class WebsocketManager:
def __init__(self, server): def __init__(self, server: Server) -> None:
self.server = server self.server = server
self.websockets = {} self.websockets: Dict[int, WebSocket] = {}
self.ws_lock = tornado.locks.Lock() self.ws_lock = Lock()
self.rpc = JsonRPC() self.rpc = JsonRPC()
self.rpc.register_method("server.websocket.id", self._handle_id_request) self.rpc.register_method("server.websocket.id", self._handle_id_request)
def register_notification(self, event_name, notify_name=None): def register_notification(self,
event_name: str,
notify_name: Optional[str] = None
) -> None:
if notify_name is None: if notify_name is None:
notify_name = event_name.split(':')[-1] notify_name = event_name.split(':')[-1]
@ -197,61 +275,74 @@ class WebsocketManager:
self.server.register_event_handler( self.server.register_event_handler(
event_name, notify_handler) event_name, notify_handler)
def register_local_handler(self, api_def, callback): def register_local_handler(self,
api_def: APIDefinition,
callback: Callable[[WebRequest], Coroutine]
) -> None:
for ws_method, req_method in \ for ws_method, req_method in \
zip(api_def.ws_methods, api_def.request_methods): zip(api_def.ws_methods, api_def.request_methods):
rpc_cb = self._generate_local_callback( rpc_cb = self._generate_local_callback(
api_def.endpoint, req_method, callback) api_def.endpoint, req_method, callback)
self.rpc.register_method(ws_method, rpc_cb) self.rpc.register_method(ws_method, rpc_cb)
def register_remote_handler(self, api_def): def register_remote_handler(self, api_def: APIDefinition) -> None:
ws_method = api_def.ws_methods[0] ws_method = api_def.ws_methods[0]
rpc_cb = self._generate_callback(api_def.endpoint) rpc_cb = self._generate_callback(api_def.endpoint)
self.rpc.register_method(ws_method, rpc_cb) self.rpc.register_method(ws_method, rpc_cb)
def remove_handler(self, ws_method): def remove_handler(self, ws_method: str) -> None:
self.rpc.remove_method(ws_method) self.rpc.remove_method(ws_method)
def _generate_callback(self, endpoint): def _generate_callback(self, endpoint: str) -> RPCCallback:
async def func(ws, **kwargs): async def func(ws: WebSocket, **kwargs) -> Any:
result = await self.server.make_request( result = await self.server.make_request(
WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr, WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr,
user=ws.current_user)) user=ws.current_user))
return result return result
return func return func
def _generate_local_callback(self, endpoint, request_method, callback): def _generate_local_callback(self,
async def func(ws, **kwargs): endpoint: str,
request_method: str,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(ws: WebSocket, **kwargs) -> Any:
result = await callback( result = await callback(
WebRequest(endpoint, kwargs, request_method, ws, WebRequest(endpoint, kwargs, request_method, ws,
ip_addr=ws.ip_addr, user=ws.current_user)) ip_addr=ws.ip_addr, user=ws.current_user))
return result return result
return func return func
async def _handle_id_request(self, ws, **kwargs): async def _handle_id_request(self,
ws: WebSocket,
**kwargs
) -> Dict[str, int]:
return {'websocket_id': ws.uid} return {'websocket_id': ws.uid}
def has_websocket(self, ws_id): def has_websocket(self, ws_id: int) -> bool:
return ws_id in self.websockets return ws_id in self.websockets
def get_websocket(self, ws_id): def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
return self.websockets.get(ws_id, None) return self.websockets.get(ws_id, None)
async def add_websocket(self, ws): async def add_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock: async with self.ws_lock:
self.websockets[ws.uid] = ws self.websockets[ws.uid] = ws
logging.info(f"New Websocket Added: {ws.uid}") logging.info(f"New Websocket Added: {ws.uid}")
async def remove_websocket(self, ws): async def remove_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock: async with self.ws_lock:
old_ws = self.websockets.pop(ws.uid, None) old_ws = self.websockets.pop(ws.uid, None)
if old_ws is not None: if old_ws is not None:
self.server.remove_subscription(old_ws) self.server.remove_subscription(old_ws)
logging.info(f"Websocket Removed: {ws.uid}") logging.info(f"Websocket Removed: {ws.uid}")
async def notify_websockets(self, name, data=Sentinel): async def notify_websockets(self,
msg = {'jsonrpc': "2.0", 'method': "notify_" + name} name: str,
if data != Sentinel: data: Any = SENTINEL
) -> None:
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data != SENTINEL:
msg['params'] = [data] msg['params'] = [data]
async with self.ws_lock: async with self.ws_lock:
for ws in list(self.websockets.values()): for ws in list(self.websockets.values()):
@ -265,30 +356,30 @@ class WebsocketManager:
logging.exception( logging.exception(
f"Error sending data over websocket: {ws.uid}") f"Error sending data over websocket: {ws.uid}")
async def close(self): async def close(self) -> None:
async with self.ws_lock: async with self.ws_lock:
for ws in list(self.websockets.values()): for ws in list(self.websockets.values()):
ws.close() ws.close()
self.websockets = {} self.websockets = {}
class WebSocket(WebSocketHandler): class WebSocket(WebSocketHandler, Subscribable):
def initialize(self): def initialize(self) -> None:
app = self.settings['parent'] app: MoonrakerApp = self.settings['parent']
self.server = app.get_server() self.server = app.get_server()
self.wsm = app.get_websocket_manager() self.wsm = app.get_websocket_manager()
self.rpc = self.wsm.rpc self.rpc = self.wsm.rpc
self.uid = id(self) self.uid = id(self)
self.is_closed = False self.is_closed: bool = False
self.ip_addr = self.request.remote_ip self.ip_addr: str = self.request.remote_ip
async def open(self): async def open(self, *args, **kwargs) -> None:
await self.wsm.add_websocket(self) await self.wsm.add_websocket(self)
def on_message(self, message): def on_message(self, message: Union[bytes, str]) -> None:
io_loop = IOLoop.current() io_loop = IOLoop.current()
io_loop.spawn_callback(self._process_message, message) io_loop.spawn_callback(self._process_message, message)
async def _process_message(self, message): async def _process_message(self, message: str) -> None:
try: try:
response = await self.rpc.dispatch(message, self) response = await self.rpc.dispatch(message, self)
if response is not None: if response is not None:
@ -296,7 +387,7 @@ class WebSocket(WebSocketHandler):
except Exception: except Exception:
logging.exception("Websocket Command Error") logging.exception("Websocket Command Error")
def send_status(self, status): def send_status(self, status: Dict[str, Any]) -> None:
if not status or self.is_closed: if not status or self.is_closed:
return return
try: try:
@ -312,14 +403,14 @@ class WebSocket(WebSocketHandler):
logging.exception( logging.exception(
f"Error sending data over websocket: {self.uid}") f"Error sending data over websocket: {self.uid}")
def on_close(self): def on_close(self) -> None:
self.is_closed = True self.is_closed = True
io_loop = IOLoop.current() io_loop = IOLoop.current()
io_loop.spawn_callback(self.wsm.remove_websocket, self) io_loop.spawn_callback(self.wsm.remove_websocket, self)
def check_origin(self, origin): def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin): if not super(WebSocket, self).check_origin(origin):
auth = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
return auth.check_cors(origin) return auth.check_cors(origin)
return False return False
@ -327,6 +418,6 @@ class WebSocket(WebSocketHandler):
# Check Authorized User # Check Authorized User
def prepare(self): def prepare(self):
auth = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
self.current_user = auth.check_authorized(self.request) self.current_user = auth.check_authorized(self.request)