websockets: refactor message handling

Implement a write buffer so that all calls to "write_message" are awaited.  This allows for more graceful shutdown if the websocket is closed.

When Moonraker shuts down, attempt to wait for all websockets to close before exiting.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2021-06-26 14:31:28 -04:00
parent 188dc4c782
commit b6f9769488
1 changed files with 71 additions and 52 deletions

View File

@ -8,9 +8,10 @@ from __future__ import annotations
import logging
import ipaddress
import json
import tornado.util
from tornado.ioloop import IOLoop
from tornado.websocket import WebSocketHandler, WebSocketClosedError
from tornado.locks import Lock
from tornado.locks import Event
from utils import ServerError, SentinelClass
# Annotation imports
@ -269,8 +270,8 @@ class WebsocketManager(APITransport):
def __init__(self, server: Server) -> None:
self.server = server
self.websockets: Dict[int, WebSocket] = {}
self.ws_lock = Lock()
self.rpc = JsonRPC()
self.closed_event: Optional[Event] = None
self.rpc.register_method("server.websocket.id", self._handle_id_request)
@ -281,8 +282,8 @@ class WebsocketManager(APITransport):
if notify_name is None:
notify_name = event_name.split(':')[-1]
async def notify_handler(*args):
await self.notify_websockets(notify_name, *args)
def notify_handler(*args):
self.notify_websockets(notify_name, *args)
self.server.register_event_handler(
event_name, notify_handler)
@ -339,42 +340,39 @@ class WebsocketManager(APITransport):
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
return self.websockets.get(ws_id, None)
async def add_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
def add_websocket(self, ws: WebSocket) -> None:
self.websockets[ws.uid] = ws
logging.info(f"New Websocket Added: {ws.uid}")
async def remove_websocket(self, ws: WebSocket) -> None:
async with self.ws_lock:
def remove_websocket(self, ws: WebSocket) -> None:
old_ws = self.websockets.pop(ws.uid, None)
if old_ws is not None:
self.server.remove_subscription(old_ws)
logging.info(f"Websocket Removed: {ws.uid}")
if self.closed_event is not None and not self.websockets:
self.closed_event.set()
async def notify_websockets(self,
def notify_websockets(self,
name: str,
data: Any = SENTINEL
) -> None:
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
if data != SENTINEL:
msg['params'] = [data]
async with self.ws_lock:
for ws in list(self.websockets.values()):
try:
ws.write_message(msg)
except WebSocketClosedError:
self.websockets.pop(ws.uid, None)
self.server.remove_subscription(ws)
logging.info(f"Websocket Removed: {ws.uid}")
except Exception:
logging.exception(
f"Error sending data over websocket: {ws.uid}")
ws.queue_message(msg)
async def close(self) -> None:
async with self.ws_lock:
if not self.websockets:
return
self.closed_event = Event()
for ws in list(self.websockets.values()):
ws.close()
self.websockets = {}
try:
await self.closed_event.wait(2.)
except tornado.util.TimeoutError:
pass
self.closed_event = None
class WebSocket(WebSocketHandler, Subscribable):
def initialize(self) -> None:
@ -385,9 +383,12 @@ class WebSocket(WebSocketHandler, Subscribable):
self.uid = id(self)
self.is_closed: bool = False
self.ip_addr: str = self.request.remote_ip
self.queue_busy: bool = False
self.message_buf: List[Union[str, Dict[str, Any]]] = []
async def open(self, *args, **kwargs) -> None:
await self.wsm.add_websocket(self)
def open(self, *args, **kwargs) -> None:
self.set_nodelay(True)
self.wsm.add_websocket(self)
def on_message(self, message: Union[bytes, str]) -> None:
io_loop = IOLoop.current()
@ -397,30 +398,48 @@ class WebSocket(WebSocketHandler, Subscribable):
try:
response = await self.rpc.dispatch(message, self)
if response is not None:
self.write_message(response)
self.queue_message(response)
except Exception:
logging.exception("Websocket Command Error")
def send_status(self, status: Dict[str, Any]) -> None:
if not status or self.is_closed:
def queue_message(self, message: Union[str, Dict[str, Any]]):
self.message_buf.append(message)
if self.queue_busy:
return
self.queue_busy = True
IOLoop.current().spawn_callback(self._process_messages)
async def _process_messages(self):
if self.is_closed:
self.message_buf = []
self.queue_busy = False
return
while self.message_buf:
msg = self.message_buf.pop(0)
try:
self.write_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status]})
await self.write_message(msg)
except WebSocketClosedError:
self.is_closed = True
logging.info(
f"Websocket Closed During Status Update: {self.uid}")
f"Websocket closed while writing: {self.uid}")
break
except Exception:
logging.exception(
f"Error sending data over websocket: {self.uid}")
self.queue_busy = False
def send_status(self, status: Dict[str, Any]) -> None:
if not status:
return
self.queue_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status]})
def on_close(self) -> None:
self.is_closed = True
io_loop = IOLoop.current()
io_loop.spawn_callback(self.wsm.remove_websocket, self)
self.message_buf = []
self.wsm.remove_websocket(self)
def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin):