From 3667e0d41f20415a8ce821d01f678afb2783ad69 Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Fri, 4 Mar 2022 17:20:38 -0500 Subject: [PATCH] websockets: implement server.connection.identify method Provide a remote method by which clients may identify their name, version, and type. Signed-off-by: Eric Callahan fix Signed-off-by: Eric Callahan --- moonraker/websockets.py | 65 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/moonraker/websockets.py b/moonraker/websockets.py index 399cdc6..d2171fe 100644 --- a/moonraker/websockets.py +++ b/moonraker/websockets.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: RPCCallback = Callable[..., Coroutine] AuthComp = Optional[components.authorization.Authorization] +CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "other"] SENTINEL = SentinelClass.get_instance() class Subscribable: @@ -277,6 +278,8 @@ class WebsocketManager(APITransport): self.closed_event: Optional[asyncio.Event] = None self.rpc.register_method("server.websocket.id", self._handle_id_request) + self.rpc.register_method( + "server.connection.identify", self._handle_identify) def register_notification(self, event_name: str, @@ -337,12 +340,65 @@ class WebsocketManager(APITransport): ) -> Dict[str, int]: return {'websocket_id': ws.uid} + async def _handle_identify(self, + ws: WebSocket, + **kwargs + ) -> Dict[str, int]: + try: + name = str(kwargs["client_name"]) + version = str(kwargs["version"]) + client_type: str = str(kwargs["type"]).lower() + url = str(kwargs["url"]) + except KeyError as e: + missing_key = str(e).split(":")[-1].strip() + raise self.server.error( + f"No data for argument: {missing_key}" + ) from None + if client_type not in CLIENT_TYPES: + raise self.server.error(f"Invalid Client Type: {client_type}") + ws.client_data = { + "name": name, + "version": version, + "type": client_type, + "url": url + } + logging.info( + f"Websocket {ws.uid} Client Identified - " + f"Name: {name}, Version: {version}, Type: {client_type}" + ) + return {'connection_id': ws.uid} + def has_websocket(self, ws_id: int) -> bool: return ws_id in self.websockets def get_websocket(self, ws_id: int) -> Optional[WebSocket]: return self.websockets.get(ws_id, None) + def get_websockets_by_type(self, client_type: str) -> List[WebSocket]: + if not client_type: + return [] + ret: List[WebSocket] = [] + for ws in self.websockets.values(): + if ws.client_data.get("type", "") == client_type.lower(): + ret.append(ws) + return ret + + def get_websockets_by_name(self, name: str) -> List[WebSocket]: + if not name: + return [] + ret: List[WebSocket] = [] + for ws in self.websockets.values(): + if ws.client_data.get("name", "").lower() == name.lower(): + ret.append(ws) + return ret + + def get_unidentified_websockets(self) -> List[WebSocket]: + ret: List[WebSocket] = [] + for ws in self.websockets.values(): + if not ws.client_data: + ret.append(ws) + return ret + def add_websocket(self, ws: WebSocket) -> None: self.websockets[ws.uid] = ws logging.debug(f"New Websocket Added: {ws.uid}") @@ -393,6 +449,7 @@ class WebSocket(WebSocketHandler, Subscribable): self.message_buf: List[Union[str, Dict[str, Any]]] = [] self.last_pong_time: float = self.event_loop.get_loop_time() self._connected_time: float = 0. + self._client_data: Dict[str, str] = {} @property def hostname(self) -> str: @@ -402,6 +459,14 @@ class WebSocket(WebSocketHandler, Subscribable): def start_time(self) -> float: return self._connected_time + @property + def client_data(self) -> Dict[str, str]: + return self._client_data + + @client_data.setter + def client_data(self, data: Dict[str, str]) -> None: + self._client_data = data + def open(self, *args, **kwargs) -> None: self.set_nodelay(True) self._connected_time = self.event_loop.get_loop_time()