websockets: implement server.connection.identify method
Provide a remote method by which clients may identify their name, version, and type. Signed-off-by: Eric Callahan <arksine.code@gmail.com> fix Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
b86e86aff2
commit
3667e0d41f
|
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
|||
RPCCallback = Callable[..., Coroutine]
|
||||
AuthComp = Optional[components.authorization.Authorization]
|
||||
|
||||
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "other"]
|
||||
SENTINEL = SentinelClass.get_instance()
|
||||
|
||||
class Subscribable:
|
||||
|
@ -277,6 +278,8 @@ class WebsocketManager(APITransport):
|
|||
self.closed_event: Optional[asyncio.Event] = None
|
||||
|
||||
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
||||
self.rpc.register_method(
|
||||
"server.connection.identify", self._handle_identify)
|
||||
|
||||
def register_notification(self,
|
||||
event_name: str,
|
||||
|
@ -337,12 +340,65 @@ class WebsocketManager(APITransport):
|
|||
) -> Dict[str, int]:
|
||||
return {'websocket_id': ws.uid}
|
||||
|
||||
async def _handle_identify(self,
|
||||
ws: WebSocket,
|
||||
**kwargs
|
||||
) -> Dict[str, int]:
|
||||
try:
|
||||
name = str(kwargs["client_name"])
|
||||
version = str(kwargs["version"])
|
||||
client_type: str = str(kwargs["type"]).lower()
|
||||
url = str(kwargs["url"])
|
||||
except KeyError as e:
|
||||
missing_key = str(e).split(":")[-1].strip()
|
||||
raise self.server.error(
|
||||
f"No data for argument: {missing_key}"
|
||||
) from None
|
||||
if client_type not in CLIENT_TYPES:
|
||||
raise self.server.error(f"Invalid Client Type: {client_type}")
|
||||
ws.client_data = {
|
||||
"name": name,
|
||||
"version": version,
|
||||
"type": client_type,
|
||||
"url": url
|
||||
}
|
||||
logging.info(
|
||||
f"Websocket {ws.uid} Client Identified - "
|
||||
f"Name: {name}, Version: {version}, Type: {client_type}"
|
||||
)
|
||||
return {'connection_id': ws.uid}
|
||||
|
||||
def has_websocket(self, ws_id: int) -> bool:
|
||||
return ws_id in self.websockets
|
||||
|
||||
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
||||
return self.websockets.get(ws_id, None)
|
||||
|
||||
def get_websockets_by_type(self, client_type: str) -> List[WebSocket]:
|
||||
if not client_type:
|
||||
return []
|
||||
ret: List[WebSocket] = []
|
||||
for ws in self.websockets.values():
|
||||
if ws.client_data.get("type", "") == client_type.lower():
|
||||
ret.append(ws)
|
||||
return ret
|
||||
|
||||
def get_websockets_by_name(self, name: str) -> List[WebSocket]:
|
||||
if not name:
|
||||
return []
|
||||
ret: List[WebSocket] = []
|
||||
for ws in self.websockets.values():
|
||||
if ws.client_data.get("name", "").lower() == name.lower():
|
||||
ret.append(ws)
|
||||
return ret
|
||||
|
||||
def get_unidentified_websockets(self) -> List[WebSocket]:
|
||||
ret: List[WebSocket] = []
|
||||
for ws in self.websockets.values():
|
||||
if not ws.client_data:
|
||||
ret.append(ws)
|
||||
return ret
|
||||
|
||||
def add_websocket(self, ws: WebSocket) -> None:
|
||||
self.websockets[ws.uid] = ws
|
||||
logging.debug(f"New Websocket Added: {ws.uid}")
|
||||
|
@ -393,6 +449,7 @@ class WebSocket(WebSocketHandler, Subscribable):
|
|||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||
self.last_pong_time: float = self.event_loop.get_loop_time()
|
||||
self._connected_time: float = 0.
|
||||
self._client_data: Dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
|
@ -402,6 +459,14 @@ class WebSocket(WebSocketHandler, Subscribable):
|
|||
def start_time(self) -> float:
|
||||
return self._connected_time
|
||||
|
||||
@property
|
||||
def client_data(self) -> Dict[str, str]:
|
||||
return self._client_data
|
||||
|
||||
@client_data.setter
|
||||
def client_data(self, data: Dict[str, str]) -> None:
|
||||
self._client_data = data
|
||||
|
||||
def open(self, *args, **kwargs) -> None:
|
||||
self.set_nodelay(True)
|
||||
self._connected_time = self.event_loop.get_loop_time()
|
||||
|
|
Loading…
Reference in New Issue