websockets: move JsonRPC and BaseSocketClient to common

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-02-24 06:55:14 -05:00
parent 201e84cd94
commit 1b9f29db13
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
5 changed files with 420 additions and 408 deletions

View File

@ -6,6 +6,9 @@
from __future__ import annotations from __future__ import annotations
import ipaddress import ipaddress
import logging
import copy
import json
from .utils import ServerError, Sentinel from .utils import ServerError, Sentinel
# Annotation imports # Annotation imports
@ -20,11 +23,14 @@ from typing import (
Union, Union,
Dict, Dict,
List, List,
Awaitable
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .websockets import BaseSocketClient from .server import Server
from .websockets import WebsocketManager
from .components.authorization import Authorization from .components.authorization import Authorization
from asyncio import Future
_T = TypeVar("_T") _T = TypeVar("_T")
_C = TypeVar("_C", str, bool, float, int) _C = TypeVar("_C", str, bool, float, int)
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@ -40,6 +46,202 @@ class Subscribable:
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
class APIDefinition:
def __init__(self,
endpoint: str,
http_uri: str,
jrpc_methods: List[str],
request_methods: Union[str, List[str]],
transports: List[str],
callback: Optional[Callable[[WebRequest], Coroutine]],
need_object_parser: bool):
self.endpoint = endpoint
self.uri = http_uri
self.jrpc_methods = jrpc_methods
if not isinstance(request_methods, list):
request_methods = [request_methods]
self.request_methods = request_methods
self.supported_transports = transports
self.callback = callback
self.need_object_parser = need_object_parser
class APITransport:
def register_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError
def remove_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError
class BaseRemoteConnection(Subscribable):
def on_create(self, server: Server) -> None:
self.server = server
self.eventloop = server.get_event_loop()
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
self.rpc = self.wsm.rpc
self._uid = id(self)
self.ip_addr = ""
self.is_closed: bool = False
self.queue_busy: bool = False
self.pending_responses: Dict[int, Future] = {}
self.message_buf: List[Union[str, Dict[str, Any]]] = []
self._connected_time: float = 0.
self._identified: bool = False
self._client_data: Dict[str, str] = {
"name": "unknown",
"version": "",
"type": "",
"url": ""
}
self._need_auth: bool = False
self._user_info: Optional[Dict[str, Any]] = None
@property
def user_info(self) -> Optional[Dict[str, Any]]:
return self._user_info
@user_info.setter
def user_info(self, uinfo: Dict[str, Any]) -> None:
self._user_info = uinfo
self._need_auth = False
@property
def need_auth(self) -> bool:
return self._need_auth
@property
def uid(self) -> int:
return self._uid
@property
def hostname(self) -> str:
return ""
@property
def start_time(self) -> float:
return self._connected_time
@property
def identified(self) -> bool:
return self._identified
@property
def client_data(self) -> Dict[str, str]:
return self._client_data
@client_data.setter
def client_data(self, data: Dict[str, str]) -> None:
self._client_data = data
self._identified = True
async def _process_message(self, message: str) -> None:
try:
response = await self.rpc.dispatch(message, self)
if response is not None:
self.queue_message(response)
except Exception:
logging.exception("Websocket Command Error")
def queue_message(self, message: Union[str, Dict[str, Any]]):
self.message_buf.append(message)
if self.queue_busy:
return
self.queue_busy = True
self.eventloop.register_callback(self._write_messages)
def authenticate(
self,
token: Optional[str] = None,
api_key: Optional[str] = None
) -> None:
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if token is not None:
self.user_info = auth.validate_jwt(token)
elif api_key is not None and self.user_info is None:
self.user_info = auth.validate_api_key(api_key)
else:
self.check_authenticated()
def check_authenticated(self, path: str = "") -> None:
if not self._need_auth:
return
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool:
if self._user_info is None:
return False
if user == self._user_info.get("username", ""):
self._user_info = None
return True
return False
async def _write_messages(self):
if self.is_closed:
self.message_buf = []
self.queue_busy = False
return
while self.message_buf:
msg = self.message_buf.pop(0)
await self.write_to_socket(msg)
self.queue_busy = False
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
raise NotImplementedError("Children must implement write_to_socket")
def send_status(self,
status: Dict[str, Any],
eventtime: float
) -> None:
if not status:
return
self.queue_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status, eventtime]})
def call_method(
self,
method: str,
params: Optional[Union[List, Dict[str, Any]]] = None
) -> Awaitable:
fut = self.eventloop.create_future()
msg = {
'jsonrpc': "2.0",
'method': method,
'id': id(fut)
}
if params is not None:
msg["params"] = params
self.pending_responses[id(fut)] = fut
self.queue_message(msg)
return fut
def send_notification(self, name: str, data: List) -> None:
self.wsm.notify_clients(name, data, [self._uid])
def resolve_pending_response(
self, response_id: int, result: Any
) -> bool:
fut = self.pending_responses.pop(response_id, None)
if fut is None:
return False
if isinstance(result, ServerError):
fut.set_exception(result)
else:
fut.set_result(result)
return True
def close_socket(self, code: int, reason: str) -> None:
raise NotImplementedError("Children must implement close_socket()")
class WebRequest: class WebRequest:
def __init__(self, def __init__(self,
endpoint: str, endpoint: str,
@ -72,8 +274,8 @@ class WebRequest:
def get_subscribable(self) -> Optional[Subscribable]: def get_subscribable(self) -> Optional[Subscribable]:
return self.conn return self.conn
def get_client_connection(self) -> Optional[BaseSocketClient]: def get_client_connection(self) -> Optional[BaseRemoteConnection]:
if isinstance(self.conn, BaseSocketClient): if isinstance(self.conn, BaseRemoteConnection):
return self.conn return self.conn
return None return None
@ -142,28 +344,188 @@ class WebRequest:
) -> Union[bool, _T]: ) -> Union[bool, _T]:
return self._get_converted_arg(key, default, bool) return self._get_converted_arg(key, default, bool)
class APIDefinition:
def __init__(self,
endpoint: str,
http_uri: str,
jrpc_methods: List[str],
request_methods: Union[str, List[str]],
transports: List[str],
callback: Optional[Callable[[WebRequest], Coroutine]],
need_object_parser: bool):
self.endpoint = endpoint
self.uri = http_uri
self.jrpc_methods = jrpc_methods
if not isinstance(request_methods, list):
request_methods = [request_methods]
self.request_methods = request_methods
self.supported_transports = transports
self.callback = callback
self.need_object_parser = need_object_parser
class APITransport: class JsonRPC:
def register_api_handler(self, api_def: APIDefinition) -> None: def __init__(
raise NotImplementedError self, server: Server, transport: str = "Websocket"
) -> None:
self.methods: Dict[str, RPCCallback] = {}
self.transport = transport
self.sanitize_response = False
self.verbose = server.is_verbose_enabled()
def remove_api_handler(self, api_def: APIDefinition) -> None: def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
raise NotImplementedError if not self.verbose:
return
self.sanitize_response = False
output = rpc_obj
method: Optional[str] = rpc_obj.get("method")
params: Dict[str, Any] = rpc_obj.get("params", {})
if isinstance(method, str):
if (
method.startswith("access.") or
method == "machine.sudo.password"
):
self.sanitize_response = True
if params and isinstance(params, dict):
output = copy.deepcopy(rpc_obj)
output["params"] = {key: "<sanitized>" for key in params}
elif method == "server.connection.identify":
output = copy.deepcopy(rpc_obj)
for field in ["access_token", "api_key"]:
if field in params:
output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
if not self.verbose:
return
if resp_obj is None:
return
output = resp_obj
if self.sanitize_response and "result" in resp_obj:
output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>"
self.sanitize_response = False
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method
def remove_method(self, name: str) -> None:
self.methods.pop(name, None)
async def dispatch(self,
data: str,
conn: Optional[BaseRemoteConnection] = None
) -> Optional[str]:
try:
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
except Exception:
msg = f"{self.transport} data not json: {data}"
logging.exception(msg)
err = self.build_error(-32700, "Parse error")
return json.dumps(err)
if isinstance(obj, list):
responses: List[Dict[str, Any]] = []
for item in obj:
self._log_request(item)
resp = await self.process_object(item, conn)
if resp is not None:
self._log_response(resp)
responses.append(resp)
if responses:
return json.dumps(responses)
else:
self._log_request(obj)
response = await self.process_object(obj, conn)
if response is not None:
self._log_response(response)
return json.dumps(response)
return None
async def process_object(self,
obj: Dict[str, Any],
conn: Optional[BaseRemoteConnection]
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = obj.get('id', None)
rpc_version: str = obj.get('jsonrpc', "")
if rpc_version != "2.0":
return self.build_error(-32600, "Invalid Request", req_id)
method_name = obj.get('method', Sentinel.MISSING)
if method_name is Sentinel.MISSING:
self.process_response(obj, conn)
return None
if not isinstance(method_name, str):
return self.build_error(-32600, "Invalid Request", req_id)
method = self.methods.get(method_name, None)
if method is None:
return self.build_error(-32601, "Method not found", req_id)
params: Dict[str, Any] = {}
if 'params' in obj:
params = obj['params']
if not isinstance(params, dict):
return self.build_error(
-32602, f"Invalid params:", req_id, True)
response = await self.execute_method(method, req_id, conn, params)
return response
def process_response(
self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection]
) -> None:
if conn is None:
logging.debug(f"RPC Response to non-socket request: {obj}")
return
response_id = obj.get("id")
if response_id is None:
logging.debug(f"RPC Response with null ID: {obj}")
return
result = obj.get("result")
if result is None:
name = conn.client_data["name"]
error = obj.get("error")
msg = f"Invalid Response: {obj}"
code = -32600
if isinstance(error, dict):
msg = error.get("message", msg)
code = error.get("code", code)
msg = f"{name} rpc error: {code} {msg}"
ret = ServerError(msg, 418)
else:
ret = result
conn.resolve_pending_response(response_id, ret)
async def execute_method(self,
callback: RPCCallback,
req_id: Optional[int],
conn: Optional[BaseRemoteConnection],
params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
if conn is not None:
params["_socket_"] = conn
try:
result = await callback(params)
except TypeError as e:
return self.build_error(
-32602, f"Invalid params:\n{e}", req_id, True)
except ServerError as e:
code = e.status_code
if code == 404:
code = -32601
elif code == 401:
code = -32602
return self.build_error(code, str(e), req_id, True)
except Exception as e:
return self.build_error(-31000, str(e), req_id, True)
if req_id is None:
return None
else:
return self.build_result(result, req_id)
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
return {
'jsonrpc': "2.0",
'result': result,
'id': req_id
}
def build_error(self,
code: int,
msg: str,
req_id: Optional[int] = None,
is_exc: bool = False
) -> Dict[str, Any]:
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
if is_exc:
logging.exception(log_msg)
else:
logging.info(log_msg)
return {
'jsonrpc': "2.0",
'error': {'code': code, 'message': msg},
'id': req_id
}

View File

@ -8,7 +8,7 @@ import asyncio
import pathlib import pathlib
import logging import logging
import json import json
from ..websockets import BaseSocketClient from ..common import BaseRemoteConnection
from ..utils import get_unix_peer_credentials from ..utils import get_unix_peer_credentials
# Annotation imports # Annotation imports
@ -32,7 +32,7 @@ UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
class ExtensionManager: class ExtensionManager:
def __init__(self, config: ConfigHelper) -> None: def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
self.agents: Dict[str, BaseSocketClient] = {} self.agents: Dict[str, BaseRemoteConnection] = {}
self.uds_server: Optional[asyncio.AbstractServer] = None self.uds_server: Optional[asyncio.AbstractServer] = None
self.server.register_endpoint( self.server.register_endpoint(
"/connection/send_event", ["POST"], self._handle_agent_event, "/connection/send_event", ["POST"], self._handle_agent_event,
@ -45,7 +45,7 @@ class ExtensionManager:
"/server/extensions/request", ["POST"], self._handle_call_agent "/server/extensions/request", ["POST"], self._handle_call_agent
) )
def register_agent(self, connection: BaseSocketClient) -> None: def register_agent(self, connection: BaseRemoteConnection) -> None:
data = connection.client_data data = connection.client_data
name = data["name"] name = data["name"]
client_type = data["type"] client_type = data["type"]
@ -64,7 +64,7 @@ class ExtensionManager:
} }
connection.send_notification("agent_event", [evt]) connection.send_notification("agent_event", [evt])
def remove_agent(self, connection: BaseSocketClient) -> None: def remove_agent(self, connection: BaseRemoteConnection) -> None:
name = connection.client_data["name"] name = connection.client_data["name"]
if name in self.agents: if name in self.agents:
del self.agents[name] del self.agents[name]
@ -135,7 +135,7 @@ class ExtensionManager:
await self.uds_server.wait_closed() await self.uds_server.wait_closed()
self.uds_server = None self.uds_server = None
class UnixSocketClient(BaseSocketClient): class UnixSocketClient(BaseRemoteConnection):
def __init__( def __init__(
self, self,
server: Server, server: Server,

View File

@ -13,8 +13,7 @@ import pathlib
import ssl import ssl
from collections import deque from collections import deque
import paho.mqtt.client as paho_mqtt import paho.mqtt.client as paho_mqtt
from ..common import Subscribable, WebRequest, APITransport from ..common import Subscribable, WebRequest, APITransport, JsonRPC
from ..websockets import JsonRPC
# Annotation imports # Annotation imports
from typing import ( from typing import (

View File

@ -32,7 +32,8 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
from ..app import InternalTransport from ..app import InternalTransport
from ..confighelper import ConfigHelper from ..confighelper import ConfigHelper
from ..websockets import WebsocketManager, BaseSocketClient from ..websockets import WebsocketManager
from ..common import BaseRemoteConnection
from tornado.websocket import WebSocketClientConnection from tornado.websocket import WebSocketClientConnection
from .database import MoonrakerDatabase from .database import MoonrakerDatabase
from .klippy_apis import KlippyAPI from .klippy_apis import KlippyAPI
@ -609,7 +610,7 @@ class SimplyPrint(Subscribable):
is_on = device_info["status"] == "on" is_on = device_info["status"] == "on"
self.send_sp("power_controller", {"on": is_on}) self.send_sp("power_controller", {"on": is_on})
def _on_websocket_identified(self, ws: BaseSocketClient) -> None: def _on_websocket_identified(self, ws: BaseRemoteConnection) -> None:
if ( if (
self.cache.current_wsid is None and self.cache.current_wsid is None and
ws.client_data.get("type", "") == "web" ws.client_data.get("type", "") == "web"
@ -622,7 +623,7 @@ class SimplyPrint(Subscribable):
self.cache.current_wsid = ws.uid self.cache.current_wsid = ws.uid
self.send_sp("machine_data", ui_data) self.send_sp("machine_data", ui_data)
def _on_websocket_removed(self, ws: BaseSocketClient) -> None: def _on_websocket_removed(self, ws: BaseRemoteConnection) -> None:
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid: if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
return return
ui_data = self._get_ui_info() ui_data = self._get_ui_info()

View File

@ -7,13 +7,17 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import ipaddress import ipaddress
import json
import asyncio import asyncio
import copy
from tornado.websocket import WebSocketHandler, WebSocketClosedError from tornado.websocket import WebSocketHandler, WebSocketClosedError
from tornado.web import HTTPError from tornado.web import HTTPError
from .common import WebRequest, Subscribable, APITransport, APIDefinition from .common import (
from .utils import ServerError, Sentinel WebRequest,
BaseRemoteConnection,
APITransport,
APIDefinition,
JsonRPC
)
from .utils import ServerError
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -42,195 +46,10 @@ if TYPE_CHECKING:
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"] CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
class JsonRPC:
def __init__(
self, server: Server, transport: str = "Websocket"
) -> None:
self.methods: Dict[str, RPCCallback] = {}
self.transport = transport
self.sanitize_response = False
self.verbose = server.is_verbose_enabled()
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
if not self.verbose:
return
self.sanitize_response = False
output = rpc_obj
method: Optional[str] = rpc_obj.get("method")
params: Dict[str, Any] = rpc_obj.get("params", {})
if isinstance(method, str):
if (
method.startswith("access.") or
method == "machine.sudo.password"
):
self.sanitize_response = True
if params and isinstance(params, dict):
output = copy.deepcopy(rpc_obj)
output["params"] = {key: "<sanitized>" for key in params}
elif method == "server.connection.identify":
output = copy.deepcopy(rpc_obj)
for field in ["access_token", "api_key"]:
if field in params:
output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
if not self.verbose:
return
if resp_obj is None:
return
output = resp_obj
if self.sanitize_response and "result" in resp_obj:
output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>"
self.sanitize_response = False
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method
def remove_method(self, name: str) -> None:
self.methods.pop(name, None)
async def dispatch(self,
data: str,
conn: Optional[BaseSocketClient] = None
) -> Optional[str]:
try:
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
except Exception:
msg = f"{self.transport} data not json: {data}"
logging.exception(msg)
err = self.build_error(-32700, "Parse error")
return json.dumps(err)
if isinstance(obj, list):
responses: List[Dict[str, Any]] = []
for item in obj:
self._log_request(item)
resp = await self.process_object(item, conn)
if resp is not None:
self._log_response(resp)
responses.append(resp)
if responses:
return json.dumps(responses)
else:
self._log_request(obj)
response = await self.process_object(obj, conn)
if response is not None:
self._log_response(response)
return json.dumps(response)
return None
async def process_object(self,
obj: Dict[str, Any],
conn: Optional[BaseSocketClient]
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = obj.get('id', None)
rpc_version: str = obj.get('jsonrpc', "")
if rpc_version != "2.0":
return self.build_error(-32600, "Invalid Request", req_id)
method_name = obj.get('method', Sentinel.MISSING)
if method_name is Sentinel.MISSING:
self.process_response(obj, conn)
return None
if not isinstance(method_name, str):
return self.build_error(-32600, "Invalid Request", req_id)
method = self.methods.get(method_name, None)
if method is None:
return self.build_error(-32601, "Method not found", req_id)
params: Dict[str, Any] = {}
if 'params' in obj:
params = obj['params']
if not isinstance(params, dict):
return self.build_error(
-32602, f"Invalid params:", req_id, True)
response = await self.execute_method(method, req_id, conn, params)
return response
def process_response(
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
) -> None:
if conn is None:
logging.debug(f"RPC Response to non-socket request: {obj}")
return
response_id = obj.get("id")
if response_id is None:
logging.debug(f"RPC Response with null ID: {obj}")
return
result = obj.get("result")
if result is None:
name = conn.client_data["name"]
error = obj.get("error")
msg = f"Invalid Response: {obj}"
code = -32600
if isinstance(error, dict):
msg = error.get("message", msg)
code = error.get("code", code)
msg = f"{name} rpc error: {code} {msg}"
ret = ServerError(msg, 418)
else:
ret = result
conn.resolve_pending_response(response_id, ret)
async def execute_method(self,
callback: RPCCallback,
req_id: Optional[int],
conn: Optional[BaseSocketClient],
params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
if conn is not None:
params["_socket_"] = conn
try:
result = await callback(params)
except TypeError as e:
return self.build_error(
-32602, f"Invalid params:\n{e}", req_id, True)
except ServerError as e:
code = e.status_code
if code == 404:
code = -32601
elif code == 401:
code = -32602
return self.build_error(code, str(e), req_id, True)
except Exception as e:
return self.build_error(-31000, str(e), req_id, True)
if req_id is None:
return None
else:
return self.build_result(result, req_id)
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
return {
'jsonrpc': "2.0",
'result': result,
'id': req_id
}
def build_error(self,
code: int,
msg: str,
req_id: Optional[int] = None,
is_exc: bool = False
) -> Dict[str, Any]:
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
if is_exc:
logging.exception(log_msg)
else:
logging.info(log_msg)
return {
'jsonrpc': "2.0",
'error': {'code': code, 'message': msg},
'id': req_id
}
class WebsocketManager(APITransport): class WebsocketManager(APITransport):
def __init__(self, server: Server) -> None: def __init__(self, server: Server) -> None:
self.server = server self.server = server
self.clients: Dict[int, BaseSocketClient] = {} self.clients: Dict[int, BaseRemoteConnection] = {}
self.bridge_connections: Dict[int, BridgeSocket] = {} self.bridge_connections: Dict[int, BridgeSocket] = {}
self.rpc = JsonRPC(server) self.rpc = JsonRPC(server)
self.closed_event: Optional[asyncio.Event] = None self.closed_event: Optional[asyncio.Event] = None
@ -289,7 +108,7 @@ class WebsocketManager(APITransport):
callback: Callable[[WebRequest], Coroutine] callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback: ) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any: async def func(args: Dict[str, Any]) -> Any:
sc: BaseSocketClient = args.pop("_socket_") sc: BaseRemoteConnection = args.pop("_socket_")
sc.check_authenticated(path=endpoint) sc.check_authenticated(path=endpoint)
result = await callback( result = await callback(
WebRequest(endpoint, args, request_method, sc, WebRequest(endpoint, args, request_method, sc,
@ -298,12 +117,12 @@ class WebsocketManager(APITransport):
return func return func
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]: async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"] sc: BaseRemoteConnection = args["_socket_"]
sc.check_authenticated() sc.check_authenticated()
return {'websocket_id': sc.uid} return {'websocket_id': sc.uid}
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]: async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"] sc: BaseRemoteConnection = args["_socket_"]
sc.authenticate( sc.authenticate(
token=args.get("access_token", None), token=args.get("access_token", None),
api_key=args.get("api_key", None) api_key=args.get("api_key", None)
@ -355,7 +174,7 @@ class WebsocketManager(APITransport):
def has_socket(self, ws_id: int) -> bool: def has_socket(self, ws_id: int) -> bool:
return ws_id in self.clients return ws_id in self.clients
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]: def get_client(self, ws_id: int) -> Optional[BaseRemoteConnection]:
sc = self.clients.get(ws_id, None) sc = self.clients.get(ws_id, None)
if sc is None or not isinstance(sc, WebSocket): if sc is None or not isinstance(sc, WebSocket):
return None return None
@ -363,37 +182,37 @@ class WebsocketManager(APITransport):
def get_clients_by_type( def get_clients_by_type(
self, client_type: str self, client_type: str
) -> List[BaseSocketClient]: ) -> List[BaseRemoteConnection]:
if not client_type: if not client_type:
return [] return []
ret: List[BaseSocketClient] = [] ret: List[BaseRemoteConnection] = []
for sc in self.clients.values(): for sc in self.clients.values():
if sc.client_data.get("type", "") == client_type.lower(): if sc.client_data.get("type", "") == client_type.lower():
ret.append(sc) ret.append(sc)
return ret return ret
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]: def get_clients_by_name(self, name: str) -> List[BaseRemoteConnection]:
if not name: if not name:
return [] return []
ret: List[BaseSocketClient] = [] ret: List[BaseRemoteConnection] = []
for sc in self.clients.values(): for sc in self.clients.values():
if sc.client_data.get("name", "").lower() == name.lower(): if sc.client_data.get("name", "").lower() == name.lower():
ret.append(sc) ret.append(sc)
return ret return ret
def get_unidentified_clients(self) -> List[BaseSocketClient]: def get_unidentified_clients(self) -> List[BaseRemoteConnection]:
ret: List[BaseSocketClient] = [] ret: List[BaseRemoteConnection] = []
for sc in self.clients.values(): for sc in self.clients.values():
if not sc.client_data: if not sc.client_data:
ret.append(sc) ret.append(sc)
return ret return ret
def add_client(self, sc: BaseSocketClient) -> None: def add_client(self, sc: BaseRemoteConnection) -> None:
self.clients[sc.uid] = sc self.clients[sc.uid] = sc
self.server.send_event("websockets:client_added", sc) self.server.send_event("websockets:client_added", sc)
logging.debug(f"New Websocket Added: {sc.uid}") logging.debug(f"New Websocket Added: {sc.uid}")
def remove_client(self, sc: BaseSocketClient) -> None: def remove_client(self, sc: BaseRemoteConnection) -> None:
old_sc = self.clients.pop(sc.uid, None) old_sc = self.clients.pop(sc.uid, None)
if old_sc is not None: if old_sc is not None:
self.server.send_event("websockets:client_removed", sc) self.server.send_event("websockets:client_removed", sc)
@ -449,176 +268,7 @@ class WebsocketManager(APITransport):
pass pass
self.closed_event = None self.closed_event = None
class BaseSocketClient(Subscribable): class WebSocket(WebSocketHandler, BaseRemoteConnection):
def on_create(self, server: Server) -> None:
self.server = server
self.eventloop = server.get_event_loop()
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
self.rpc = self.wsm.rpc
self._uid = id(self)
self.ip_addr = ""
self.is_closed: bool = False
self.queue_busy: bool = False
self.pending_responses: Dict[int, asyncio.Future] = {}
self.message_buf: List[Union[str, Dict[str, Any]]] = []
self._connected_time: float = 0.
self._identified: bool = False
self._client_data: Dict[str, str] = {
"name": "unknown",
"version": "",
"type": "",
"url": ""
}
self._need_auth: bool = False
self._user_info: Optional[Dict[str, Any]] = None
@property
def user_info(self) -> Optional[Dict[str, Any]]:
return self._user_info
@user_info.setter
def user_info(self, uinfo: Dict[str, Any]) -> None:
self._user_info = uinfo
self._need_auth = False
@property
def need_auth(self) -> bool:
return self._need_auth
@property
def uid(self) -> int:
return self._uid
@property
def hostname(self) -> str:
return ""
@property
def start_time(self) -> float:
return self._connected_time
@property
def identified(self) -> bool:
return self._identified
@property
def client_data(self) -> Dict[str, str]:
return self._client_data
@client_data.setter
def client_data(self, data: Dict[str, str]) -> None:
self._client_data = data
self._identified = True
async def _process_message(self, message: str) -> None:
try:
response = await self.rpc.dispatch(message, self)
if response is not None:
self.queue_message(response)
except Exception:
logging.exception("Websocket Command Error")
def queue_message(self, message: Union[str, Dict[str, Any]]):
self.message_buf.append(message)
if self.queue_busy:
return
self.queue_busy = True
self.eventloop.register_callback(self._write_messages)
def authenticate(
self,
token: Optional[str] = None,
api_key: Optional[str] = None
) -> None:
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if token is not None:
self.user_info = auth.validate_jwt(token)
elif api_key is not None and self.user_info is None:
self.user_info = auth.validate_api_key(api_key)
else:
self.check_authenticated()
def check_authenticated(self, path: str = "") -> None:
if not self._need_auth:
return
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool:
if self._user_info is None:
return False
if user == self._user_info.get("username", ""):
self._user_info = None
return True
return False
async def _write_messages(self):
if self.is_closed:
self.message_buf = []
self.queue_busy = False
return
while self.message_buf:
msg = self.message_buf.pop(0)
await self.write_to_socket(msg)
self.queue_busy = False
async def write_to_socket(
self, message: Union[str, Dict[str, Any]]
) -> None:
raise NotImplementedError("Children must implement write_to_socket")
def send_status(self,
status: Dict[str, Any],
eventtime: float
) -> None:
if not status:
return
self.queue_message({
'jsonrpc': "2.0",
'method': "notify_status_update",
'params': [status, eventtime]})
def call_method(
self,
method: str,
params: Optional[Union[List, Dict[str, Any]]] = None
) -> Awaitable:
fut = self.eventloop.create_future()
msg = {
'jsonrpc': "2.0",
'method': method,
'id': id(fut)
}
if params is not None:
msg["params"] = params
self.pending_responses[id(fut)] = fut
self.queue_message(msg)
return fut
def send_notification(self, name: str, data: List) -> None:
self.wsm.notify_clients(name, data, [self._uid])
def resolve_pending_response(
self, response_id: int, result: Any
) -> bool:
fut = self.pending_responses.pop(response_id, None)
if fut is None:
return False
if isinstance(result, ServerError):
fut.set_exception(result)
else:
fut.set_result(result)
return True
def close_socket(self, code: int, reason: str) -> None:
raise NotImplementedError("Children must implement close_socket()")
class WebSocket(WebSocketHandler, BaseSocketClient):
connection_count: int = 0 connection_count: int = 0
def initialize(self) -> None: def initialize(self) -> None: