websockets: refactor message handling

Implement a write buffer so that all calls to "write_message" are awaited.  This allows for more graceful shutdown if the websocket is closed.

When Moonraker shuts down, attempt to wait for all websockets to close before exiting.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2021-06-26 14:31:28 -04:00
parent 188dc4c782
commit b6f9769488
1 changed files with 71 additions and 52 deletions

View File

@ -8,9 +8,10 @@ from __future__ import annotations
import logging import logging
import ipaddress import ipaddress
import json import json
import tornado.util
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.websocket import WebSocketHandler, WebSocketClosedError
from tornado.locks import Lock from tornado.locks import Event
from utils import ServerError, SentinelClass from utils import ServerError, SentinelClass
# Annotation imports # Annotation imports
@ -269,8 +270,8 @@ class WebsocketManager(APITransport):
def __init__(self, server: Server) -> None: def __init__(self, server: Server) -> None:
self.server = server self.server = server
self.websockets: Dict[int, WebSocket] = {} self.websockets: Dict[int, WebSocket] = {}
self.ws_lock = Lock()
self.rpc = JsonRPC() self.rpc = JsonRPC()
self.closed_event: Optional[Event] = None
self.rpc.register_method("server.websocket.id", self._handle_id_request) self.rpc.register_method("server.websocket.id", self._handle_id_request)
@ -281,8 +282,8 @@ class WebsocketManager(APITransport):
if notify_name is None: if notify_name is None:
notify_name = event_name.split(':')[-1] notify_name = event_name.split(':')[-1]
async def notify_handler(*args): def notify_handler(*args):
await self.notify_websockets(notify_name, *args) self.notify_websockets(notify_name, *args)
self.server.register_event_handler( self.server.register_event_handler(
event_name, notify_handler) event_name, notify_handler)
@ -339,42 +340,39 @@ class WebsocketManager(APITransport):
def get_websocket(self, ws_id: int) -> Optional[WebSocket]: def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
return self.websockets.get(ws_id, None) return self.websockets.get(ws_id, None)
async def add_websocket(self, ws: WebSocket) -> None: def add_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
self.websockets[ws.uid] = ws self.websockets[ws.uid] = ws
logging.info(f"New Websocket Added: {ws.uid}") logging.info(f"New Websocket Added: {ws.uid}")
async def remove_websocket(self, ws: WebSocket) -> None: def remove_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
old_ws = self.websockets.pop(ws.uid, None) old_ws = self.websockets.pop(ws.uid, None)
if old_ws is not None: if old_ws is not None:
self.server.remove_subscription(old_ws) self.server.remove_subscription(old_ws)
logging.info(f"Websocket Removed: {ws.uid}") logging.info(f"Websocket Removed: {ws.uid}")
if self.closed_event is not None and not self.websockets:
self.closed_event.set()
async def notify_websockets(self, def notify_websockets(self,
name: str, name: str,
data: Any = SENTINEL data: Any = SENTINEL
) -> None: ) -> None:
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name} msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data != SENTINEL: if data != SENTINEL:
msg['params'] = [data] msg['params'] = [data]
async with self.ws_lock:
for ws in list(self.websockets.values()): for ws in list(self.websockets.values()):
try: ws.queue_message(msg)
ws.write_message(msg)
except WebSocketClosedError:
self.websockets.pop(ws.uid, None)
self.server.remove_subscription(ws)
logging.info(f"Websocket Removed: {ws.uid}")
except Exception:
logging.exception(
f"Error sending data over websocket: {ws.uid}")
async def close(self) -> None: async def close(self) -> None:
async with self.ws_lock: if not self.websockets:
return
self.closed_event = Event()
for ws in list(self.websockets.values()): for ws in list(self.websockets.values()):
ws.close() ws.close()
self.websockets = {} try:
await self.closed_event.wait(2.)
except tornado.util.TimeoutError:
pass
self.closed_event = None
class WebSocket(WebSocketHandler, Subscribable): class WebSocket(WebSocketHandler, Subscribable):
def initialize(self) -> None: def initialize(self) -> None:
@ -385,9 +383,12 @@ class WebSocket(WebSocketHandler, Subscribable):
self.uid = id(self) self.uid = id(self)
self.is_closed: bool = False self.is_closed: bool = False
self.ip_addr: str = self.request.remote_ip self.ip_addr: str = self.request.remote_ip
self.queue_busy: bool = False
self.message_buf: List[Union[str, Dict[str, Any]]] = []
async def open(self, *args, **kwargs) -> None: def open(self, *args, **kwargs) -> None:
await self.wsm.add_websocket(self) self.set_nodelay(True)
self.wsm.add_websocket(self)
def on_message(self, message: Union[bytes, str]) -> None: def on_message(self, message: Union[bytes, str]) -> None:
io_loop = IOLoop.current() io_loop = IOLoop.current()
@ -397,30 +398,48 @@ class WebSocket(WebSocketHandler, Subscribable):
try: try:
response = await self.rpc.dispatch(message, self) response = await self.rpc.dispatch(message, self)
if response is not None: if response is not None:
self.write_message(response) self.queue_message(response)
except Exception: except Exception:
logging.exception("Websocket Command Error") logging.exception("Websocket Command Error")
def send_status(self, status: Dict[str, Any]) -> None: def queue_message(self, message: Union[str, Dict[str, Any]]):
if not status or self.is_closed: self.message_buf.append(message)
if self.queue_busy:
return return
self.queue_busy = True
IOLoop.current().spawn_callback(self._process_messages)
async def _process_messages(self):
if self.is_closed:
self.message_buf = []
self.queue_busy = False
return
while self.message_buf:
msg = self.message_buf.pop(0)
try: try:
self.write_message({ await self.write_message(msg)
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status]})
except WebSocketClosedError: except WebSocketClosedError:
self.is_closed = True self.is_closed = True
logging.info( logging.info(
f"Websocket Closed During Status Update: {self.uid}") f"Websocket closed while writing: {self.uid}")
break
except Exception: except Exception:
logging.exception( logging.exception(
f"Error sending data over websocket: {self.uid}") f"Error sending data over websocket: {self.uid}")
self.queue_busy = False
def send_status(self, status: Dict[str, Any]) -> None:
if not status:
return
self.queue_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status]})
def on_close(self) -> None: def on_close(self) -> None:
self.is_closed = True self.is_closed = True
io_loop = IOLoop.current() self.message_buf = []
io_loop.spawn_callback(self.wsm.remove_websocket, self) self.wsm.remove_websocket(self)
def check_origin(self, origin: str) -> bool: def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin): if not super(WebSocket, self).check_origin(origin):