websockets: move JsonRPC and BaseSocketClient to common
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
201e84cd94
commit
1b9f29db13
|
@ -6,6 +6,9 @@
|
|||
|
||||
from __future__ import annotations
|
||||
import ipaddress
|
||||
import logging
|
||||
import copy
|
||||
import json
|
||||
from .utils import ServerError, Sentinel
|
||||
|
||||
# Annotation imports
|
||||
|
@ -20,11 +23,14 @@ from typing import (
|
|||
Union,
|
||||
Dict,
|
||||
List,
|
||||
Awaitable
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .websockets import BaseSocketClient
|
||||
from .server import Server
|
||||
from .websockets import WebsocketManager
|
||||
from .components.authorization import Authorization
|
||||
from asyncio import Future
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", str, bool, float, int)
|
||||
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
@ -40,6 +46,202 @@ class Subscribable:
|
|||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class APIDefinition:
|
||||
def __init__(self,
|
||||
endpoint: str,
|
||||
http_uri: str,
|
||||
jrpc_methods: List[str],
|
||||
request_methods: Union[str, List[str]],
|
||||
transports: List[str],
|
||||
callback: Optional[Callable[[WebRequest], Coroutine]],
|
||||
need_object_parser: bool):
|
||||
self.endpoint = endpoint
|
||||
self.uri = http_uri
|
||||
self.jrpc_methods = jrpc_methods
|
||||
if not isinstance(request_methods, list):
|
||||
request_methods = [request_methods]
|
||||
self.request_methods = request_methods
|
||||
self.supported_transports = transports
|
||||
self.callback = callback
|
||||
self.need_object_parser = need_object_parser
|
||||
|
||||
class APITransport:
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class BaseRemoteConnection(Subscribable):
|
||||
def on_create(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.eventloop = server.get_event_loop()
|
||||
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||
self.rpc = self.wsm.rpc
|
||||
self._uid = id(self)
|
||||
self.ip_addr = ""
|
||||
self.is_closed: bool = False
|
||||
self.queue_busy: bool = False
|
||||
self.pending_responses: Dict[int, Future] = {}
|
||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||
self._connected_time: float = 0.
|
||||
self._identified: bool = False
|
||||
self._client_data: Dict[str, str] = {
|
||||
"name": "unknown",
|
||||
"version": "",
|
||||
"type": "",
|
||||
"url": ""
|
||||
}
|
||||
self._need_auth: bool = False
|
||||
self._user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self._user_info
|
||||
|
||||
@user_info.setter
|
||||
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
||||
self._user_info = uinfo
|
||||
self._need_auth = False
|
||||
|
||||
@property
|
||||
def need_auth(self) -> bool:
|
||||
return self._need_auth
|
||||
|
||||
@property
|
||||
def uid(self) -> int:
|
||||
return self._uid
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def start_time(self) -> float:
|
||||
return self._connected_time
|
||||
|
||||
@property
|
||||
def identified(self) -> bool:
|
||||
return self._identified
|
||||
|
||||
@property
|
||||
def client_data(self) -> Dict[str, str]:
|
||||
return self._client_data
|
||||
|
||||
@client_data.setter
|
||||
def client_data(self, data: Dict[str, str]) -> None:
|
||||
self._client_data = data
|
||||
self._identified = True
|
||||
|
||||
async def _process_message(self, message: str) -> None:
|
||||
try:
|
||||
response = await self.rpc.dispatch(message, self)
|
||||
if response is not None:
|
||||
self.queue_message(response)
|
||||
except Exception:
|
||||
logging.exception("Websocket Command Error")
|
||||
|
||||
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
||||
self.message_buf.append(message)
|
||||
if self.queue_busy:
|
||||
return
|
||||
self.queue_busy = True
|
||||
self.eventloop.register_callback(self._write_messages)
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> None:
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is None:
|
||||
return
|
||||
if token is not None:
|
||||
self.user_info = auth.validate_jwt(token)
|
||||
elif api_key is not None and self.user_info is None:
|
||||
self.user_info = auth.validate_api_key(api_key)
|
||||
else:
|
||||
self.check_authenticated()
|
||||
|
||||
def check_authenticated(self, path: str = "") -> None:
|
||||
if not self._need_auth:
|
||||
return
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is None:
|
||||
return
|
||||
if not auth.is_path_permitted(path):
|
||||
raise self.server.error("Unauthorized", 401)
|
||||
|
||||
def on_user_logout(self, user: str) -> bool:
|
||||
if self._user_info is None:
|
||||
return False
|
||||
if user == self._user_info.get("username", ""):
|
||||
self._user_info = None
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _write_messages(self):
|
||||
if self.is_closed:
|
||||
self.message_buf = []
|
||||
self.queue_busy = False
|
||||
return
|
||||
while self.message_buf:
|
||||
msg = self.message_buf.pop(0)
|
||||
await self.write_to_socket(msg)
|
||||
self.queue_busy = False
|
||||
|
||||
async def write_to_socket(
|
||||
self, message: Union[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
raise NotImplementedError("Children must implement write_to_socket")
|
||||
|
||||
def send_status(self,
|
||||
status: Dict[str, Any],
|
||||
eventtime: float
|
||||
) -> None:
|
||||
if not status:
|
||||
return
|
||||
self.queue_message({
|
||||
'jsonrpc': "2.0",
|
||||
'method': "notify_status_update",
|
||||
'params': [status, eventtime]})
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
method: str,
|
||||
params: Optional[Union[List, Dict[str, Any]]] = None
|
||||
) -> Awaitable:
|
||||
fut = self.eventloop.create_future()
|
||||
msg = {
|
||||
'jsonrpc': "2.0",
|
||||
'method': method,
|
||||
'id': id(fut)
|
||||
}
|
||||
if params is not None:
|
||||
msg["params"] = params
|
||||
self.pending_responses[id(fut)] = fut
|
||||
self.queue_message(msg)
|
||||
return fut
|
||||
|
||||
def send_notification(self, name: str, data: List) -> None:
|
||||
self.wsm.notify_clients(name, data, [self._uid])
|
||||
|
||||
def resolve_pending_response(
|
||||
self, response_id: int, result: Any
|
||||
) -> bool:
|
||||
fut = self.pending_responses.pop(response_id, None)
|
||||
if fut is None:
|
||||
return False
|
||||
if isinstance(result, ServerError):
|
||||
fut.set_exception(result)
|
||||
else:
|
||||
fut.set_result(result)
|
||||
return True
|
||||
|
||||
def close_socket(self, code: int, reason: str) -> None:
|
||||
raise NotImplementedError("Children must implement close_socket()")
|
||||
|
||||
|
||||
class WebRequest:
|
||||
def __init__(self,
|
||||
endpoint: str,
|
||||
|
@ -72,8 +274,8 @@ class WebRequest:
|
|||
def get_subscribable(self) -> Optional[Subscribable]:
|
||||
return self.conn
|
||||
|
||||
def get_client_connection(self) -> Optional[BaseSocketClient]:
|
||||
if isinstance(self.conn, BaseSocketClient):
|
||||
def get_client_connection(self) -> Optional[BaseRemoteConnection]:
|
||||
if isinstance(self.conn, BaseRemoteConnection):
|
||||
return self.conn
|
||||
return None
|
||||
|
||||
|
@ -142,28 +344,188 @@ class WebRequest:
|
|||
) -> Union[bool, _T]:
|
||||
return self._get_converted_arg(key, default, bool)
|
||||
|
||||
class APIDefinition:
|
||||
def __init__(self,
|
||||
endpoint: str,
|
||||
http_uri: str,
|
||||
jrpc_methods: List[str],
|
||||
request_methods: Union[str, List[str]],
|
||||
transports: List[str],
|
||||
callback: Optional[Callable[[WebRequest], Coroutine]],
|
||||
need_object_parser: bool):
|
||||
self.endpoint = endpoint
|
||||
self.uri = http_uri
|
||||
self.jrpc_methods = jrpc_methods
|
||||
if not isinstance(request_methods, list):
|
||||
request_methods = [request_methods]
|
||||
self.request_methods = request_methods
|
||||
self.supported_transports = transports
|
||||
self.callback = callback
|
||||
self.need_object_parser = need_object_parser
|
||||
|
||||
class APITransport:
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
raise NotImplementedError
|
||||
class JsonRPC:
|
||||
def __init__(
|
||||
self, server: Server, transport: str = "Websocket"
|
||||
) -> None:
|
||||
self.methods: Dict[str, RPCCallback] = {}
|
||||
self.transport = transport
|
||||
self.sanitize_response = False
|
||||
self.verbose = server.is_verbose_enabled()
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
raise NotImplementedError
|
||||
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
self.sanitize_response = False
|
||||
output = rpc_obj
|
||||
method: Optional[str] = rpc_obj.get("method")
|
||||
params: Dict[str, Any] = rpc_obj.get("params", {})
|
||||
if isinstance(method, str):
|
||||
if (
|
||||
method.startswith("access.") or
|
||||
method == "machine.sudo.password"
|
||||
):
|
||||
self.sanitize_response = True
|
||||
if params and isinstance(params, dict):
|
||||
output = copy.deepcopy(rpc_obj)
|
||||
output["params"] = {key: "<sanitized>" for key in params}
|
||||
elif method == "server.connection.identify":
|
||||
output = copy.deepcopy(rpc_obj)
|
||||
for field in ["access_token", "api_key"]:
|
||||
if field in params:
|
||||
output["params"][field] = "<sanitized>"
|
||||
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
|
||||
|
||||
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
if resp_obj is None:
|
||||
return
|
||||
output = resp_obj
|
||||
if self.sanitize_response and "result" in resp_obj:
|
||||
output = copy.deepcopy(resp_obj)
|
||||
output["result"] = "<sanitized>"
|
||||
self.sanitize_response = False
|
||||
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
|
||||
|
||||
def register_method(self,
|
||||
name: str,
|
||||
method: RPCCallback
|
||||
) -> None:
|
||||
self.methods[name] = method
|
||||
|
||||
def remove_method(self, name: str) -> None:
|
||||
self.methods.pop(name, None)
|
||||
|
||||
async def dispatch(self,
|
||||
data: str,
|
||||
conn: Optional[BaseRemoteConnection] = None
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
|
||||
except Exception:
|
||||
msg = f"{self.transport} data not json: {data}"
|
||||
logging.exception(msg)
|
||||
err = self.build_error(-32700, "Parse error")
|
||||
return json.dumps(err)
|
||||
if isinstance(obj, list):
|
||||
responses: List[Dict[str, Any]] = []
|
||||
for item in obj:
|
||||
self._log_request(item)
|
||||
resp = await self.process_object(item, conn)
|
||||
if resp is not None:
|
||||
self._log_response(resp)
|
||||
responses.append(resp)
|
||||
if responses:
|
||||
return json.dumps(responses)
|
||||
else:
|
||||
self._log_request(obj)
|
||||
response = await self.process_object(obj, conn)
|
||||
if response is not None:
|
||||
self._log_response(response)
|
||||
return json.dumps(response)
|
||||
return None
|
||||
|
||||
async def process_object(self,
|
||||
obj: Dict[str, Any],
|
||||
conn: Optional[BaseRemoteConnection]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
req_id: Optional[int] = obj.get('id', None)
|
||||
rpc_version: str = obj.get('jsonrpc', "")
|
||||
if rpc_version != "2.0":
|
||||
return self.build_error(-32600, "Invalid Request", req_id)
|
||||
method_name = obj.get('method', Sentinel.MISSING)
|
||||
if method_name is Sentinel.MISSING:
|
||||
self.process_response(obj, conn)
|
||||
return None
|
||||
if not isinstance(method_name, str):
|
||||
return self.build_error(-32600, "Invalid Request", req_id)
|
||||
method = self.methods.get(method_name, None)
|
||||
if method is None:
|
||||
return self.build_error(-32601, "Method not found", req_id)
|
||||
params: Dict[str, Any] = {}
|
||||
if 'params' in obj:
|
||||
params = obj['params']
|
||||
if not isinstance(params, dict):
|
||||
return self.build_error(
|
||||
-32602, f"Invalid params:", req_id, True)
|
||||
response = await self.execute_method(method, req_id, conn, params)
|
||||
return response
|
||||
|
||||
def process_response(
|
||||
self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection]
|
||||
) -> None:
|
||||
if conn is None:
|
||||
logging.debug(f"RPC Response to non-socket request: {obj}")
|
||||
return
|
||||
response_id = obj.get("id")
|
||||
if response_id is None:
|
||||
logging.debug(f"RPC Response with null ID: {obj}")
|
||||
return
|
||||
result = obj.get("result")
|
||||
if result is None:
|
||||
name = conn.client_data["name"]
|
||||
error = obj.get("error")
|
||||
msg = f"Invalid Response: {obj}"
|
||||
code = -32600
|
||||
if isinstance(error, dict):
|
||||
msg = error.get("message", msg)
|
||||
code = error.get("code", code)
|
||||
msg = f"{name} rpc error: {code} {msg}"
|
||||
ret = ServerError(msg, 418)
|
||||
else:
|
||||
ret = result
|
||||
conn.resolve_pending_response(response_id, ret)
|
||||
|
||||
async def execute_method(self,
|
||||
callback: RPCCallback,
|
||||
req_id: Optional[int],
|
||||
conn: Optional[BaseRemoteConnection],
|
||||
params: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if conn is not None:
|
||||
params["_socket_"] = conn
|
||||
try:
|
||||
result = await callback(params)
|
||||
except TypeError as e:
|
||||
return self.build_error(
|
||||
-32602, f"Invalid params:\n{e}", req_id, True)
|
||||
except ServerError as e:
|
||||
code = e.status_code
|
||||
if code == 404:
|
||||
code = -32601
|
||||
elif code == 401:
|
||||
code = -32602
|
||||
return self.build_error(code, str(e), req_id, True)
|
||||
except Exception as e:
|
||||
return self.build_error(-31000, str(e), req_id, True)
|
||||
|
||||
if req_id is None:
|
||||
return None
|
||||
else:
|
||||
return self.build_result(result, req_id)
|
||||
|
||||
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
|
||||
return {
|
||||
'jsonrpc': "2.0",
|
||||
'result': result,
|
||||
'id': req_id
|
||||
}
|
||||
|
||||
def build_error(self,
|
||||
code: int,
|
||||
msg: str,
|
||||
req_id: Optional[int] = None,
|
||||
is_exc: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
|
||||
if is_exc:
|
||||
logging.exception(log_msg)
|
||||
else:
|
||||
logging.info(log_msg)
|
||||
return {
|
||||
'jsonrpc': "2.0",
|
||||
'error': {'code': code, 'message': msg},
|
||||
'id': req_id
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import asyncio
|
|||
import pathlib
|
||||
import logging
|
||||
import json
|
||||
from ..websockets import BaseSocketClient
|
||||
from ..common import BaseRemoteConnection
|
||||
from ..utils import get_unix_peer_credentials
|
||||
|
||||
# Annotation imports
|
||||
|
@ -32,7 +32,7 @@ UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
|
|||
class ExtensionManager:
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.agents: Dict[str, BaseSocketClient] = {}
|
||||
self.agents: Dict[str, BaseRemoteConnection] = {}
|
||||
self.uds_server: Optional[asyncio.AbstractServer] = None
|
||||
self.server.register_endpoint(
|
||||
"/connection/send_event", ["POST"], self._handle_agent_event,
|
||||
|
@ -45,7 +45,7 @@ class ExtensionManager:
|
|||
"/server/extensions/request", ["POST"], self._handle_call_agent
|
||||
)
|
||||
|
||||
def register_agent(self, connection: BaseSocketClient) -> None:
|
||||
def register_agent(self, connection: BaseRemoteConnection) -> None:
|
||||
data = connection.client_data
|
||||
name = data["name"]
|
||||
client_type = data["type"]
|
||||
|
@ -64,7 +64,7 @@ class ExtensionManager:
|
|||
}
|
||||
connection.send_notification("agent_event", [evt])
|
||||
|
||||
def remove_agent(self, connection: BaseSocketClient) -> None:
|
||||
def remove_agent(self, connection: BaseRemoteConnection) -> None:
|
||||
name = connection.client_data["name"]
|
||||
if name in self.agents:
|
||||
del self.agents[name]
|
||||
|
@ -135,7 +135,7 @@ class ExtensionManager:
|
|||
await self.uds_server.wait_closed()
|
||||
self.uds_server = None
|
||||
|
||||
class UnixSocketClient(BaseSocketClient):
|
||||
class UnixSocketClient(BaseRemoteConnection):
|
||||
def __init__(
|
||||
self,
|
||||
server: Server,
|
||||
|
|
|
@ -13,8 +13,7 @@ import pathlib
|
|||
import ssl
|
||||
from collections import deque
|
||||
import paho.mqtt.client as paho_mqtt
|
||||
from ..common import Subscribable, WebRequest, APITransport
|
||||
from ..websockets import JsonRPC
|
||||
from ..common import Subscribable, WebRequest, APITransport, JsonRPC
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
|
|
|
@ -32,7 +32,8 @@ from typing import (
|
|||
if TYPE_CHECKING:
|
||||
from ..app import InternalTransport
|
||||
from ..confighelper import ConfigHelper
|
||||
from ..websockets import WebsocketManager, BaseSocketClient
|
||||
from ..websockets import WebsocketManager
|
||||
from ..common import BaseRemoteConnection
|
||||
from tornado.websocket import WebSocketClientConnection
|
||||
from .database import MoonrakerDatabase
|
||||
from .klippy_apis import KlippyAPI
|
||||
|
@ -609,7 +610,7 @@ class SimplyPrint(Subscribable):
|
|||
is_on = device_info["status"] == "on"
|
||||
self.send_sp("power_controller", {"on": is_on})
|
||||
|
||||
def _on_websocket_identified(self, ws: BaseSocketClient) -> None:
|
||||
def _on_websocket_identified(self, ws: BaseRemoteConnection) -> None:
|
||||
if (
|
||||
self.cache.current_wsid is None and
|
||||
ws.client_data.get("type", "") == "web"
|
||||
|
@ -622,7 +623,7 @@ class SimplyPrint(Subscribable):
|
|||
self.cache.current_wsid = ws.uid
|
||||
self.send_sp("machine_data", ui_data)
|
||||
|
||||
def _on_websocket_removed(self, ws: BaseSocketClient) -> None:
|
||||
def _on_websocket_removed(self, ws: BaseRemoteConnection) -> None:
|
||||
if self.cache.current_wsid is None or self.cache.current_wsid != ws.uid:
|
||||
return
|
||||
ui_data = self._get_ui_info()
|
||||
|
|
|
@ -7,13 +7,17 @@
|
|||
from __future__ import annotations
|
||||
import logging
|
||||
import ipaddress
|
||||
import json
|
||||
import asyncio
|
||||
import copy
|
||||
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
||||
from tornado.web import HTTPError
|
||||
from .common import WebRequest, Subscribable, APITransport, APIDefinition
|
||||
from .utils import ServerError, Sentinel
|
||||
from .common import (
|
||||
WebRequest,
|
||||
BaseRemoteConnection,
|
||||
APITransport,
|
||||
APIDefinition,
|
||||
JsonRPC
|
||||
)
|
||||
from .utils import ServerError
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
|
@ -42,195 +46,10 @@ if TYPE_CHECKING:
|
|||
|
||||
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
||||
|
||||
class JsonRPC:
|
||||
def __init__(
|
||||
self, server: Server, transport: str = "Websocket"
|
||||
) -> None:
|
||||
self.methods: Dict[str, RPCCallback] = {}
|
||||
self.transport = transport
|
||||
self.sanitize_response = False
|
||||
self.verbose = server.is_verbose_enabled()
|
||||
|
||||
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
self.sanitize_response = False
|
||||
output = rpc_obj
|
||||
method: Optional[str] = rpc_obj.get("method")
|
||||
params: Dict[str, Any] = rpc_obj.get("params", {})
|
||||
if isinstance(method, str):
|
||||
if (
|
||||
method.startswith("access.") or
|
||||
method == "machine.sudo.password"
|
||||
):
|
||||
self.sanitize_response = True
|
||||
if params and isinstance(params, dict):
|
||||
output = copy.deepcopy(rpc_obj)
|
||||
output["params"] = {key: "<sanitized>" for key in params}
|
||||
elif method == "server.connection.identify":
|
||||
output = copy.deepcopy(rpc_obj)
|
||||
for field in ["access_token", "api_key"]:
|
||||
if field in params:
|
||||
output["params"][field] = "<sanitized>"
|
||||
logging.debug(f"{self.transport} Received::{json.dumps(output)}")
|
||||
|
||||
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
if resp_obj is None:
|
||||
return
|
||||
output = resp_obj
|
||||
if self.sanitize_response and "result" in resp_obj:
|
||||
output = copy.deepcopy(resp_obj)
|
||||
output["result"] = "<sanitized>"
|
||||
self.sanitize_response = False
|
||||
logging.debug(f"{self.transport} Response::{json.dumps(output)}")
|
||||
|
||||
def register_method(self,
|
||||
name: str,
|
||||
method: RPCCallback
|
||||
) -> None:
|
||||
self.methods[name] = method
|
||||
|
||||
def remove_method(self, name: str) -> None:
|
||||
self.methods.pop(name, None)
|
||||
|
||||
async def dispatch(self,
|
||||
data: str,
|
||||
conn: Optional[BaseSocketClient] = None
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
obj: Union[Dict[str, Any], List[dict]] = json.loads(data)
|
||||
except Exception:
|
||||
msg = f"{self.transport} data not json: {data}"
|
||||
logging.exception(msg)
|
||||
err = self.build_error(-32700, "Parse error")
|
||||
return json.dumps(err)
|
||||
if isinstance(obj, list):
|
||||
responses: List[Dict[str, Any]] = []
|
||||
for item in obj:
|
||||
self._log_request(item)
|
||||
resp = await self.process_object(item, conn)
|
||||
if resp is not None:
|
||||
self._log_response(resp)
|
||||
responses.append(resp)
|
||||
if responses:
|
||||
return json.dumps(responses)
|
||||
else:
|
||||
self._log_request(obj)
|
||||
response = await self.process_object(obj, conn)
|
||||
if response is not None:
|
||||
self._log_response(response)
|
||||
return json.dumps(response)
|
||||
return None
|
||||
|
||||
async def process_object(self,
|
||||
obj: Dict[str, Any],
|
||||
conn: Optional[BaseSocketClient]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
req_id: Optional[int] = obj.get('id', None)
|
||||
rpc_version: str = obj.get('jsonrpc', "")
|
||||
if rpc_version != "2.0":
|
||||
return self.build_error(-32600, "Invalid Request", req_id)
|
||||
method_name = obj.get('method', Sentinel.MISSING)
|
||||
if method_name is Sentinel.MISSING:
|
||||
self.process_response(obj, conn)
|
||||
return None
|
||||
if not isinstance(method_name, str):
|
||||
return self.build_error(-32600, "Invalid Request", req_id)
|
||||
method = self.methods.get(method_name, None)
|
||||
if method is None:
|
||||
return self.build_error(-32601, "Method not found", req_id)
|
||||
params: Dict[str, Any] = {}
|
||||
if 'params' in obj:
|
||||
params = obj['params']
|
||||
if not isinstance(params, dict):
|
||||
return self.build_error(
|
||||
-32602, f"Invalid params:", req_id, True)
|
||||
response = await self.execute_method(method, req_id, conn, params)
|
||||
return response
|
||||
|
||||
def process_response(
|
||||
self, obj: Dict[str, Any], conn: Optional[BaseSocketClient]
|
||||
) -> None:
|
||||
if conn is None:
|
||||
logging.debug(f"RPC Response to non-socket request: {obj}")
|
||||
return
|
||||
response_id = obj.get("id")
|
||||
if response_id is None:
|
||||
logging.debug(f"RPC Response with null ID: {obj}")
|
||||
return
|
||||
result = obj.get("result")
|
||||
if result is None:
|
||||
name = conn.client_data["name"]
|
||||
error = obj.get("error")
|
||||
msg = f"Invalid Response: {obj}"
|
||||
code = -32600
|
||||
if isinstance(error, dict):
|
||||
msg = error.get("message", msg)
|
||||
code = error.get("code", code)
|
||||
msg = f"{name} rpc error: {code} {msg}"
|
||||
ret = ServerError(msg, 418)
|
||||
else:
|
||||
ret = result
|
||||
conn.resolve_pending_response(response_id, ret)
|
||||
|
||||
async def execute_method(self,
|
||||
callback: RPCCallback,
|
||||
req_id: Optional[int],
|
||||
conn: Optional[BaseSocketClient],
|
||||
params: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if conn is not None:
|
||||
params["_socket_"] = conn
|
||||
try:
|
||||
result = await callback(params)
|
||||
except TypeError as e:
|
||||
return self.build_error(
|
||||
-32602, f"Invalid params:\n{e}", req_id, True)
|
||||
except ServerError as e:
|
||||
code = e.status_code
|
||||
if code == 404:
|
||||
code = -32601
|
||||
elif code == 401:
|
||||
code = -32602
|
||||
return self.build_error(code, str(e), req_id, True)
|
||||
except Exception as e:
|
||||
return self.build_error(-31000, str(e), req_id, True)
|
||||
|
||||
if req_id is None:
|
||||
return None
|
||||
else:
|
||||
return self.build_result(result, req_id)
|
||||
|
||||
def build_result(self, result: Any, req_id: int) -> Dict[str, Any]:
|
||||
return {
|
||||
'jsonrpc': "2.0",
|
||||
'result': result,
|
||||
'id': req_id
|
||||
}
|
||||
|
||||
def build_error(self,
|
||||
code: int,
|
||||
msg: str,
|
||||
req_id: Optional[int] = None,
|
||||
is_exc: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
log_msg = f"JSON-RPC Request Error: {code}\n{msg}"
|
||||
if is_exc:
|
||||
logging.exception(log_msg)
|
||||
else:
|
||||
logging.info(log_msg)
|
||||
return {
|
||||
'jsonrpc': "2.0",
|
||||
'error': {'code': code, 'message': msg},
|
||||
'id': req_id
|
||||
}
|
||||
|
||||
class WebsocketManager(APITransport):
|
||||
def __init__(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.clients: Dict[int, BaseSocketClient] = {}
|
||||
self.clients: Dict[int, BaseRemoteConnection] = {}
|
||||
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
||||
self.rpc = JsonRPC(server)
|
||||
self.closed_event: Optional[asyncio.Event] = None
|
||||
|
@ -289,7 +108,7 @@ class WebsocketManager(APITransport):
|
|||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
sc: BaseSocketClient = args.pop("_socket_")
|
||||
sc: BaseRemoteConnection = args.pop("_socket_")
|
||||
sc.check_authenticated(path=endpoint)
|
||||
result = await callback(
|
||||
WebRequest(endpoint, args, request_method, sc,
|
||||
|
@ -298,12 +117,12 @@ class WebsocketManager(APITransport):
|
|||
return func
|
||||
|
||||
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseSocketClient = args["_socket_"]
|
||||
sc: BaseRemoteConnection = args["_socket_"]
|
||||
sc.check_authenticated()
|
||||
return {'websocket_id': sc.uid}
|
||||
|
||||
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseSocketClient = args["_socket_"]
|
||||
sc: BaseRemoteConnection = args["_socket_"]
|
||||
sc.authenticate(
|
||||
token=args.get("access_token", None),
|
||||
api_key=args.get("api_key", None)
|
||||
|
@ -355,7 +174,7 @@ class WebsocketManager(APITransport):
|
|||
def has_socket(self, ws_id: int) -> bool:
|
||||
return ws_id in self.clients
|
||||
|
||||
def get_client(self, ws_id: int) -> Optional[BaseSocketClient]:
|
||||
def get_client(self, ws_id: int) -> Optional[BaseRemoteConnection]:
|
||||
sc = self.clients.get(ws_id, None)
|
||||
if sc is None or not isinstance(sc, WebSocket):
|
||||
return None
|
||||
|
@ -363,37 +182,37 @@ class WebsocketManager(APITransport):
|
|||
|
||||
def get_clients_by_type(
|
||||
self, client_type: str
|
||||
) -> List[BaseSocketClient]:
|
||||
) -> List[BaseRemoteConnection]:
|
||||
if not client_type:
|
||||
return []
|
||||
ret: List[BaseSocketClient] = []
|
||||
ret: List[BaseRemoteConnection] = []
|
||||
for sc in self.clients.values():
|
||||
if sc.client_data.get("type", "") == client_type.lower():
|
||||
ret.append(sc)
|
||||
return ret
|
||||
|
||||
def get_clients_by_name(self, name: str) -> List[BaseSocketClient]:
|
||||
def get_clients_by_name(self, name: str) -> List[BaseRemoteConnection]:
|
||||
if not name:
|
||||
return []
|
||||
ret: List[BaseSocketClient] = []
|
||||
ret: List[BaseRemoteConnection] = []
|
||||
for sc in self.clients.values():
|
||||
if sc.client_data.get("name", "").lower() == name.lower():
|
||||
ret.append(sc)
|
||||
return ret
|
||||
|
||||
def get_unidentified_clients(self) -> List[BaseSocketClient]:
|
||||
ret: List[BaseSocketClient] = []
|
||||
def get_unidentified_clients(self) -> List[BaseRemoteConnection]:
|
||||
ret: List[BaseRemoteConnection] = []
|
||||
for sc in self.clients.values():
|
||||
if not sc.client_data:
|
||||
ret.append(sc)
|
||||
return ret
|
||||
|
||||
def add_client(self, sc: BaseSocketClient) -> None:
|
||||
def add_client(self, sc: BaseRemoteConnection) -> None:
|
||||
self.clients[sc.uid] = sc
|
||||
self.server.send_event("websockets:client_added", sc)
|
||||
logging.debug(f"New Websocket Added: {sc.uid}")
|
||||
|
||||
def remove_client(self, sc: BaseSocketClient) -> None:
|
||||
def remove_client(self, sc: BaseRemoteConnection) -> None:
|
||||
old_sc = self.clients.pop(sc.uid, None)
|
||||
if old_sc is not None:
|
||||
self.server.send_event("websockets:client_removed", sc)
|
||||
|
@ -449,176 +268,7 @@ class WebsocketManager(APITransport):
|
|||
pass
|
||||
self.closed_event = None
|
||||
|
||||
class BaseSocketClient(Subscribable):
|
||||
def on_create(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.eventloop = server.get_event_loop()
|
||||
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||
self.rpc = self.wsm.rpc
|
||||
self._uid = id(self)
|
||||
self.ip_addr = ""
|
||||
self.is_closed: bool = False
|
||||
self.queue_busy: bool = False
|
||||
self.pending_responses: Dict[int, asyncio.Future] = {}
|
||||
self.message_buf: List[Union[str, Dict[str, Any]]] = []
|
||||
self._connected_time: float = 0.
|
||||
self._identified: bool = False
|
||||
self._client_data: Dict[str, str] = {
|
||||
"name": "unknown",
|
||||
"version": "",
|
||||
"type": "",
|
||||
"url": ""
|
||||
}
|
||||
self._need_auth: bool = False
|
||||
self._user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self._user_info
|
||||
|
||||
@user_info.setter
|
||||
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
||||
self._user_info = uinfo
|
||||
self._need_auth = False
|
||||
|
||||
@property
|
||||
def need_auth(self) -> bool:
|
||||
return self._need_auth
|
||||
|
||||
@property
|
||||
def uid(self) -> int:
|
||||
return self._uid
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def start_time(self) -> float:
|
||||
return self._connected_time
|
||||
|
||||
@property
|
||||
def identified(self) -> bool:
|
||||
return self._identified
|
||||
|
||||
@property
|
||||
def client_data(self) -> Dict[str, str]:
|
||||
return self._client_data
|
||||
|
||||
@client_data.setter
|
||||
def client_data(self, data: Dict[str, str]) -> None:
|
||||
self._client_data = data
|
||||
self._identified = True
|
||||
|
||||
async def _process_message(self, message: str) -> None:
|
||||
try:
|
||||
response = await self.rpc.dispatch(message, self)
|
||||
if response is not None:
|
||||
self.queue_message(response)
|
||||
except Exception:
|
||||
logging.exception("Websocket Command Error")
|
||||
|
||||
def queue_message(self, message: Union[str, Dict[str, Any]]):
|
||||
self.message_buf.append(message)
|
||||
if self.queue_busy:
|
||||
return
|
||||
self.queue_busy = True
|
||||
self.eventloop.register_callback(self._write_messages)
|
||||
|
||||
def authenticate(
|
||||
self,
|
||||
token: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> None:
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is None:
|
||||
return
|
||||
if token is not None:
|
||||
self.user_info = auth.validate_jwt(token)
|
||||
elif api_key is not None and self.user_info is None:
|
||||
self.user_info = auth.validate_api_key(api_key)
|
||||
else:
|
||||
self.check_authenticated()
|
||||
|
||||
def check_authenticated(self, path: str = "") -> None:
|
||||
if not self._need_auth:
|
||||
return
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is None:
|
||||
return
|
||||
if not auth.is_path_permitted(path):
|
||||
raise self.server.error("Unauthorized", 401)
|
||||
|
||||
def on_user_logout(self, user: str) -> bool:
|
||||
if self._user_info is None:
|
||||
return False
|
||||
if user == self._user_info.get("username", ""):
|
||||
self._user_info = None
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _write_messages(self):
|
||||
if self.is_closed:
|
||||
self.message_buf = []
|
||||
self.queue_busy = False
|
||||
return
|
||||
while self.message_buf:
|
||||
msg = self.message_buf.pop(0)
|
||||
await self.write_to_socket(msg)
|
||||
self.queue_busy = False
|
||||
|
||||
async def write_to_socket(
|
||||
self, message: Union[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
raise NotImplementedError("Children must implement write_to_socket")
|
||||
|
||||
def send_status(self,
|
||||
status: Dict[str, Any],
|
||||
eventtime: float
|
||||
) -> None:
|
||||
if not status:
|
||||
return
|
||||
self.queue_message({
|
||||
'jsonrpc': "2.0",
|
||||
'method': "notify_status_update",
|
||||
'params': [status, eventtime]})
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
method: str,
|
||||
params: Optional[Union[List, Dict[str, Any]]] = None
|
||||
) -> Awaitable:
|
||||
fut = self.eventloop.create_future()
|
||||
msg = {
|
||||
'jsonrpc': "2.0",
|
||||
'method': method,
|
||||
'id': id(fut)
|
||||
}
|
||||
if params is not None:
|
||||
msg["params"] = params
|
||||
self.pending_responses[id(fut)] = fut
|
||||
self.queue_message(msg)
|
||||
return fut
|
||||
|
||||
def send_notification(self, name: str, data: List) -> None:
|
||||
self.wsm.notify_clients(name, data, [self._uid])
|
||||
|
||||
def resolve_pending_response(
|
||||
self, response_id: int, result: Any
|
||||
) -> bool:
|
||||
fut = self.pending_responses.pop(response_id, None)
|
||||
if fut is None:
|
||||
return False
|
||||
if isinstance(result, ServerError):
|
||||
fut.set_exception(result)
|
||||
else:
|
||||
fut.set_result(result)
|
||||
return True
|
||||
|
||||
def close_socket(self, code: int, reason: str) -> None:
|
||||
raise NotImplementedError("Children must implement close_socket()")
|
||||
|
||||
class WebSocket(WebSocketHandler, BaseSocketClient):
|
||||
class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
||||
connection_count: int = 0
|
||||
|
||||
def initialize(self) -> None:
|
||||
|
|
Loading…
Reference in New Issue