jsonrpc: share one instance among all transports

This change refactors the APIDefiniton into a dataclass, allowing
defs to be shared directly among HTTP and RPC requests.  In
addition, all transports now share one instance of JSONRPC,
removing duplicate registration.  API Defintiions are registered
with the RPC Dispatcher, and it validates the Transport type.
In addition tranports may perform their own validation prior
to request execution.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-11-26 11:45:55 -05:00
parent 8b2d9b26f5
commit bfeb096f31
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
8 changed files with 211 additions and 254 deletions

View File

@ -24,6 +24,7 @@ from tornado.http1connection import HTTP1Connection
from tornado.log import access_log
from .utils import ServerError, source_info
from .common import (
JsonRPC,
WebRequest,
APIDefinition,
APITransport,
@ -50,7 +51,6 @@ from typing import (
Union,
Dict,
List,
Tuple,
AsyncGenerator,
)
if TYPE_CHECKING:
@ -121,34 +121,26 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
class InternalTransport(APITransport):
def __init__(self, server: Server) -> None:
self.server = server
self.callbacks: Dict[str, Tuple[str, RequestType, APICallback]] = {}
def register_api_handler(self, api_def: APIDefinition) -> None:
ep = api_def.endpoint
cb = api_def.callback
for req_type, rpc_method in api_def.rpc_methods.items():
self.callbacks[rpc_method] = (ep, req_type, cb)
def remove_api_handler(self, api_def: APIDefinition) -> None:
for rpc_method in api_def.rpc_methods.values():
self.callbacks.pop(rpc_method, None)
async def call_method(self,
method_name: str,
request_arguments: Dict[str, Any] = {},
**kwargs
) -> Any:
if method_name not in self.callbacks:
rpc: JsonRPC = self.server.lookup_component("jsonrpc")
method_info = rpc.get_method(method_name)
if method_info is None:
raise self.server.error(f"No method {method_name} available")
req_type, api_definition = method_info
if TransportType.INTERNAL not in api_definition.transports:
raise self.server.error(f"No method {method_name} available")
ep, req_type, func = self.callbacks[method_name]
# Request arguments can be suppplied either through a dict object
# or via keyword arugments
args = request_arguments or kwargs
return await func(WebRequest(ep, dict(args), req_type))
return await api_definition.request(args, req_type, self)
class MoonrakerApp:
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.json_rpc = JsonRPC(self.server)
self.http_server: Optional[HTTPServer] = None
self.secure_server: Optional[HTTPServer] = None
self.template_cache: Dict[str, JinjaTemplate] = {}
@ -180,14 +172,7 @@ class MoonrakerApp:
)
self._route_prefix = f"/{rp}"
home_pattern = f"{self._route_prefix}/?"
# Set Up Websocket and Authorization Managers
self.wsm = WebsocketManager(self.server)
self.internal_transport = InternalTransport(self.server)
self.api_transports: Dict[TransportType, APITransport] = {
TransportType.WEBSOCKET: self.wsm,
TransportType.INTERNAL: self.internal_transport
}
mimetypes.add_type('text/plain', '.log')
mimetypes.add_type('text/plain', '.gcode')
@ -228,9 +213,8 @@ class MoonrakerApp:
# Register Server Components
self.server.register_component("application", self)
self.server.register_component("websockets", self.wsm)
self.server.register_component("internal_transport",
self.internal_transport)
self.server.register_component("jsonrpc", self.json_rpc)
self.server.register_component("internal_transport", self.internal_transport)
def _get_path_option(
self, config: ConfigHelper, option: str
@ -318,13 +302,6 @@ class MoonrakerApp:
if self.secure_server is not None:
self.secure_server.stop()
await self.secure_server.close_all_connections()
await self.wsm.close()
def register_api_transport(
self, trtype: TransportType, api_transport: APITransport
) -> Dict[str, APIDefinition]:
self.api_transports[trtype] = api_transport
return APIDefinition.get_cache()
def register_endpoint(
self,
@ -356,9 +333,10 @@ class MoonrakerApp:
f"{self._route_prefix}{http_path}", DynamicRequestHandler, params
)
self.registered_base_handlers.append(http_path)
for trtype, api_transport in self.api_transports.items():
if trtype in transports:
api_transport.register_api_handler(api_def)
for request_type, method_name in api_def.rpc_items():
transports = api_def.transports & ~TransportType.HTTP
logging.info(f"Registering RPC Method: ({transports}) {method_name}")
self.json_rpc.register_method(method_name, request_type, api_def)
def register_static_file_handler(
self, pattern: str, file_path: str, force: bool = False
@ -416,8 +394,8 @@ class MoonrakerApp:
if api_def.http_path in self.registered_base_handlers:
self.registered_base_handlers.remove(api_def.http_path)
self.mutable_router.remove_handler(api_def.http_path)
for api_transport in self.api_transports.values():
api_transport.remove_api_handler(api_def)
for method_name in api_def.rpc_methods:
self.json_rpc.remove_method(method_name)
async def load_template(self, asset_name: str) -> JinjaTemplate:
if asset_name in self.template_cache:
@ -475,8 +453,7 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
except Exception:
pass
else:
wsm: WebsocketManager = self.server.lookup_component(
"websockets")
wsm: WebsocketManager = self.server.lookup_component("websockets")
conn = wsm.get_client(conn_id)
if not isinstance(conn, WebSocket):
return None
@ -644,25 +621,18 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
async def delete(self, *args, **kwargs) -> None:
await self._process_http_request(RequestType.DELETE)
async def _do_request(
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
) -> Any:
return await self.api_defintion.callback(
WebRequest(
self.endpoint, args, req_type, conn=conn,
ip_addr=self.request.remote_ip or "", user=self.current_user
)
)
async def _process_http_request(self, req_type: RequestType) -> None:
if req_type not in self.api_defintion.request_types:
raise tornado.web.HTTPError(405)
conn = self.get_associated_websocket()
args = self.parse_args()
transport = self.get_associated_websocket()
req = f"{self.request.method} {self.request.path}"
self._log_debug(f"HTTP Request::{req}", args)
try:
result = await self._do_request(args, conn, req_type)
ip = self.request.remote_ip or ""
result = await self.api_defintion.request(
args, req_type, transport, ip, self.current_user
)
except ServerError as e:
raise tornado.web.HTTPError(
e.status_code, reason=str(e)) from e

View File

@ -11,6 +11,7 @@ import logging
import copy
import re
from enum import Enum, Flag, auto
from dataclasses import dataclass
from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw
@ -26,7 +27,9 @@ from typing import (
Union,
Dict,
List,
Awaitable
Awaitable,
ClassVar,
Tuple
)
if TYPE_CHECKING:
@ -157,31 +160,34 @@ class KlippyState(ExtendedEnum):
def startup_complete(self) -> bool:
return self.value > 2
class Subscribable:
def send_status(
self, status: Dict[str, Any], eventtime: float
) -> None:
raise NotImplementedError
@dataclass(frozen=True)
class APIDefinition:
_cache: Dict[str, APIDefinition] = {}
def __init__(
endpoint: str
http_path: str
rpc_methods: List[str]
request_types: RequestType
transports: TransportType
callback: Callable[[WebRequest], Coroutine]
_cache: ClassVar[Dict[str, APIDefinition]] = {}
def request(
self,
endpoint: str,
http_path: str,
rpc_methods: Dict[RequestType, str],
request_types: RequestType,
transports: TransportType,
callback: Callable[[WebRequest], Coroutine],
need_object_parser: bool
) -> None:
self.endpoint = endpoint
self.http_path = http_path
self.rpc_methods = rpc_methods
self.request_types = request_types
self.supported_transports = transports
self.callback = callback
self.need_object_parser = need_object_parser
args: Dict[str, Any],
request_type: RequestType,
transport: Optional[APITransport] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> Coroutine:
return self.callback(
WebRequest(self.endpoint, args, request_type, transport, ip_addr, user)
)
@property
def need_object_parser(self) -> bool:
return self.endpoint.startswith("objects/")
def rpc_items(self) -> zip[Tuple[RequestType, str]]:
return zip(self.request_types, self.rpc_methods)
@classmethod
def create(
@ -210,30 +216,30 @@ class APIDefinition:
f"Invalid endpoint name '{endpoint}', must start with one of "
f"the following: {prefixes}"
)
jrpc_methods: Dict[RequestType, str] = {}
rpc_methods: List[str] = []
if is_remote:
# Request Types have no meaning for remote requests. Therefore
# both GET and POST http requests are accepted. JRPC requests do
# not need an associated RequestType, so the unknown value is used.
request_types = RequestType.GET | RequestType.POST
jrpc_methods[RequestType(0)] = http_path[1:].replace('/', '.')
else:
rpc_methods.append(http_path[1:].replace('/', '.'))
elif transports != TransportType.HTTP:
name_parts = http_path[1:].split('/')
if len(request_types) > 1:
for rtype in request_types:
func_name = rtype.name.lower() + "_" + name_parts[-1]
jrpc_methods[rtype] = ".".join(name_parts[:-1] + [func_name])
rpc_methods.append(".".join(name_parts[:-1] + [func_name]))
else:
jrpc_methods[request_types] = ".".join(name_parts)
if len(request_types) != len(jrpc_methods):
rpc_methods.append(".".join(name_parts))
if len(request_types) != len(rpc_methods):
raise ServerError(
"Invalid API definition. Number of websocket methods must "
"match the number of request methods"
)
need_object_parser = endpoint.startswith("objects/")
api_def = cls(
endpoint, http_path, jrpc_methods, request_types,
transports, callback, need_object_parser
endpoint, http_path, rpc_methods, request_types,
transports, callback
)
cls._cache[endpoint] = api_def
return api_def
@ -247,18 +253,26 @@ class APIDefinition:
return cls._cache
class APITransport:
def register_api_handler(self, api_def: APIDefinition) -> None:
@property
def transport_type(self) -> TransportType:
return TransportType.INTERNAL
def screen_rpc_request(
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
) -> None:
return None
def send_status(
self, status: Dict[str, Any], eventtime: float
) -> None:
raise NotImplementedError
def remove_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError
class BaseRemoteConnection(Subscribable):
class BaseRemoteConnection(APITransport):
def on_create(self, server: Server) -> None:
self.server = server
self.eventloop = server.get_event_loop()
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
self.rpc = self.wsm.rpc
self.rpc: JsonRPC = self.server.lookup_component("jsonrpc")
self._uid = id(self)
self.ip_addr = ""
self.is_closed: bool = False
@ -314,6 +328,15 @@ class BaseRemoteConnection(Subscribable):
self._client_data = data
self._identified = True
@property
def transport_type(self) -> TransportType:
return TransportType.WEBSOCKET
def screen_rpc_request(
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
) -> None:
self.check_authenticated(api_def.endpoint)
async def _process_message(self, message: str) -> None:
try:
response = await self.rpc.dispatch(message, self)
@ -442,14 +465,14 @@ class WebRequest:
endpoint: str,
args: Dict[str, Any],
request_type: RequestType = RequestType(0),
conn: Optional[Subscribable] = None,
transport: Optional[APITransport] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
self.endpoint = endpoint
self.request_type = request_type
self.args = args
self.conn = conn
self.transport = transport
self.request_type = request_type
self.ip_addr: Optional[IPUnion] = None
try:
self.ip_addr = ipaddress.ip_address(ip_addr)
@ -469,12 +492,12 @@ class WebRequest:
def get_args(self) -> Dict[str, Any]:
return self.args
def get_subscribable(self) -> Optional[Subscribable]:
return self.conn
def get_subscribable(self) -> Optional[APITransport]:
return self.transport
def get_client_connection(self) -> Optional[BaseRemoteConnection]:
if isinstance(self.conn, BaseRemoteConnection):
return self.conn
if isinstance(self.transport, BaseRemoteConnection):
return self.transport
return None
def get_ip_address(self) -> Optional[IPUnion]:
@ -595,15 +618,12 @@ class WebRequest:
class JsonRPC:
def __init__(
self, server: Server, transport: str = "Websocket"
) -> None:
self.methods: Dict[str, RPCCallback] = {}
self.transport = transport
def __init__(self, server: Server) -> None:
self.methods: Dict[str, Tuple[RequestType, APIDefinition]] = {}
self.sanitize_response = False
self.verbose = server.is_verbose_enabled()
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
def _log_request(self, rpc_obj: Dict[str, Any], trtype: TransportType) -> None:
if not self.verbose:
return
self.sanitize_response = False
@ -624,9 +644,11 @@ class JsonRPC:
for field in ["access_token", "api_key"]:
if field in params:
output["params"][field] = "<sanitized>"
logging.debug(f"{self.transport} Received::{jsonw.dumps(output).decode()}")
logging.debug(f"{trtype} Received::{jsonw.dumps(output).decode()}")
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
def _log_response(
self, resp_obj: Optional[Dict[str, Any]], trtype: TransportType
) -> None:
if not self.verbose:
return
if resp_obj is None:
@ -636,67 +658,84 @@ class JsonRPC:
output = copy.deepcopy(resp_obj)
output["result"] = "<sanitized>"
self.sanitize_response = False
logging.debug(f"{self.transport} Response::{jsonw.dumps(output).decode()}")
logging.debug(f"{trtype} Response::{jsonw.dumps(output).decode()}")
def register_method(self,
name: str,
method: RPCCallback
) -> None:
self.methods[name] = method
def register_method(
self,
name: str,
request_type: RequestType,
api_definition: APIDefinition
) -> None:
self.methods[name] = (request_type, api_definition)
def get_method(self, name: str) -> Optional[Tuple[RequestType, APIDefinition]]:
return self.methods.get(name, None)
def remove_method(self, name: str) -> None:
self.methods.pop(name, None)
async def dispatch(self,
data: str,
conn: Optional[BaseRemoteConnection] = None
) -> Optional[bytes]:
async def dispatch(
self,
data: Union[str, bytes],
transport: APITransport
) -> Optional[bytes]:
transport_type = transport.transport_type
try:
obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data)
except Exception:
msg = f"{self.transport} data not json: {data}"
if isinstance(data, bytes):
data = data.decode()
msg = f"{transport_type} data not valid json: {data}"
logging.exception(msg)
err = self.build_error(-32700, "Parse error")
return jsonw.dumps(err)
if isinstance(obj, list):
responses: List[Dict[str, Any]] = []
for item in obj:
self._log_request(item)
resp = await self.process_object(item, conn)
self._log_request(item, transport_type)
resp = await self.process_object(item, transport)
if resp is not None:
self._log_response(resp)
self._log_response(resp, transport_type)
responses.append(resp)
if responses:
return jsonw.dumps(responses)
else:
self._log_request(obj)
response = await self.process_object(obj, conn)
self._log_request(obj, transport_type)
response = await self.process_object(obj, transport)
if response is not None:
self._log_response(response)
self._log_response(response, transport_type)
return jsonw.dumps(response)
return None
async def process_object(self,
obj: Dict[str, Any],
conn: Optional[BaseRemoteConnection]
) -> Optional[Dict[str, Any]]:
async def process_object(
self,
obj: Dict[str, Any],
transport: APITransport
) -> Optional[Dict[str, Any]]:
req_id: Optional[int] = obj.get('id', None)
rpc_version: str = obj.get('jsonrpc', "")
if rpc_version != "2.0":
return self.build_error(-32600, "Invalid Request", req_id)
method_name = obj.get('method', Sentinel.MISSING)
if method_name is Sentinel.MISSING:
self.process_response(obj, conn)
self.process_response(obj, transport)
return None
if not isinstance(method_name, str):
return self.build_error(
-32600, "Invalid Request", req_id, method_name=str(method_name)
)
method = self.methods.get(method_name, None)
if method is None:
method_info = self.methods.get(method_name, None)
if method_info is None:
return self.build_error(
-32601, "Method not found", req_id, method_name=method_name
)
request_type, api_definition = method_info
transport_type = transport.transport_type
if transport_type not in api_definition.transports:
return self.build_error(
-32601, f"Method not found for transport {transport_type.name}",
req_id, method_name=method_name
)
params: Dict[str, Any] = {}
if 'params' in obj:
params = obj['params']
@ -704,12 +743,14 @@ class JsonRPC:
return self.build_error(
-32602, "Invalid params:", req_id, method_name=method_name
)
return await self.execute_method(method_name, method, req_id, conn, params)
return await self.execute_method(
method_name, request_type, api_definition, req_id, transport, params
)
def process_response(
self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection]
self, obj: Dict[str, Any], conn: APITransport
) -> None:
if conn is None:
if not isinstance(conn, BaseRemoteConnection):
logging.debug(f"RPC Response to non-socket request: {obj}")
return
response_id = obj.get("id")
@ -734,15 +775,21 @@ class JsonRPC:
async def execute_method(
self,
method_name: str,
callback: RPCCallback,
request_type: RequestType,
api_definition: APIDefinition,
req_id: Optional[int],
conn: Optional[BaseRemoteConnection],
transport: APITransport,
params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
if conn is not None:
params["_socket_"] = conn
try:
result = await callback(params)
transport.screen_rpc_request(api_definition, request_type, params)
if isinstance(transport, BaseRemoteConnection):
result = await api_definition.request(
params, request_type, transport, transport.ip_addr,
transport.user_info
)
else:
result = await api_definition.request(params, request_type, transport)
except TypeError as e:
return self.build_error(
-32602, f"Invalid params:\n{e}", req_id, True, method_name

View File

@ -6,7 +6,7 @@
from __future__ import annotations
from ..utils import Sentinel
from ..common import WebRequest, Subscribable, RequestType
from ..common import WebRequest, APITransport, RequestType
# Annotation imports
from typing import (
@ -38,7 +38,7 @@ STATUS_ENDPOINT = "objects/query"
OBJ_LIST_ENDPOINT = "objects/list"
REG_METHOD_ENDPOINT = "register_remote_method"
class KlippyAPI(Subscribable):
class KlippyAPI(APITransport):
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.klippy: Klippy = self.server.lookup_component("klippy_connection")
@ -103,7 +103,7 @@ class KlippyAPI(Subscribable):
default: Any = Sentinel.MISSING
) -> Any:
try:
req = WebRequest(method, params, conn=self)
req = WebRequest(method, params, transport=self)
result = await self.klippy.request(req)
except self.server.error:
if default is Sentinel.MISSING:

View File

@ -15,10 +15,8 @@ import paho.mqtt.client as paho_mqtt
from ..common import (
TransportType,
RequestType,
Subscribable,
WebRequest,
APITransport,
JsonRPC,
KlippyState
)
from ..utils import json_wrapper as jsonw
@ -38,9 +36,9 @@ from typing import (
Deque,
)
if TYPE_CHECKING:
from ..app import APIDefinition
from ..confighelper import ConfigHelper
from ..klippy_connection import KlippyConnection as Klippy
from ..common import JsonRPC, APIDefinition
FlexCallback = Callable[[bytes], Optional[Coroutine]]
RPCCallback = Callable[..., Coroutine]
@ -249,7 +247,7 @@ class AIOHelper:
logging.info("MQTT Misc Loop Complete")
class MQTTClient(APITransport, Subscribable):
class MQTTClient(APITransport):
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.eventloop = self.server.get_event_loop()
@ -318,7 +316,6 @@ class MQTTClient(APITransport, Subscribable):
)
# Subscribe to API requests
self.json_rpc = JsonRPC(self.server, transport="MQTT")
self.api_request_topic = f"{self.instance_name}/moonraker/api/request"
self.api_resp_topic = f"{self.instance_name}/moonraker/api/response"
self.klipper_status_topic = f"{self.instance_name}/klipper/status"
@ -342,10 +339,6 @@ class MQTTClient(APITransport, Subscribable):
self.timestamp_deque: Deque = deque(maxlen=20)
self.api_qos = config.getint('api_qos', self.qos)
if config.getboolean("enable_moonraker_api", True):
api_cache = self.server.register_api_transport(TransportType.MQTT, self)
for api_def in api_cache.values():
if TransportType.MQTT in api_def.supported_transports:
self.register_api_handler(api_def)
self.subscribe_topic(self.api_request_topic,
self._process_api_request,
self.api_qos)
@ -378,7 +371,7 @@ class MQTTClient(APITransport, Subscribable):
args = {'objects': self.status_objs}
try:
await self.klippy.request(
WebRequest("objects/subscribe", args, conn=self)
WebRequest("objects/subscribe", args, transport=self)
)
except self.server.error:
pass
@ -683,38 +676,19 @@ class MQTTClient(APITransport, Subscribable):
}
async def _process_api_request(self, payload: bytes) -> None:
response = await self.json_rpc.dispatch(payload.decode())
rpc: JsonRPC = self.server.lookup_component("jsonrpc")
response = await rpc.dispatch(payload, self)
if response is not None:
await self.publish_topic(self.api_resp_topic, response,
self.api_qos)
def register_api_handler(self, api_def: APIDefinition) -> None:
for req_type, rpc_method in api_def.rpc_methods.items():
rpc_cb = self._generate_rpc_callback(
api_def.endpoint, req_type, api_def.callback
)
self.json_rpc.register_method(rpc_method, rpc_cb)
logging.info(
"Registering MQTT JSON-RPC methods: "
f"{', '.join(api_def.rpc_methods.values())}")
@property
def transport_type(self) -> TransportType:
return TransportType.MQTT
def remove_api_handler(self, api_def: APIDefinition) -> None:
for jrpc_method in api_def.rpc_methods.values():
self.json_rpc.remove_method(jrpc_method)
def _generate_rpc_callback(
self,
endpoint: str,
request_type: RequestType,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
self._check_timestamp(args)
result = await callback(WebRequest(endpoint, args, request_type))
return result
return func
def _check_timestamp(self, args: Dict[str, Any]) -> None:
def screen_rpc_request(
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
) -> None:
ts = args.pop("mqtt_timestamp", None)
if ts is not None:
if ts in self.timestamp_deque:

View File

@ -17,7 +17,7 @@ import logging.handlers
import tempfile
from queue import SimpleQueue
from ..loghelper import LocalQueueHandler
from ..common import Subscribable, WebRequest, JobEvent, KlippyState
from ..common import APITransport, WebRequest, JobEvent, KlippyState
from ..utils import json_wrapper as jsonw
from typing import (
@ -58,7 +58,7 @@ PRE_SETUP_EVENTS = [
"ping"
]
class SimplyPrint(Subscribable):
class SimplyPrint(APITransport):
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self._logger = ProtoLogger(config)
@ -585,7 +585,8 @@ class SimplyPrint(Subscribable):
klippy = self.server.lookup_component("klippy_connection")
try:
resp: Dict[str, Dict[str, Any]] = await klippy.request(
WebRequest("objects/subscribe", args, conn=self))
WebRequest("objects/subscribe", args, transport=self)
)
status: Dict[str, Any] = resp.get("status", {})
except self.server.error:
status = {}

View File

@ -32,7 +32,7 @@ from typing import (
)
if TYPE_CHECKING:
from .server import Server
from .common import WebRequest, Subscribable, BaseRemoteConnection
from .common import WebRequest, APITransport, BaseRemoteConnection
from .confighelper import ConfigHelper
from .components.klippy_apis import KlippyAPI
from .components.file_manager.file_manager import FileManager
@ -80,7 +80,7 @@ class KlippyConnection:
self.init_attempts: int = 0
self._state: KlippyState = KlippyState.DISCONNECTED
self._state.set_message("Klippy Disconnected")
self.subscriptions: Dict[Subscribable, Subscription] = {}
self.subscriptions: Dict[APITransport, Subscription] = {}
self.subscription_cache: Dict[str, Dict[str, Any]] = {}
# Setup remote methods accessable to Klippy. Note that all
# registered remote methods should be of the notification type,
@ -657,7 +657,7 @@ class KlippyConnection:
finally:
self.pending_requests.pop(base_request.id, None)
def remove_subscription(self, conn: Subscribable) -> None:
def remove_subscription(self, conn: APITransport) -> None:
self.subscriptions.pop(conn, None)
def is_connected(self) -> bool:

View File

@ -26,6 +26,7 @@ from .klippy_connection import KlippyConnection
from .utils import ServerError, Sentinel, get_software_info, json_wrapper
from .loghelper import LogManager
from .common import RequestType
from .websockets import WebsocketManager
# Annotation imports
from typing import (
@ -42,7 +43,6 @@ from typing import (
)
if TYPE_CHECKING:
from .common import WebRequest
from .websockets import WebsocketManager
from .components.file_manager.file_manager import FileManager
from .components.machine import Machine
from .components.extensions import ExtensionManager
@ -96,8 +96,8 @@ class Server:
self.register_debug_endpoint = app.register_debug_endpoint
self.register_static_file_handler = app.register_static_file_handler
self.register_upload_handler = app.register_upload_handler
self.register_api_transport = app.register_api_transport
self.log_manager.set_server(self)
self.websocket_manager = WebsocketManager(config)
for warning in args.get("startup_warnings", []):
self.add_warning(warning)
@ -309,8 +309,7 @@ class Server:
def register_notification(
self, event_name: str, notify_name: Optional[str] = None
) -> None:
wsm: WebsocketManager = self.lookup_component("websockets")
wsm.register_notification(event_name, notify_name)
self.websocket_manager.register_notification(event_name, notify_name)
def register_event_handler(
self, event: str, callback: FlexCallback
@ -391,6 +390,7 @@ class Server:
await asyncio.sleep(.1)
try:
await self.moonraker_app.close()
await self.websocket_manager.close()
except Exception:
logging.exception("Error Closing App")
@ -434,7 +434,6 @@ class Server:
reg_dirs = []
if file_manager is not None:
reg_dirs = file_manager.get_registered_dirs()
wsm: WebsocketManager = self.lookup_component('websockets')
mreqs = self.klippy_connection.missing_requirements
if raw:
warnings = list(self.warnings.values())
@ -449,7 +448,7 @@ class Server:
'failed_components': self.failed_components,
'registered_directories': reg_dirs,
'warnings': warnings,
'websocket_count': wsm.get_count(),
'websocket_count': self.websocket_manager.get_count(),
'moonraker_version': self.app_args['software_version'],
'missing_klippy_requirements': mreqs,
'api_version': API_VERSION,

View File

@ -14,9 +14,7 @@ from .common import (
RequestType,
WebRequest,
BaseRemoteConnection,
APITransport,
APIDefinition,
JsonRPC,
TransportType,
)
from .utils import ServerError
@ -36,6 +34,7 @@ from typing import (
if TYPE_CHECKING:
from .server import Server
from .klippy_connection import KlippyConnection as Klippy
from .confighelper import ConfigHelper
from .components.extensions import ExtensionManager
from .components.authorization import Authorization
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@ -46,17 +45,21 @@ if TYPE_CHECKING:
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
class WebsocketManager(APITransport):
def __init__(self, server: Server) -> None:
self.server = server
class WebsocketManager:
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.clients: Dict[int, BaseRemoteConnection] = {}
self.bridge_connections: Dict[int, BridgeSocket] = {}
self.rpc = JsonRPC(server)
self.closed_event: Optional[asyncio.Event] = None
self.rpc.register_method("server.websocket.id", self._handle_id_request)
self.rpc.register_method(
"server.connection.identify", self._handle_identify)
self.server.register_endpoint(
"/server/websocket/id", RequestType.GET, self._handle_id_request,
TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/server/connection/identify", RequestType.POST, self._handle_identify,
TransportType.WEBSOCKET
)
self.server.register_component("websockets", self)
def register_notification(
self,
@ -75,64 +78,27 @@ class WebsocketManager(APITransport):
self.notify_clients(notify_name, args)
self.server.register_event_handler(event_name, notify_handler)
def register_api_handler(self, api_def: APIDefinition) -> None:
for req_type, rpc_method in api_def.rpc_methods.items():
rpc_cb = self._generate_rpc_callback(
api_def.endpoint, req_type, api_def.callback
)
self.rpc.register_method(rpc_method, rpc_cb)
logging.info(
"Registering Websocket JSON-RPC methods: "
f"{', '.join(api_def.rpc_methods.values())}"
)
def remove_api_handler(self, api_def: APIDefinition) -> None:
for rpc_method in api_def.rpc_methods.values():
self.rpc.remove_method(rpc_method)
def _generate_rpc_callback(
self,
endpoint: str,
request_type: RequestType,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
sc: BaseRemoteConnection = args.pop("_socket_")
sc.check_authenticated(path=endpoint)
result = await callback(
WebRequest(
endpoint, args, request_type, sc,
ip_addr=sc.ip_addr, user=sc.user_info
)
)
return result
return func
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseRemoteConnection = args["_socket_"]
sc.check_authenticated()
async def _handle_id_request(self, web_request: WebRequest) -> Dict[str, int]:
sc = web_request.get_client_connection()
assert sc is not None
return {'websocket_id': sc.uid}
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseRemoteConnection = args["_socket_"]
sc.authenticate(
token=args.get("access_token", None),
api_key=args.get("api_key", None)
)
async def _handle_identify(self, web_request: WebRequest) -> Dict[str, int]:
sc = web_request.get_client_connection()
assert sc is not None
if sc.identified:
raise self.server.error(
f"Connection already identified: {sc.client_data}"
)
try:
name = str(args["client_name"])
version = str(args["version"])
client_type: str = str(args["type"]).lower()
url = str(args["url"])
except KeyError as e:
missing_key = str(e).split(":")[-1].strip()
raise self.server.error(
f"No data for argument: {missing_key}"
) from None
name = web_request.get_str("client_name")
version = web_request.get_str("version")
client_type: str = web_request.get_str("type").lower()
url = web_request.get_str("url")
sc.authenticate(
token=web_request.get_str("access_token", None),
api_key=web_request.get_str("api_key", None)
)
if client_type not in CLIENT_TYPES:
raise self.server.error(f"Invalid Client Type: {client_type}")
sc.client_data = {