websockets: refactor message handling
Implement a write buffer so that all calls to "write_message" are awaited. This allows for more graceful shutdown if the websocket is closed. When Moonraker shuts down, attempt to wait for all websockets to close before exiting. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
188dc4c782
commit
b6f9769488
|
@ -8,9 +8,10 @@ from __future__ import annotations
|
|||
import logging
|
||||
import ipaddress
|
||||
import json
|
||||
import tornado.util
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
||||
from tornado.locks import Lock
|
||||
from tornado.locks import Event
|
||||
from utils import ServerError, SentinelClass
|
||||
|
||||
# Annotation imports
|
||||
|
@ -269,8 +270,8 @@ class WebsocketManager(APITransport):
|
|||
def __init__(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.websockets: Dict[int, WebSocket] = {}
|
||||
self.ws_lock = Lock()
|
||||
self.rpc = JsonRPC()
|
||||
self.closed_event: Optional[Event] = None
|
||||
|
||||
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
||||
|
||||
|
@ -281,8 +282,8 @@ class WebsocketManager(APITransport):
|
|||
if notify_name is None:
|
||||
notify_name = event_name.split(':')[-1]
|
||||
|
||||
async def notify_handler(*args):
|
||||
await self.notify_websockets(notify_name, *args)
|
||||
def notify_handler(*args):
|
||||
self.notify_websockets(notify_name, *args)
|
||||
self.server.register_event_handler(
|
||||
event_name, notify_handler)
|
||||
|
||||
|
@ -339,42 +340,39 @@ class WebsocketManager(APITransport):
|
|||
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
||||
return self.websockets.get(ws_id, None)
|
||||
|
||||
async def add_websocket(self, ws: WebSocket) -> None:
|
||||
async with self.ws_lock:
|
||||
def add_websocket(self, ws: WebSocket) -> None:
|
||||
self.websockets[ws.uid] = ws
|
||||
logging.info(f"New Websocket Added: {ws.uid}")
|
||||
|
||||
async def remove_websocket(self, ws: WebSocket) -> None:
|
||||
async with self.ws_lock:
|
||||
def remove_websocket(self, ws: WebSocket) -> None:
|
||||
old_ws = self.websockets.pop(ws.uid, None)
|
||||
if old_ws is not None:
|
||||
self.server.remove_subscription(old_ws)
|
||||
logging.info(f"Websocket Removed: {ws.uid}")
|
||||
if self.closed_event is not None and not self.websockets:
|
||||
self.closed_event.set()
|
||||
|
||||
async def notify_websockets(self,
|
||||
def notify_websockets(self,
|
||||
name: str,
|
||||
data: Any = SENTINEL
|
||||
) -> None:
|
||||
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
||||
if data != SENTINEL:
|
||||
msg['params'] = [data]
|
||||
async with self.ws_lock:
|
||||
for ws in list(self.websockets.values()):
|
||||
try:
|
||||
ws.write_message(msg)
|
||||
except WebSocketClosedError:
|
||||
self.websockets.pop(ws.uid, None)
|
||||
self.server.remove_subscription(ws)
|
||||
logging.info(f"Websocket Removed: {ws.uid}")
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Error sending data over websocket: {ws.uid}")
|
||||
ws.queue_message(msg)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self.ws_lock:
|
||||
if not self.websockets:
|
||||
return
|
||||
self.closed_event = Event()
|
||||
for ws in list(self.websockets.values()):
|
||||
ws.close()
|
||||
self.websockets = {}
|
||||
try:
|
||||
await self.closed_event.wait(2.)
|
||||
except tornado.util.TimeoutError:
|
||||
pass
|
||||
self.closed_event = None
|
||||
|
||||
class WebSocket(WebSocketHandler, Subscribable):
|
||||
def initialize(self) -> None:
|
||||
|
@ -385,9 +383,12 @@ class WebSocket(WebSocketHandler, Subscribable):
|
|||
self.uid = id(self)
|
||||
self.is_closed: bool = False
|
||||
self.ip_addr: str = self.request.remote_ip
|
||||
self.queue_busy: bool = False
|
||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||
|
||||
async def open(self, *args, **kwargs) -> None:
|
||||
await self.wsm.add_websocket(self)
|
||||
def open(self, *args, **kwargs) -> None:
|
||||
self.set_nodelay(True)
|
||||
self.wsm.add_websocket(self)
|
||||
|
||||
def on_message(self, message: Union[bytes, str]) -> None:
|
||||
io_loop = IOLoop.current()
|
||||
|
@ -397,30 +398,48 @@ class WebSocket(WebSocketHandler, Subscribable):
|
|||
try:
|
||||
response = await self.rpc.dispatch(message, self)
|
||||
if response is not None:
|
||||
self.write_message(response)
|
||||
self.queue_message(response)
|
||||
except Exception:
|
||||
logging.exception("Websocket Command Error")
|
||||
|
||||
def send_status(self, status: Dict[str, Any]) -> None:
|
||||
if not status or self.is_closed:
|
||||
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
||||
self.message_buf.append(message)
|
||||
if self.queue_busy:
|
||||
return
|
||||
self.queue_busy = True
|
||||
IOLoop.current().spawn_callback(self._process_messages)
|
||||
|
||||
async def _process_messages(self):
|
||||
if self.is_closed:
|
||||
self.message_buf = []
|
||||
self.queue_busy = False
|
||||
return
|
||||
while self.message_buf:
|
||||
msg = self.message_buf.pop(0)
|
||||
try:
|
||||
self.write_message({
|
||||
'jsonrpc': "2.0",
|
||||
'method': "notify_status_update",
|
||||
'params': [status]})
|
||||
await self.write_message(msg)
|
||||
except WebSocketClosedError:
|
||||
self.is_closed = True
|
||||
logging.info(
|
||||
f"Websocket Closed During Status Update: {self.uid}")
|
||||
f"Websocket closed while writing: {self.uid}")
|
||||
break
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Error sending data over websocket: {self.uid}")
|
||||
self.queue_busy = False
|
||||
|
||||
def send_status(self, status: Dict[str, Any]) -> None:
|
||||
if not status:
|
||||
return
|
||||
self.queue_message({
|
||||
'jsonrpc': "2.0",
|
||||
'method': "notify_status_update",
|
||||
'params': [status]})
|
||||
|
||||
def on_close(self) -> None:
|
||||
self.is_closed = True
|
||||
io_loop = IOLoop.current()
|
||||
io_loop.spawn_callback(self.wsm.remove_websocket, self)
|
||||
self.message_buf = []
|
||||
self.wsm.remove_websocket(self)
|
||||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
if not super(WebSocket, self).check_origin(origin):
|
||||
|
|
Loading…
Reference in New Issue