diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 4b84aec..a6f5a46 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -8,9 +8,10 @@ from __future__ import annotations import logging import ipaddress import json +import tornado.util from tornado.ioloop import IOLoop from tornado.websocket import WebSocketHandler, WebSocketClosedError -from tornado.locks import Lock +from tornado.locks import Event from utils import ServerError, SentinelClass # Annotation imports @@ -269,8 +270,8 @@ class WebsocketManager(APITransport): def __init__(self, server: Server) -> None: self.server = server self.websockets: Dict[int, WebSocket] = {} - self.ws_lock = Lock() self.rpc = JsonRPC() + self.closed_event: Optional[Event] = None self.rpc.register_method("server.websocket.id", self._handle_id_request) @@ -281,8 +282,8 @@ class WebsocketManager(APITransport): if notify_name is None: notify_name = event_name.split(':')[-1] - async def notify_handler(*args): - await self.notify_websockets(notify_name, *args) + def notify_handler(*args): + self.notify_websockets(notify_name, *args) self.server.register_event_handler( event_name, notify_handler) @@ -339,42 +340,39 @@ class WebsocketManager(APITransport): def get_websocket(self, ws_id: int) -> Optional[WebSocket]: return self.websockets.get(ws_id, None) - async def add_websocket(self, ws: WebSocket) -> None: - async with self.ws_lock: - self.websockets[ws.uid] = ws - logging.info(f"New Websocket Added: {ws.uid}") + def add_websocket(self, ws: WebSocket) -> None: + self.websockets[ws.uid] = ws + logging.info(f"New Websocket Added: {ws.uid}") - async def remove_websocket(self, ws: WebSocket) -> None: - async with self.ws_lock: - old_ws = self.websockets.pop(ws.uid, None) - if old_ws is not None: - self.server.remove_subscription(old_ws) - logging.info(f"Websocket Removed: {ws.uid}") + def remove_websocket(self, ws: WebSocket) -> None: + old_ws = self.websockets.pop(ws.uid, None) + if old_ws is not None: + self.server.remove_subscription(old_ws) + logging.info(f"Websocket Removed: {ws.uid}") + if self.closed_event is not None and not self.websockets: + self.closed_event.set() - async def notify_websockets(self, - name: str, - data: Any = SENTINEL - ) -> None: + def notify_websockets(self, + name: str, + data: Any = SENTINEL + ) -> None: msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name} if data != SENTINEL: msg['params'] = [data] - async with self.ws_lock: - for ws in list(self.websockets.values()): - try: - ws.write_message(msg) - except WebSocketClosedError: - self.websockets.pop(ws.uid, None) - self.server.remove_subscription(ws) - logging.info(f"Websocket Removed: {ws.uid}") - except Exception: - logging.exception( - f"Error sending data over websocket: {ws.uid}") + for ws in list(self.websockets.values()): + ws.queue_message(msg) async def close(self) -> None: - async with self.ws_lock: - for ws in list(self.websockets.values()): - ws.close() - self.websockets = {} + if not self.websockets: + return + self.closed_event = Event() + for ws in list(self.websockets.values()): + ws.close() + try: + await self.closed_event.wait(2.) + except tornado.util.TimeoutError: + pass + self.closed_event = None class WebSocket(WebSocketHandler, Subscribable): def initialize(self) -> None: @@ -385,9 +383,12 @@ class WebSocket(WebSocketHandler, Subscribable): self.uid = id(self) self.is_closed: bool = False self.ip_addr: str = self.request.remote_ip + self.queue_busy: bool = False + self.message_buf: List[Union[str, Dict[str, Any]]] = [] - async def open(self, *args, **kwargs) -> None: - await self.wsm.add_websocket(self) + def open(self, *args, **kwargs) -> None: + self.set_nodelay(True) + self.wsm.add_websocket(self) def on_message(self, message: Union[bytes, str]) -> None: io_loop = IOLoop.current() @@ -397,30 +398,48 @@ class WebSocket(WebSocketHandler, Subscribable): try: response = await self.rpc.dispatch(message, self) if response is not None: - self.write_message(response) + self.queue_message(response) except Exception: logging.exception("Websocket Command Error") - def send_status(self, status: Dict[str, Any]) -> None: - if not status or self.is_closed: + def queue_message(self, message: Union[str, Dict[str, Any]]): + self.message_buf.append(message) + if self.queue_busy: return - try: - self.write_message({ - 'jsonrpc': "2.0", - 'method': "notify_status_update", - 'params': [status]}) - except WebSocketClosedError: - self.is_closed = True - logging.info( - f"Websocket Closed During Status Update: {self.uid}") - except Exception: - logging.exception( - f"Error sending data over websocket: {self.uid}") + self.queue_busy = True + IOLoop.current().spawn_callback(self._process_messages) + + async def _process_messages(self): + if self.is_closed: + self.message_buf = [] + self.queue_busy = False + return + while self.message_buf: + msg = self.message_buf.pop(0) + try: + await self.write_message(msg) + except WebSocketClosedError: + self.is_closed = True + logging.info( + f"Websocket closed while writing: {self.uid}") + break + except Exception: + logging.exception( + f"Error sending data over websocket: {self.uid}") + self.queue_busy = False + + def send_status(self, status: Dict[str, Any]) -> None: + if not status: + return + self.queue_message({ + 'jsonrpc': "2.0", + 'method': "notify_status_update", + 'params': [status]}) def on_close(self) -> None: self.is_closed = True - io_loop = IOLoop.current() - io_loop.spawn_callback(self.wsm.remove_websocket, self) + self.message_buf = [] + self.wsm.remove_websocket(self) def check_origin(self, origin: str) -> bool: if not super(WebSocket, self).check_origin(origin):