websockets: implement server.connection.identify method

Provide a remote method by which clients may identify their
name, version, and type.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>

fix

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-03-04 17:20:38 -05:00
parent b86e86aff2
commit 3667e0d41f
No known key found for this signature in database
GPG Key ID: 7027245FBBDDF59A
1 changed files with 65 additions and 0 deletions

View File

@ -38,6 +38,7 @@ if TYPE_CHECKING:
RPCCallback = Callable[..., Coroutine] RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[components.authorization.Authorization] AuthComp = Optional[components.authorization.Authorization]
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "other"]
SENTINEL = SentinelClass.get_instance() SENTINEL = SentinelClass.get_instance()
class Subscribable: class Subscribable:
@ -277,6 +278,8 @@ class WebsocketManager(APITransport):
self.closed_event: Optional[asyncio.Event] = None self.closed_event: Optional[asyncio.Event] = None
self.rpc.register_method("server.websocket.id", self._handle_id_request) self.rpc.register_method("server.websocket.id", self._handle_id_request)
self.rpc.register_method(
"server.connection.identify", self._handle_identify)
def register_notification(self, def register_notification(self,
event_name: str, event_name: str,
@ -337,12 +340,65 @@ class WebsocketManager(APITransport):
) -> Dict[str, int]: ) -> Dict[str, int]:
return {'websocket_id': ws.uid} return {'websocket_id': ws.uid}
async def _handle_identify(self,
ws: WebSocket,
**kwargs
) -> Dict[str, int]:
try:
name = str(kwargs["client_name"])
version = str(kwargs["version"])
client_type: str = str(kwargs["type"]).lower()
url = str(kwargs["url"])
except KeyError as e:
missing_key = str(e).split(":")[-1].strip()
raise self.server.error(
f"No data for argument: {missing_key}"
) from None
if client_type not in CLIENT_TYPES:
raise self.server.error(f"Invalid Client Type: {client_type}")
ws.client_data = {
"name": name,
"version": version,
"type": client_type,
"url": url
}
logging.info(
f"Websocket {ws.uid} Client Identified - "
f"Name: {name}, Version: {version}, Type: {client_type}"
)
return {'connection_id': ws.uid}
def has_websocket(self, ws_id: int) -> bool: def has_websocket(self, ws_id: int) -> bool:
return ws_id in self.websockets return ws_id in self.websockets
def get_websocket(self, ws_id: int) -> Optional[WebSocket]: def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
return self.websockets.get(ws_id, None) return self.websockets.get(ws_id, None)
def get_websockets_by_type(self, client_type: str) -> List[WebSocket]:
if not client_type:
return []
ret: List[WebSocket] = []
for ws in self.websockets.values():
if ws.client_data.get("type", "") == client_type.lower():
ret.append(ws)
return ret
def get_websockets_by_name(self, name: str) -> List[WebSocket]:
if not name:
return []
ret: List[WebSocket] = []
for ws in self.websockets.values():
if ws.client_data.get("name", "").lower() == name.lower():
ret.append(ws)
return ret
def get_unidentified_websockets(self) -> List[WebSocket]:
ret: List[WebSocket] = []
for ws in self.websockets.values():
if not ws.client_data:
ret.append(ws)
return ret
def add_websocket(self, ws: WebSocket) -> None: def add_websocket(self, ws: WebSocket) -> None:
self.websockets[ws.uid] = ws self.websockets[ws.uid] = ws
logging.debug(f"New Websocket Added: {ws.uid}") logging.debug(f"New Websocket Added: {ws.uid}")
@ -393,6 +449,7 @@ class WebSocket(WebSocketHandler, Subscribable):
self.message_buf: List[Union[str, Dict[str, Any]]] = [] self.message_buf: List[Union[str, Dict[str, Any]]] = []
self.last_pong_time: float = self.event_loop.get_loop_time() self.last_pong_time: float = self.event_loop.get_loop_time()
self._connected_time: float = 0. self._connected_time: float = 0.
self._client_data: Dict[str, str] = {}
@property @property
def hostname(self) -> str: def hostname(self) -> str:
@ -402,6 +459,14 @@ class WebSocket(WebSocketHandler, Subscribable):
def start_time(self) -> float: def start_time(self) -> float:
return self._connected_time return self._connected_time
@property
def client_data(self) -> Dict[str, str]:
return self._client_data
@client_data.setter
def client_data(self, data: Dict[str, str]) -> None:
self._client_data = data
def open(self, *args, **kwargs) -> None: def open(self, *args, **kwargs) -> None:
self.set_nodelay(True) self.set_nodelay(True)
self._connected_time = self.event_loop.get_loop_time() self._connected_time = self.event_loop.get_loop_time()