websockets: refactor message handling
Implement a write buffer so that all calls to "write_message" are awaited. This allows for more graceful shutdown if the websocket is closed. When Moonraker shuts down, attempt to wait for all websockets to close before exiting. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
188dc4c782
commit
b6f9769488
|
@ -8,9 +8,10 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
|
import tornado.util
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
||||||
from tornado.locks import Lock
|
from tornado.locks import Event
|
||||||
from utils import ServerError, SentinelClass
|
from utils import ServerError, SentinelClass
|
||||||
|
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
|
@ -269,8 +270,8 @@ class WebsocketManager(APITransport):
|
||||||
def __init__(self, server: Server) -> None:
|
def __init__(self, server: Server) -> None:
|
||||||
self.server = server
|
self.server = server
|
||||||
self.websockets: Dict[int, WebSocket] = {}
|
self.websockets: Dict[int, WebSocket] = {}
|
||||||
self.ws_lock = Lock()
|
|
||||||
self.rpc = JsonRPC()
|
self.rpc = JsonRPC()
|
||||||
|
self.closed_event: Optional[Event] = None
|
||||||
|
|
||||||
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
||||||
|
|
||||||
|
@ -281,8 +282,8 @@ class WebsocketManager(APITransport):
|
||||||
if notify_name is None:
|
if notify_name is None:
|
||||||
notify_name = event_name.split(':')[-1]
|
notify_name = event_name.split(':')[-1]
|
||||||
|
|
||||||
async def notify_handler(*args):
|
def notify_handler(*args):
|
||||||
await self.notify_websockets(notify_name, *args)
|
self.notify_websockets(notify_name, *args)
|
||||||
self.server.register_event_handler(
|
self.server.register_event_handler(
|
||||||
event_name, notify_handler)
|
event_name, notify_handler)
|
||||||
|
|
||||||
|
@ -339,42 +340,39 @@ class WebsocketManager(APITransport):
|
||||||
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
||||||
return self.websockets.get(ws_id, None)
|
return self.websockets.get(ws_id, None)
|
||||||
|
|
||||||
async def add_websocket(self, ws: WebSocket) -> None:
|
def add_websocket(self, ws: WebSocket) -> None:
|
||||||
async with self.ws_lock:
|
self.websockets[ws.uid] = ws
|
||||||
self.websockets[ws.uid] = ws
|
logging.info(f"New Websocket Added: {ws.uid}")
|
||||||
logging.info(f"New Websocket Added: {ws.uid}")
|
|
||||||
|
|
||||||
async def remove_websocket(self, ws: WebSocket) -> None:
|
def remove_websocket(self, ws: WebSocket) -> None:
|
||||||
async with self.ws_lock:
|
old_ws = self.websockets.pop(ws.uid, None)
|
||||||
old_ws = self.websockets.pop(ws.uid, None)
|
if old_ws is not None:
|
||||||
if old_ws is not None:
|
self.server.remove_subscription(old_ws)
|
||||||
self.server.remove_subscription(old_ws)
|
logging.info(f"Websocket Removed: {ws.uid}")
|
||||||
logging.info(f"Websocket Removed: {ws.uid}")
|
if self.closed_event is not None and not self.websockets:
|
||||||
|
self.closed_event.set()
|
||||||
|
|
||||||
async def notify_websockets(self,
|
def notify_websockets(self,
|
||||||
name: str,
|
name: str,
|
||||||
data: Any = SENTINEL
|
data: Any = SENTINEL
|
||||||
) -> None:
|
) -> None:
|
||||||
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
msg: Dict[str, Any] = {'jsonrpc': "2.0", 'method': "notify_" + name}
|
||||||
if data != SENTINEL:
|
if data != SENTINEL:
|
||||||
msg['params'] = [data]
|
msg['params'] = [data]
|
||||||
async with self.ws_lock:
|
for ws in list(self.websockets.values()):
|
||||||
for ws in list(self.websockets.values()):
|
ws.queue_message(msg)
|
||||||
try:
|
|
||||||
ws.write_message(msg)
|
|
||||||
except WebSocketClosedError:
|
|
||||||
self.websockets.pop(ws.uid, None)
|
|
||||||
self.server.remove_subscription(ws)
|
|
||||||
logging.info(f"Websocket Removed: {ws.uid}")
|
|
||||||
except Exception:
|
|
||||||
logging.exception(
|
|
||||||
f"Error sending data over websocket: {ws.uid}")
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
async with self.ws_lock:
|
if not self.websockets:
|
||||||
for ws in list(self.websockets.values()):
|
return
|
||||||
ws.close()
|
self.closed_event = Event()
|
||||||
self.websockets = {}
|
for ws in list(self.websockets.values()):
|
||||||
|
ws.close()
|
||||||
|
try:
|
||||||
|
await self.closed_event.wait(2.)
|
||||||
|
except tornado.util.TimeoutError:
|
||||||
|
pass
|
||||||
|
self.closed_event = None
|
||||||
|
|
||||||
class WebSocket(WebSocketHandler, Subscribable):
|
class WebSocket(WebSocketHandler, Subscribable):
|
||||||
def initialize(self) -> None:
|
def initialize(self) -> None:
|
||||||
|
@ -385,9 +383,12 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
self.uid = id(self)
|
self.uid = id(self)
|
||||||
self.is_closed: bool = False
|
self.is_closed: bool = False
|
||||||
self.ip_addr: str = self.request.remote_ip
|
self.ip_addr: str = self.request.remote_ip
|
||||||
|
self.queue_busy: bool = False
|
||||||
|
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||||
|
|
||||||
async def open(self, *args, **kwargs) -> None:
|
def open(self, *args, **kwargs) -> None:
|
||||||
await self.wsm.add_websocket(self)
|
self.set_nodelay(True)
|
||||||
|
self.wsm.add_websocket(self)
|
||||||
|
|
||||||
def on_message(self, message: Union[bytes, str]) -> None:
|
def on_message(self, message: Union[bytes, str]) -> None:
|
||||||
io_loop = IOLoop.current()
|
io_loop = IOLoop.current()
|
||||||
|
@ -397,30 +398,48 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
try:
|
try:
|
||||||
response = await self.rpc.dispatch(message, self)
|
response = await self.rpc.dispatch(message, self)
|
||||||
if response is not None:
|
if response is not None:
|
||||||
self.write_message(response)
|
self.queue_message(response)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception("Websocket Command Error")
|
logging.exception("Websocket Command Error")
|
||||||
|
|
||||||
def send_status(self, status: Dict[str, Any]) -> None:
|
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
||||||
if not status or self.is_closed:
|
self.message_buf.append(message)
|
||||||
|
if self.queue_busy:
|
||||||
return
|
return
|
||||||
try:
|
self.queue_busy = True
|
||||||
self.write_message({
|
IOLoop.current().spawn_callback(self._process_messages)
|
||||||
'jsonrpc': "2.0",
|
|
||||||
'method': "notify_status_update",
|
async def _process_messages(self):
|
||||||
'params': [status]})
|
if self.is_closed:
|
||||||
except WebSocketClosedError:
|
self.message_buf = []
|
||||||
self.is_closed = True
|
self.queue_busy = False
|
||||||
logging.info(
|
return
|
||||||
f"Websocket Closed During Status Update: {self.uid}")
|
while self.message_buf:
|
||||||
except Exception:
|
msg = self.message_buf.pop(0)
|
||||||
logging.exception(
|
try:
|
||||||
f"Error sending data over websocket: {self.uid}")
|
await self.write_message(msg)
|
||||||
|
except WebSocketClosedError:
|
||||||
|
self.is_closed = True
|
||||||
|
logging.info(
|
||||||
|
f"Websocket closed while writing: {self.uid}")
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
f"Error sending data over websocket: {self.uid}")
|
||||||
|
self.queue_busy = False
|
||||||
|
|
||||||
|
def send_status(self, status: Dict[str, Any]) -> None:
|
||||||
|
if not status:
|
||||||
|
return
|
||||||
|
self.queue_message({
|
||||||
|
'jsonrpc': "2.0",
|
||||||
|
'method': "notify_status_update",
|
||||||
|
'params': [status]})
|
||||||
|
|
||||||
def on_close(self) -> None:
|
def on_close(self) -> None:
|
||||||
self.is_closed = True
|
self.is_closed = True
|
||||||
io_loop = IOLoop.current()
|
self.message_buf = []
|
||||||
io_loop.spawn_callback(self.wsm.remove_websocket, self)
|
self.wsm.remove_websocket(self)
|
||||||
|
|
||||||
def check_origin(self, origin: str) -> bool:
|
def check_origin(self, origin: str) -> bool:
|
||||||
if not super(WebSocket, self).check_origin(origin):
|
if not super(WebSocket, self).check_origin(origin):
|
||||||
|
|
Loading…
Reference in New Issue