websockets: implement server.connection.identify method
Provide a remote method by which clients may identify their name, version, and type. Signed-off-by: Eric Callahan <arksine.code@gmail.com> fix Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
b86e86aff2
commit
3667e0d41f
|
@ -38,6 +38,7 @@ if TYPE_CHECKING:
|
||||||
RPCCallback = Callable[..., Coroutine]
|
RPCCallback = Callable[..., Coroutine]
|
||||||
AuthComp = Optional[components.authorization.Authorization]
|
AuthComp = Optional[components.authorization.Authorization]
|
||||||
|
|
||||||
|
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "other"]
|
||||||
SENTINEL = SentinelClass.get_instance()
|
SENTINEL = SentinelClass.get_instance()
|
||||||
|
|
||||||
class Subscribable:
|
class Subscribable:
|
||||||
|
@ -277,6 +278,8 @@ class WebsocketManager(APITransport):
|
||||||
self.closed_event: Optional[asyncio.Event] = None
|
self.closed_event: Optional[asyncio.Event] = None
|
||||||
|
|
||||||
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
||||||
|
self.rpc.register_method(
|
||||||
|
"server.connection.identify", self._handle_identify)
|
||||||
|
|
||||||
def register_notification(self,
|
def register_notification(self,
|
||||||
event_name: str,
|
event_name: str,
|
||||||
|
@ -337,12 +340,65 @@ class WebsocketManager(APITransport):
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
return {'websocket_id': ws.uid}
|
return {'websocket_id': ws.uid}
|
||||||
|
|
||||||
|
async def _handle_identify(self,
|
||||||
|
ws: WebSocket,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, int]:
|
||||||
|
try:
|
||||||
|
name = str(kwargs["client_name"])
|
||||||
|
version = str(kwargs["version"])
|
||||||
|
client_type: str = str(kwargs["type"]).lower()
|
||||||
|
url = str(kwargs["url"])
|
||||||
|
except KeyError as e:
|
||||||
|
missing_key = str(e).split(":")[-1].strip()
|
||||||
|
raise self.server.error(
|
||||||
|
f"No data for argument: {missing_key}"
|
||||||
|
) from None
|
||||||
|
if client_type not in CLIENT_TYPES:
|
||||||
|
raise self.server.error(f"Invalid Client Type: {client_type}")
|
||||||
|
ws.client_data = {
|
||||||
|
"name": name,
|
||||||
|
"version": version,
|
||||||
|
"type": client_type,
|
||||||
|
"url": url
|
||||||
|
}
|
||||||
|
logging.info(
|
||||||
|
f"Websocket {ws.uid} Client Identified - "
|
||||||
|
f"Name: {name}, Version: {version}, Type: {client_type}"
|
||||||
|
)
|
||||||
|
return {'connection_id': ws.uid}
|
||||||
|
|
||||||
def has_websocket(self, ws_id: int) -> bool:
|
def has_websocket(self, ws_id: int) -> bool:
|
||||||
return ws_id in self.websockets
|
return ws_id in self.websockets
|
||||||
|
|
||||||
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
def get_websocket(self, ws_id: int) -> Optional[WebSocket]:
|
||||||
return self.websockets.get(ws_id, None)
|
return self.websockets.get(ws_id, None)
|
||||||
|
|
||||||
|
def get_websockets_by_type(self, client_type: str) -> List[WebSocket]:
|
||||||
|
if not client_type:
|
||||||
|
return []
|
||||||
|
ret: List[WebSocket] = []
|
||||||
|
for ws in self.websockets.values():
|
||||||
|
if ws.client_data.get("type", "") == client_type.lower():
|
||||||
|
ret.append(ws)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_websockets_by_name(self, name: str) -> List[WebSocket]:
|
||||||
|
if not name:
|
||||||
|
return []
|
||||||
|
ret: List[WebSocket] = []
|
||||||
|
for ws in self.websockets.values():
|
||||||
|
if ws.client_data.get("name", "").lower() == name.lower():
|
||||||
|
ret.append(ws)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_unidentified_websockets(self) -> List[WebSocket]:
|
||||||
|
ret: List[WebSocket] = []
|
||||||
|
for ws in self.websockets.values():
|
||||||
|
if not ws.client_data:
|
||||||
|
ret.append(ws)
|
||||||
|
return ret
|
||||||
|
|
||||||
def add_websocket(self, ws: WebSocket) -> None:
|
def add_websocket(self, ws: WebSocket) -> None:
|
||||||
self.websockets[ws.uid] = ws
|
self.websockets[ws.uid] = ws
|
||||||
logging.debug(f"New Websocket Added: {ws.uid}")
|
logging.debug(f"New Websocket Added: {ws.uid}")
|
||||||
|
@ -393,6 +449,7 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||||
self.last_pong_time: float = self.event_loop.get_loop_time()
|
self.last_pong_time: float = self.event_loop.get_loop_time()
|
||||||
self._connected_time: float = 0.
|
self._connected_time: float = 0.
|
||||||
|
self._client_data: Dict[str, str] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hostname(self) -> str:
|
def hostname(self) -> str:
|
||||||
|
@ -402,6 +459,14 @@ class WebSocket(WebSocketHandler, Subscribable):
|
||||||
def start_time(self) -> float:
|
def start_time(self) -> float:
|
||||||
return self._connected_time
|
return self._connected_time
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_data(self) -> Dict[str, str]:
|
||||||
|
return self._client_data
|
||||||
|
|
||||||
|
@client_data.setter
|
||||||
|
def client_data(self, data: Dict[str, str]) -> None:
|
||||||
|
self._client_data = data
|
||||||
|
|
||||||
def open(self, *args, **kwargs) -> None:
|
def open(self, *args, **kwargs) -> None:
|
||||||
self.set_nodelay(True)
|
self.set_nodelay(True)
|
||||||
self._connected_time = self.event_loop.get_loop_time()
|
self._connected_time = self.event_loop.get_loop_time()
|
||||||
|
|
Loading…
Reference in New Issue