websockets: add annotations

Implement a "Subscribable" base class for objects that can maintain a status subscription.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-11 18:14:07 -04:00
parent 96e69240ca
commit 9c76dbef7a
1 changed files with 163 additions and 72 deletions

View File

@ -4,108 +4,169 @@
#
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import logging
import ipaddress
import tornado
import json
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler, WebSocketClosedError
from utils import ServerError
from tornado.locks import Lock
from utils import ServerError, SentinelClass
class Sentinel:
pass
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Optional,
Callable,
Coroutine,
Type,
TypeVar,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from moonraker import Server
from app import APIDefinition, MoonrakerApp
import components.authorization
_T = TypeVar("_T")
_C = TypeVar("_C", str, bool, float, int)
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
ConvType = Union[str, bool, float, int]
ArgVal = Union[None, int, float, bool, str]
RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[components.authorization.Authorization]
SENTINEL = SentinelClass.get_instance()
class Subscribable:
def send_status(self, status: Dict[str, Any]) -> None:
raise NotImplementedError
class WebRequest:
def __init__(self, endpoint, args, action="",
conn=None, ip_addr="", user=None):
def __init__(self,
endpoint: str,
args: Dict[str, Any],
action: Optional[str] = "",
conn: Optional[Subscribable] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
self.endpoint = endpoint
self.action = action
self.action = action or ""
self.args = args
self.conn = conn
self.ip_addr: Optional[IPUnion] = None
try:
self.ip_addr = ipaddress.ip_address(ip_addr)
except Exception:
self.ip_addr = None
self.current_user = user
def get_endpoint(self):
def get_endpoint(self) -> str:
return self.endpoint
def get_action(self):
def get_action(self) -> str:
return self.action
def get_args(self):
def get_args(self) -> Dict[str, Any]:
return self.args
def get_connection(self):
def get_connection(self) -> Optional[Subscribable]:
return self.conn
def get_ip_address(self):
def get_ip_address(self) -> Optional[IPUnion]:
return self.ip_addr
def get_current_user(self):
def get_current_user(self) -> Optional[Dict[str, Any]]:
return self.current_user
def _get_converted_arg(self, key, default=Sentinel, dtype=str):
def _get_converted_arg(self,
key: str,
default: Union[SentinelClass, _T],
dtype: Type[_C]
) -> Union[_C, _T]:
if key not in self.args:
if default == Sentinel:
if isinstance(default, SentinelClass):
raise ServerError(f"No data for argument: {key}")
return default
val = self.args[key]
try:
if dtype != bool:
if dtype is not bool:
return dtype(val)
else:
if isinstance(val, str):
val = val.lower()
if val in ["true", "false"]:
return True if val == "true" else False
return True if val == "true" else False # type: ignore
elif isinstance(val, bool):
return val
return val # type: ignore
raise TypeError
except Exception:
raise ServerError(
f"Unable to convert argument [{key}] to {dtype}: "
f"value recieved: {val}")
def get(self, key, default=Sentinel):
def get(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[_T, Any]:
val = self.args.get(key, default)
if val == Sentinel:
if isinstance(val, SentinelClass):
raise ServerError(f"No data for argument: {key}")
return val
def get_str(self, key, default=Sentinel):
return self._get_converted_arg(key, default)
def get_str(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[str, _T]:
return self._get_converted_arg(key, default, str)
def get_int(self, key, default=Sentinel):
def get_int(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[int, _T]:
return self._get_converted_arg(key, default, int)
def get_float(self, key, default=Sentinel):
def get_float(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[float, _T]:
return self._get_converted_arg(key, default, float)
def get_boolean(self, key, default=Sentinel):
def get_boolean(self,
key: str,
default: Union[SentinelClass, _T] = SENTINEL
) -> Union[bool, _T]:
return self._get_converted_arg(key, default, bool)
class JsonRPC:
def __init__(self):
self.methods = {}
def __init__(self) -> None:
self.methods: Dict[str, RPCCallback] = {}
def register_method(self, name, method):
def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method
def remove_method(self, name):
def remove_method(self, name: str) -> None:
self.methods.pop(name, None)
async def dispatch(self, data, ws):
response = None
async def dispatch(self,
data: str,
ws: WebSocket
) -> Optional[str]:
response: Any = None
try:
request = json.loads(data)
request: Union[Dict[str, Any], List[dict]] = json.loads(data)
except Exception:
msg = f"Websocket data not json: {data}"
logging.exception(msg)
response = self.build_error(-32700, "Parse error")
return json.dumps(response)
logging.debug("Websocket Request::" + data)
logging.debug(f"Websocket Request::{data}")
if isinstance(request, list):
response = []
for req in request:
@ -121,9 +182,12 @@ class JsonRPC:
logging.debug("Websocket Response::" + response)
return response
async def process_request(self, request, ws):
req_id = request.get('id', None)
rpc_version = request.get('jsonrpc', "")
async def process_request(self,
request: Dict[str, Any],
ws: WebSocket
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = request.get('id', None)
rpc_version: str = request.get('jsonrpc', "")
method_name = request.get('method', None)
if rpc_version != "2.0" or not isinstance(method_name, str):
return self.build_error(-32600, "Invalid Request", req_id)
@ -144,7 +208,13 @@ class JsonRPC:
response = await self.execute_method(method, req_id, ws)
return response
async def execute_method(self, method, req_id, ws, *args, **kwargs):
async def execute_method(self,
method: RPCCallback,
req_id: Optional[int],
ws: WebSocket,
*args,
**kwargs
) -> Optional[Dict[str, Any]]:
try:
result = await method(ws, *args, **kwargs)
except TypeError as e:
@ -160,14 +230,19 @@ class JsonRPC:
else:
return self.build_result(result, req_id)
def build_result(self, result, req_id):
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
return {
'jsonrpc': "2.0",
'result': result,
'id': req_id
}
def build_error(self, code, msg, req_id=None, is_exc=False):
def build_error(self,
code: int,
msg: str,
req_id: Optional[int] = None,
is_exc: bool = False
) -> Dict[str, Any]:
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
if is_exc:
logging.exception(log_msg)
@ -180,15 +255,18 @@ class JsonRPC:
}
class WebsocketManager:
def __init__(self, server):
def __init__(self, server: Server) -> None:
self.server = server
self.websockets = {}
self.ws_lock = tornado.locks.Lock()
self.websockets: Dict[int, WebSocket] = {}
self.ws_lock = Lock()
self.rpc = JsonRPC()
self.rpc.register_method("server.websocket.id", self._handle_id_request)
def register_notification(self, event_name, notify_name=None):
def register_notification(self,
event_name: str,
notify_name: Optional[str] = None
) -> None:
if notify_name is None:
notify_name = event_name.split(':')[-1]
@ -197,61 +275,74 @@ class WebsocketManager:
self.server.register_event_handler(
event_name, notify_handler)
def register_local_handler(self, api_def, callback):
def register_local_handler(self,
api_def: APIDefinition,
callback: Callable[[WebRequest], Coroutine]
) -> None:
for ws_method, req_method in \
zip(api_def.ws_methods, api_def.request_methods):
rpc_cb = self._generate_local_callback(
api_def.endpoint, req_method, callback)
self.rpc.register_method(ws_method, rpc_cb)
def register_remote_handler(self, api_def):
def register_remote_handler(self, api_def: APIDefinition) -> None:
ws_method = api_def.ws_methods[0]
rpc_cb = self._generate_callback(api_def.endpoint)
self.rpc.register_method(ws_method, rpc_cb)
def remove_handler(self, ws_method):
def remove_handler(self, ws_method: str) -> None:
self.rpc.remove_method(ws_method)
def _generate_callback(self, endpoint):
async def func(ws, **kwargs):
def _generate_callback(self, endpoint: str) -> RPCCallback:
async def func(ws: WebSocket, **kwargs) -> Any:
result = await self.server.make_request(
WebRequest(endpoint, kwargs, conn=ws, ip_addr=ws.ip_addr,
user=ws.current_user))
return result
return func
def _generate_local_callback(self, endpoint, request_method, callback):
async def func(ws, **kwargs):
def _generate_local_callback(self,
endpoint: str,
request_method: str,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(ws: WebSocket, **kwargs) -> Any:
result = await callback(
WebRequest(endpoint, kwargs, request_method, ws,
ip_addr=ws.ip_addr, user=ws.current_user))
return result
return func
async def _handle_id_request(self, ws, **kwargs):
async def _handle_id_request(self,
ws: WebSocket,
**kwargs
) -> Dict[str, int]:
return {'websocket_id': ws.uid}
def has_websocket(self, ws_id):
def has_websocket(self, ws_id: int) -> bool:
return ws_id in self.websockets
def get_websocket(self, ws_id):
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
return self.websockets.get(ws_id, None)
async def add_websocket(self, ws):
async def add_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
self.websockets[ws.uid] = ws
logging.info(f"New Websocket Added: {ws.uid}")
async def remove_websocket(self, ws):
async def remove_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
old_ws = self.websockets.pop(ws.uid, None)
if old_ws is not None:
self.server.remove_subscription(old_ws)
logging.info(f"Websocket Removed: {ws.uid}")
async def notify_websockets(self, name, data=Sentinel):
msg = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data != Sentinel:
async def notify_websockets(self,
name: str,
data: Any = SENTINEL
) -> None:
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data != SENTINEL:
msg['params'] = [data]
async with self.ws_lock:
for ws in list(self.websockets.values()):
@ -265,30 +356,30 @@ class WebsocketManager:
logging.exception(
f"Error sending data over websocket: {ws.uid}")
async def close(self):
async def close(self) -> None:
async with self.ws_lock:
for ws in list(self.websockets.values()):
ws.close()
self.websockets = {}
class WebSocket(WebSocketHandler):
def initialize(self):
app = self.settings['parent']
class WebSocket(WebSocketHandler, Subscribable):
def initialize(self) -> None:
app: MoonrakerApp = self.settings['parent']
self.server = app.get_server()
self.wsm = app.get_websocket_manager()
self.rpc = self.wsm.rpc
self.uid = id(self)
self.is_closed = False
self.ip_addr = self.request.remote_ip
self.is_closed: bool = False
self.ip_addr: str = self.request.remote_ip
async def open(self):
async def open(self, *args, **kwargs) -> None:
await self.wsm.add_websocket(self)
def on_message(self, message):
def on_message(self, message: Union[bytes, str]) -> None:
io_loop = IOLoop.current()
io_loop.spawn_callback(self._process_message, message)
async def _process_message(self, message):
async def _process_message(self, message: str) -> None:
try:
response = await self.rpc.dispatch(message, self)
if response is not None:
@ -296,7 +387,7 @@ class WebSocket(WebSocketHandler):
except Exception:
logging.exception("Websocket Command Error")
def send_status(self, status):
def send_status(self, status: Dict[str, Any]) -> None:
if not status or self.is_closed:
return
try:
@ -312,14 +403,14 @@ class WebSocket(WebSocketHandler):
logging.exception(
f"Error sending data over websocket: {self.uid}")
def on_close(self):
def on_close(self) -> None:
self.is_closed = True
io_loop = IOLoop.current()
io_loop.spawn_callback(self.wsm.remove_websocket, self)
def check_origin(self, origin):
def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin):
auth = self.server.lookup_component('authorization', None)
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
return auth.check_cors(origin)
return False
@ -327,6 +418,6 @@ class WebSocket(WebSocketHandler):
# Check Authorized User
def prepare(self):
auth = self.server.lookup_component('authorization', None)
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
self.current_user = auth.check_authorized(self.request)