jsonrpc: share one instance among all transports
This change refactors the APIDefiniton into a dataclass, allowing defs to be shared directly among HTTP and RPC requests. In addition, all transports now share one instance of JSONRPC, removing duplicate registration. API Defintiions are registered with the RPC Dispatcher, and it validates the Transport type. In addition tranports may perform their own validation prior to request execution. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
8b2d9b26f5
commit
bfeb096f31
|
@ -24,6 +24,7 @@ from tornado.http1connection import HTTP1Connection
|
|||
from tornado.log import access_log
|
||||
from .utils import ServerError, source_info
|
||||
from .common import (
|
||||
JsonRPC,
|
||||
WebRequest,
|
||||
APIDefinition,
|
||||
APITransport,
|
||||
|
@ -50,7 +51,6 @@ from typing import (
|
|||
Union,
|
||||
Dict,
|
||||
List,
|
||||
Tuple,
|
||||
AsyncGenerator,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
|
@ -121,34 +121,26 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
|
|||
class InternalTransport(APITransport):
|
||||
def __init__(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.callbacks: Dict[str, Tuple[str, RequestType, APICallback]] = {}
|
||||
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
ep = api_def.endpoint
|
||||
cb = api_def.callback
|
||||
for req_type, rpc_method in api_def.rpc_methods.items():
|
||||
self.callbacks[rpc_method] = (ep, req_type, cb)
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for rpc_method in api_def.rpc_methods.values():
|
||||
self.callbacks.pop(rpc_method, None)
|
||||
|
||||
async def call_method(self,
|
||||
method_name: str,
|
||||
request_arguments: Dict[str, Any] = {},
|
||||
**kwargs
|
||||
) -> Any:
|
||||
if method_name not in self.callbacks:
|
||||
rpc: JsonRPC = self.server.lookup_component("jsonrpc")
|
||||
method_info = rpc.get_method(method_name)
|
||||
if method_info is None:
|
||||
raise self.server.error(f"No method {method_name} available")
|
||||
req_type, api_definition = method_info
|
||||
if TransportType.INTERNAL not in api_definition.transports:
|
||||
raise self.server.error(f"No method {method_name} available")
|
||||
ep, req_type, func = self.callbacks[method_name]
|
||||
# Request arguments can be suppplied either through a dict object
|
||||
# or via keyword arugments
|
||||
args = request_arguments or kwargs
|
||||
return await func(WebRequest(ep, dict(args), req_type))
|
||||
return await api_definition.request(args, req_type, self)
|
||||
|
||||
class MoonrakerApp:
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.json_rpc = JsonRPC(self.server)
|
||||
self.http_server: Optional[HTTPServer] = None
|
||||
self.secure_server: Optional[HTTPServer] = None
|
||||
self.template_cache: Dict[str, JinjaTemplate] = {}
|
||||
|
@ -180,14 +172,7 @@ class MoonrakerApp:
|
|||
)
|
||||
self._route_prefix = f"/{rp}"
|
||||
home_pattern = f"{self._route_prefix}/?"
|
||||
|
||||
# Set Up Websocket and Authorization Managers
|
||||
self.wsm = WebsocketManager(self.server)
|
||||
self.internal_transport = InternalTransport(self.server)
|
||||
self.api_transports: Dict[TransportType, APITransport] = {
|
||||
TransportType.WEBSOCKET: self.wsm,
|
||||
TransportType.INTERNAL: self.internal_transport
|
||||
}
|
||||
|
||||
mimetypes.add_type('text/plain', '.log')
|
||||
mimetypes.add_type('text/plain', '.gcode')
|
||||
|
@ -228,9 +213,8 @@ class MoonrakerApp:
|
|||
|
||||
# Register Server Components
|
||||
self.server.register_component("application", self)
|
||||
self.server.register_component("websockets", self.wsm)
|
||||
self.server.register_component("internal_transport",
|
||||
self.internal_transport)
|
||||
self.server.register_component("jsonrpc", self.json_rpc)
|
||||
self.server.register_component("internal_transport", self.internal_transport)
|
||||
|
||||
def _get_path_option(
|
||||
self, config: ConfigHelper, option: str
|
||||
|
@ -318,13 +302,6 @@ class MoonrakerApp:
|
|||
if self.secure_server is not None:
|
||||
self.secure_server.stop()
|
||||
await self.secure_server.close_all_connections()
|
||||
await self.wsm.close()
|
||||
|
||||
def register_api_transport(
|
||||
self, trtype: TransportType, api_transport: APITransport
|
||||
) -> Dict[str, APIDefinition]:
|
||||
self.api_transports[trtype] = api_transport
|
||||
return APIDefinition.get_cache()
|
||||
|
||||
def register_endpoint(
|
||||
self,
|
||||
|
@ -356,9 +333,10 @@ class MoonrakerApp:
|
|||
f"{self._route_prefix}{http_path}", DynamicRequestHandler, params
|
||||
)
|
||||
self.registered_base_handlers.append(http_path)
|
||||
for trtype, api_transport in self.api_transports.items():
|
||||
if trtype in transports:
|
||||
api_transport.register_api_handler(api_def)
|
||||
for request_type, method_name in api_def.rpc_items():
|
||||
transports = api_def.transports & ~TransportType.HTTP
|
||||
logging.info(f"Registering RPC Method: ({transports}) {method_name}")
|
||||
self.json_rpc.register_method(method_name, request_type, api_def)
|
||||
|
||||
def register_static_file_handler(
|
||||
self, pattern: str, file_path: str, force: bool = False
|
||||
|
@ -416,8 +394,8 @@ class MoonrakerApp:
|
|||
if api_def.http_path in self.registered_base_handlers:
|
||||
self.registered_base_handlers.remove(api_def.http_path)
|
||||
self.mutable_router.remove_handler(api_def.http_path)
|
||||
for api_transport in self.api_transports.values():
|
||||
api_transport.remove_api_handler(api_def)
|
||||
for method_name in api_def.rpc_methods:
|
||||
self.json_rpc.remove_method(method_name)
|
||||
|
||||
async def load_template(self, asset_name: str) -> JinjaTemplate:
|
||||
if asset_name in self.template_cache:
|
||||
|
@ -475,8 +453,7 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
|
|||
except Exception:
|
||||
pass
|
||||
else:
|
||||
wsm: WebsocketManager = self.server.lookup_component(
|
||||
"websockets")
|
||||
wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||
conn = wsm.get_client(conn_id)
|
||||
if not isinstance(conn, WebSocket):
|
||||
return None
|
||||
|
@ -644,25 +621,18 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
|
|||
async def delete(self, *args, **kwargs) -> None:
|
||||
await self._process_http_request(RequestType.DELETE)
|
||||
|
||||
async def _do_request(
|
||||
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
|
||||
) -> Any:
|
||||
return await self.api_defintion.callback(
|
||||
WebRequest(
|
||||
self.endpoint, args, req_type, conn=conn,
|
||||
ip_addr=self.request.remote_ip or "", user=self.current_user
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_http_request(self, req_type: RequestType) -> None:
|
||||
if req_type not in self.api_defintion.request_types:
|
||||
raise tornado.web.HTTPError(405)
|
||||
conn = self.get_associated_websocket()
|
||||
args = self.parse_args()
|
||||
transport = self.get_associated_websocket()
|
||||
req = f"{self.request.method} {self.request.path}"
|
||||
self._log_debug(f"HTTP Request::{req}", args)
|
||||
try:
|
||||
result = await self._do_request(args, conn, req_type)
|
||||
ip = self.request.remote_ip or ""
|
||||
result = await self.api_defintion.request(
|
||||
args, req_type, transport, ip, self.current_user
|
||||
)
|
||||
except ServerError as e:
|
||||
raise tornado.web.HTTPError(
|
||||
e.status_code, reason=str(e)) from e
|
||||
|
|
|
@ -11,6 +11,7 @@ import logging
|
|||
import copy
|
||||
import re
|
||||
from enum import Enum, Flag, auto
|
||||
from dataclasses import dataclass
|
||||
from .utils import ServerError, Sentinel
|
||||
from .utils import json_wrapper as jsonw
|
||||
|
||||
|
@ -26,7 +27,9 @@ from typing import (
|
|||
Union,
|
||||
Dict,
|
||||
List,
|
||||
Awaitable
|
||||
Awaitable,
|
||||
ClassVar,
|
||||
Tuple
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -157,31 +160,34 @@ class KlippyState(ExtendedEnum):
|
|||
def startup_complete(self) -> bool:
|
||||
return self.value > 2
|
||||
|
||||
class Subscribable:
|
||||
def send_status(
|
||||
self, status: Dict[str, Any], eventtime: float
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class APIDefinition:
|
||||
_cache: Dict[str, APIDefinition] = {}
|
||||
def __init__(
|
||||
endpoint: str
|
||||
http_path: str
|
||||
rpc_methods: List[str]
|
||||
request_types: RequestType
|
||||
transports: TransportType
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
_cache: ClassVar[Dict[str, APIDefinition]] = {}
|
||||
|
||||
def request(
|
||||
self,
|
||||
endpoint: str,
|
||||
http_path: str,
|
||||
rpc_methods: Dict[RequestType, str],
|
||||
request_types: RequestType,
|
||||
transports: TransportType,
|
||||
callback: Callable[[WebRequest], Coroutine],
|
||||
need_object_parser: bool
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.http_path = http_path
|
||||
self.rpc_methods = rpc_methods
|
||||
self.request_types = request_types
|
||||
self.supported_transports = transports
|
||||
self.callback = callback
|
||||
self.need_object_parser = need_object_parser
|
||||
args: Dict[str, Any],
|
||||
request_type: RequestType,
|
||||
transport: Optional[APITransport] = None,
|
||||
ip_addr: str = "",
|
||||
user: Optional[Dict[str, Any]] = None
|
||||
) -> Coroutine:
|
||||
return self.callback(
|
||||
WebRequest(self.endpoint, args, request_type, transport, ip_addr, user)
|
||||
)
|
||||
|
||||
@property
|
||||
def need_object_parser(self) -> bool:
|
||||
return self.endpoint.startswith("objects/")
|
||||
|
||||
def rpc_items(self) -> zip[Tuple[RequestType, str]]:
|
||||
return zip(self.request_types, self.rpc_methods)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -210,30 +216,30 @@ class APIDefinition:
|
|||
f"Invalid endpoint name '{endpoint}', must start with one of "
|
||||
f"the following: {prefixes}"
|
||||
)
|
||||
jrpc_methods: Dict[RequestType, str] = {}
|
||||
rpc_methods: List[str] = []
|
||||
if is_remote:
|
||||
# Request Types have no meaning for remote requests. Therefore
|
||||
# both GET and POST http requests are accepted. JRPC requests do
|
||||
# not need an associated RequestType, so the unknown value is used.
|
||||
request_types = RequestType.GET | RequestType.POST
|
||||
jrpc_methods[RequestType(0)] = http_path[1:].replace('/', '.')
|
||||
else:
|
||||
rpc_methods.append(http_path[1:].replace('/', '.'))
|
||||
elif transports != TransportType.HTTP:
|
||||
name_parts = http_path[1:].split('/')
|
||||
if len(request_types) > 1:
|
||||
for rtype in request_types:
|
||||
func_name = rtype.name.lower() + "_" + name_parts[-1]
|
||||
jrpc_methods[rtype] = ".".join(name_parts[:-1] + [func_name])
|
||||
rpc_methods.append(".".join(name_parts[:-1] + [func_name]))
|
||||
else:
|
||||
jrpc_methods[request_types] = ".".join(name_parts)
|
||||
if len(request_types) != len(jrpc_methods):
|
||||
rpc_methods.append(".".join(name_parts))
|
||||
if len(request_types) != len(rpc_methods):
|
||||
raise ServerError(
|
||||
"Invalid API definition. Number of websocket methods must "
|
||||
"match the number of request methods"
|
||||
)
|
||||
need_object_parser = endpoint.startswith("objects/")
|
||||
|
||||
api_def = cls(
|
||||
endpoint, http_path, jrpc_methods, request_types,
|
||||
transports, callback, need_object_parser
|
||||
endpoint, http_path, rpc_methods, request_types,
|
||||
transports, callback
|
||||
)
|
||||
cls._cache[endpoint] = api_def
|
||||
return api_def
|
||||
|
@ -247,18 +253,26 @@ class APIDefinition:
|
|||
return cls._cache
|
||||
|
||||
class APITransport:
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
return TransportType.INTERNAL
|
||||
|
||||
def screen_rpc_request(
|
||||
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def send_status(
|
||||
self, status: Dict[str, Any], eventtime: float
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class BaseRemoteConnection(Subscribable):
|
||||
class BaseRemoteConnection(APITransport):
|
||||
def on_create(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.eventloop = server.get_event_loop()
|
||||
self.wsm: WebsocketManager = self.server.lookup_component("websockets")
|
||||
self.rpc = self.wsm.rpc
|
||||
self.rpc: JsonRPC = self.server.lookup_component("jsonrpc")
|
||||
self._uid = id(self)
|
||||
self.ip_addr = ""
|
||||
self.is_closed: bool = False
|
||||
|
@ -314,6 +328,15 @@ class BaseRemoteConnection(Subscribable):
|
|||
self._client_data = data
|
||||
self._identified = True
|
||||
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
return TransportType.WEBSOCKET
|
||||
|
||||
def screen_rpc_request(
|
||||
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
|
||||
) -> None:
|
||||
self.check_authenticated(api_def.endpoint)
|
||||
|
||||
async def _process_message(self, message: str) -> None:
|
||||
try:
|
||||
response = await self.rpc.dispatch(message, self)
|
||||
|
@ -442,14 +465,14 @@ class WebRequest:
|
|||
endpoint: str,
|
||||
args: Dict[str, Any],
|
||||
request_type: RequestType = RequestType(0),
|
||||
conn: Optional[Subscribable] = None,
|
||||
transport: Optional[APITransport] = None,
|
||||
ip_addr: str = "",
|
||||
user: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.request_type = request_type
|
||||
self.args = args
|
||||
self.conn = conn
|
||||
self.transport = transport
|
||||
self.request_type = request_type
|
||||
self.ip_addr: Optional[IPUnion] = None
|
||||
try:
|
||||
self.ip_addr = ipaddress.ip_address(ip_addr)
|
||||
|
@ -469,12 +492,12 @@ class WebRequest:
|
|||
def get_args(self) -> Dict[str, Any]:
|
||||
return self.args
|
||||
|
||||
def get_subscribable(self) -> Optional[Subscribable]:
|
||||
return self.conn
|
||||
def get_subscribable(self) -> Optional[APITransport]:
|
||||
return self.transport
|
||||
|
||||
def get_client_connection(self) -> Optional[BaseRemoteConnection]:
|
||||
if isinstance(self.conn, BaseRemoteConnection):
|
||||
return self.conn
|
||||
if isinstance(self.transport, BaseRemoteConnection):
|
||||
return self.transport
|
||||
return None
|
||||
|
||||
def get_ip_address(self) -> Optional[IPUnion]:
|
||||
|
@ -595,15 +618,12 @@ class WebRequest:
|
|||
|
||||
|
||||
class JsonRPC:
|
||||
def __init__(
|
||||
self, server: Server, transport: str = "Websocket"
|
||||
) -> None:
|
||||
self.methods: Dict[str, RPCCallback] = {}
|
||||
self.transport = transport
|
||||
def __init__(self, server: Server) -> None:
|
||||
self.methods: Dict[str, Tuple[RequestType, APIDefinition]] = {}
|
||||
self.sanitize_response = False
|
||||
self.verbose = server.is_verbose_enabled()
|
||||
|
||||
def _log_request(self, rpc_obj: Dict[str, Any], ) -> None:
|
||||
def _log_request(self, rpc_obj: Dict[str, Any], trtype: TransportType) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
self.sanitize_response = False
|
||||
|
@ -624,9 +644,11 @@ class JsonRPC:
|
|||
for field in ["access_token", "api_key"]:
|
||||
if field in params:
|
||||
output["params"][field] = "<sanitized>"
|
||||
logging.debug(f"{self.transport} Received::{jsonw.dumps(output).decode()}")
|
||||
logging.debug(f"{trtype} Received::{jsonw.dumps(output).decode()}")
|
||||
|
||||
def _log_response(self, resp_obj: Optional[Dict[str, Any]]) -> None:
|
||||
def _log_response(
|
||||
self, resp_obj: Optional[Dict[str, Any]], trtype: TransportType
|
||||
) -> None:
|
||||
if not self.verbose:
|
||||
return
|
||||
if resp_obj is None:
|
||||
|
@ -636,67 +658,84 @@ class JsonRPC:
|
|||
output = copy.deepcopy(resp_obj)
|
||||
output["result"] = "<sanitized>"
|
||||
self.sanitize_response = False
|
||||
logging.debug(f"{self.transport} Response::{jsonw.dumps(output).decode()}")
|
||||
logging.debug(f"{trtype} Response::{jsonw.dumps(output).decode()}")
|
||||
|
||||
def register_method(self,
|
||||
name: str,
|
||||
method: RPCCallback
|
||||
) -> None:
|
||||
self.methods[name] = method
|
||||
def register_method(
|
||||
self,
|
||||
name: str,
|
||||
request_type: RequestType,
|
||||
api_definition: APIDefinition
|
||||
) -> None:
|
||||
self.methods[name] = (request_type, api_definition)
|
||||
|
||||
def get_method(self, name: str) -> Optional[Tuple[RequestType, APIDefinition]]:
|
||||
return self.methods.get(name, None)
|
||||
|
||||
def remove_method(self, name: str) -> None:
|
||||
self.methods.pop(name, None)
|
||||
|
||||
async def dispatch(self,
|
||||
data: str,
|
||||
conn: Optional[BaseRemoteConnection] = None
|
||||
) -> Optional[bytes]:
|
||||
async def dispatch(
|
||||
self,
|
||||
data: Union[str, bytes],
|
||||
transport: APITransport
|
||||
) -> Optional[bytes]:
|
||||
transport_type = transport.transport_type
|
||||
try:
|
||||
obj: Union[Dict[str, Any], List[dict]] = jsonw.loads(data)
|
||||
except Exception:
|
||||
msg = f"{self.transport} data not json: {data}"
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode()
|
||||
msg = f"{transport_type} data not valid json: {data}"
|
||||
logging.exception(msg)
|
||||
err = self.build_error(-32700, "Parse error")
|
||||
return jsonw.dumps(err)
|
||||
if isinstance(obj, list):
|
||||
responses: List[Dict[str, Any]] = []
|
||||
for item in obj:
|
||||
self._log_request(item)
|
||||
resp = await self.process_object(item, conn)
|
||||
self._log_request(item, transport_type)
|
||||
resp = await self.process_object(item, transport)
|
||||
if resp is not None:
|
||||
self._log_response(resp)
|
||||
self._log_response(resp, transport_type)
|
||||
responses.append(resp)
|
||||
if responses:
|
||||
return jsonw.dumps(responses)
|
||||
else:
|
||||
self._log_request(obj)
|
||||
response = await self.process_object(obj, conn)
|
||||
self._log_request(obj, transport_type)
|
||||
response = await self.process_object(obj, transport)
|
||||
if response is not None:
|
||||
self._log_response(response)
|
||||
self._log_response(response, transport_type)
|
||||
return jsonw.dumps(response)
|
||||
return None
|
||||
|
||||
async def process_object(self,
|
||||
obj: Dict[str, Any],
|
||||
conn: Optional[BaseRemoteConnection]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
async def process_object(
|
||||
self,
|
||||
obj: Dict[str, Any],
|
||||
transport: APITransport
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
req_id: Optional[int] = obj.get('id', None)
|
||||
rpc_version: str = obj.get('jsonrpc', "")
|
||||
if rpc_version != "2.0":
|
||||
return self.build_error(-32600, "Invalid Request", req_id)
|
||||
method_name = obj.get('method', Sentinel.MISSING)
|
||||
if method_name is Sentinel.MISSING:
|
||||
self.process_response(obj, conn)
|
||||
self.process_response(obj, transport)
|
||||
return None
|
||||
if not isinstance(method_name, str):
|
||||
return self.build_error(
|
||||
-32600, "Invalid Request", req_id, method_name=str(method_name)
|
||||
)
|
||||
method = self.methods.get(method_name, None)
|
||||
if method is None:
|
||||
method_info = self.methods.get(method_name, None)
|
||||
if method_info is None:
|
||||
return self.build_error(
|
||||
-32601, "Method not found", req_id, method_name=method_name
|
||||
)
|
||||
request_type, api_definition = method_info
|
||||
transport_type = transport.transport_type
|
||||
if transport_type not in api_definition.transports:
|
||||
return self.build_error(
|
||||
-32601, f"Method not found for transport {transport_type.name}",
|
||||
req_id, method_name=method_name
|
||||
)
|
||||
params: Dict[str, Any] = {}
|
||||
if 'params' in obj:
|
||||
params = obj['params']
|
||||
|
@ -704,12 +743,14 @@ class JsonRPC:
|
|||
return self.build_error(
|
||||
-32602, "Invalid params:", req_id, method_name=method_name
|
||||
)
|
||||
return await self.execute_method(method_name, method, req_id, conn, params)
|
||||
return await self.execute_method(
|
||||
method_name, request_type, api_definition, req_id, transport, params
|
||||
)
|
||||
|
||||
def process_response(
|
||||
self, obj: Dict[str, Any], conn: Optional[BaseRemoteConnection]
|
||||
self, obj: Dict[str, Any], conn: APITransport
|
||||
) -> None:
|
||||
if conn is None:
|
||||
if not isinstance(conn, BaseRemoteConnection):
|
||||
logging.debug(f"RPC Response to non-socket request: {obj}")
|
||||
return
|
||||
response_id = obj.get("id")
|
||||
|
@ -734,15 +775,21 @@ class JsonRPC:
|
|||
async def execute_method(
|
||||
self,
|
||||
method_name: str,
|
||||
callback: RPCCallback,
|
||||
request_type: RequestType,
|
||||
api_definition: APIDefinition,
|
||||
req_id: Optional[int],
|
||||
conn: Optional[BaseRemoteConnection],
|
||||
transport: APITransport,
|
||||
params: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if conn is not None:
|
||||
params["_socket_"] = conn
|
||||
try:
|
||||
result = await callback(params)
|
||||
transport.screen_rpc_request(api_definition, request_type, params)
|
||||
if isinstance(transport, BaseRemoteConnection):
|
||||
result = await api_definition.request(
|
||||
params, request_type, transport, transport.ip_addr,
|
||||
transport.user_info
|
||||
)
|
||||
else:
|
||||
result = await api_definition.request(params, request_type, transport)
|
||||
except TypeError as e:
|
||||
return self.build_error(
|
||||
-32602, f"Invalid params:\n{e}", req_id, True, method_name
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
from ..utils import Sentinel
|
||||
from ..common import WebRequest, Subscribable, RequestType
|
||||
from ..common import WebRequest, APITransport, RequestType
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
|
@ -38,7 +38,7 @@ STATUS_ENDPOINT = "objects/query"
|
|||
OBJ_LIST_ENDPOINT = "objects/list"
|
||||
REG_METHOD_ENDPOINT = "register_remote_method"
|
||||
|
||||
class KlippyAPI(Subscribable):
|
||||
class KlippyAPI(APITransport):
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.klippy: Klippy = self.server.lookup_component("klippy_connection")
|
||||
|
@ -103,7 +103,7 @@ class KlippyAPI(Subscribable):
|
|||
default: Any = Sentinel.MISSING
|
||||
) -> Any:
|
||||
try:
|
||||
req = WebRequest(method, params, conn=self)
|
||||
req = WebRequest(method, params, transport=self)
|
||||
result = await self.klippy.request(req)
|
||||
except self.server.error:
|
||||
if default is Sentinel.MISSING:
|
||||
|
|
|
@ -15,10 +15,8 @@ import paho.mqtt.client as paho_mqtt
|
|||
from ..common import (
|
||||
TransportType,
|
||||
RequestType,
|
||||
Subscribable,
|
||||
WebRequest,
|
||||
APITransport,
|
||||
JsonRPC,
|
||||
KlippyState
|
||||
)
|
||||
from ..utils import json_wrapper as jsonw
|
||||
|
@ -38,9 +36,9 @@ from typing import (
|
|||
Deque,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from ..app import APIDefinition
|
||||
from ..confighelper import ConfigHelper
|
||||
from ..klippy_connection import KlippyConnection as Klippy
|
||||
from ..common import JsonRPC, APIDefinition
|
||||
FlexCallback = Callable[[bytes], Optional[Coroutine]]
|
||||
RPCCallback = Callable[..., Coroutine]
|
||||
|
||||
|
@ -249,7 +247,7 @@ class AIOHelper:
|
|||
logging.info("MQTT Misc Loop Complete")
|
||||
|
||||
|
||||
class MQTTClient(APITransport, Subscribable):
|
||||
class MQTTClient(APITransport):
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.eventloop = self.server.get_event_loop()
|
||||
|
@ -318,7 +316,6 @@ class MQTTClient(APITransport, Subscribable):
|
|||
)
|
||||
|
||||
# Subscribe to API requests
|
||||
self.json_rpc = JsonRPC(self.server, transport="MQTT")
|
||||
self.api_request_topic = f"{self.instance_name}/moonraker/api/request"
|
||||
self.api_resp_topic = f"{self.instance_name}/moonraker/api/response"
|
||||
self.klipper_status_topic = f"{self.instance_name}/klipper/status"
|
||||
|
@ -342,10 +339,6 @@ class MQTTClient(APITransport, Subscribable):
|
|||
self.timestamp_deque: Deque = deque(maxlen=20)
|
||||
self.api_qos = config.getint('api_qos', self.qos)
|
||||
if config.getboolean("enable_moonraker_api", True):
|
||||
api_cache = self.server.register_api_transport(TransportType.MQTT, self)
|
||||
for api_def in api_cache.values():
|
||||
if TransportType.MQTT in api_def.supported_transports:
|
||||
self.register_api_handler(api_def)
|
||||
self.subscribe_topic(self.api_request_topic,
|
||||
self._process_api_request,
|
||||
self.api_qos)
|
||||
|
@ -378,7 +371,7 @@ class MQTTClient(APITransport, Subscribable):
|
|||
args = {'objects': self.status_objs}
|
||||
try:
|
||||
await self.klippy.request(
|
||||
WebRequest("objects/subscribe", args, conn=self)
|
||||
WebRequest("objects/subscribe", args, transport=self)
|
||||
)
|
||||
except self.server.error:
|
||||
pass
|
||||
|
@ -683,38 +676,19 @@ class MQTTClient(APITransport, Subscribable):
|
|||
}
|
||||
|
||||
async def _process_api_request(self, payload: bytes) -> None:
|
||||
response = await self.json_rpc.dispatch(payload.decode())
|
||||
rpc: JsonRPC = self.server.lookup_component("jsonrpc")
|
||||
response = await rpc.dispatch(payload, self)
|
||||
if response is not None:
|
||||
await self.publish_topic(self.api_resp_topic, response,
|
||||
self.api_qos)
|
||||
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for req_type, rpc_method in api_def.rpc_methods.items():
|
||||
rpc_cb = self._generate_rpc_callback(
|
||||
api_def.endpoint, req_type, api_def.callback
|
||||
)
|
||||
self.json_rpc.register_method(rpc_method, rpc_cb)
|
||||
logging.info(
|
||||
"Registering MQTT JSON-RPC methods: "
|
||||
f"{', '.join(api_def.rpc_methods.values())}")
|
||||
@property
|
||||
def transport_type(self) -> TransportType:
|
||||
return TransportType.MQTT
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for jrpc_method in api_def.rpc_methods.values():
|
||||
self.json_rpc.remove_method(jrpc_method)
|
||||
|
||||
def _generate_rpc_callback(
|
||||
self,
|
||||
endpoint: str,
|
||||
request_type: RequestType,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
self._check_timestamp(args)
|
||||
result = await callback(WebRequest(endpoint, args, request_type))
|
||||
return result
|
||||
return func
|
||||
|
||||
def _check_timestamp(self, args: Dict[str, Any]) -> None:
|
||||
def screen_rpc_request(
|
||||
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
|
||||
) -> None:
|
||||
ts = args.pop("mqtt_timestamp", None)
|
||||
if ts is not None:
|
||||
if ts in self.timestamp_deque:
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging.handlers
|
|||
import tempfile
|
||||
from queue import SimpleQueue
|
||||
from ..loghelper import LocalQueueHandler
|
||||
from ..common import Subscribable, WebRequest, JobEvent, KlippyState
|
||||
from ..common import APITransport, WebRequest, JobEvent, KlippyState
|
||||
from ..utils import json_wrapper as jsonw
|
||||
|
||||
from typing import (
|
||||
|
@ -58,7 +58,7 @@ PRE_SETUP_EVENTS = [
|
|||
"ping"
|
||||
]
|
||||
|
||||
class SimplyPrint(Subscribable):
|
||||
class SimplyPrint(APITransport):
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self._logger = ProtoLogger(config)
|
||||
|
@ -585,7 +585,8 @@ class SimplyPrint(Subscribable):
|
|||
klippy = self.server.lookup_component("klippy_connection")
|
||||
try:
|
||||
resp: Dict[str, Dict[str, Any]] = await klippy.request(
|
||||
WebRequest("objects/subscribe", args, conn=self))
|
||||
WebRequest("objects/subscribe", args, transport=self)
|
||||
)
|
||||
status: Dict[str, Any] = resp.get("status", {})
|
||||
except self.server.error:
|
||||
status = {}
|
||||
|
|
|
@ -32,7 +32,7 @@ from typing import (
|
|||
)
|
||||
if TYPE_CHECKING:
|
||||
from .server import Server
|
||||
from .common import WebRequest, Subscribable, BaseRemoteConnection
|
||||
from .common import WebRequest, APITransport, BaseRemoteConnection
|
||||
from .confighelper import ConfigHelper
|
||||
from .components.klippy_apis import KlippyAPI
|
||||
from .components.file_manager.file_manager import FileManager
|
||||
|
@ -80,7 +80,7 @@ class KlippyConnection:
|
|||
self.init_attempts: int = 0
|
||||
self._state: KlippyState = KlippyState.DISCONNECTED
|
||||
self._state.set_message("Klippy Disconnected")
|
||||
self.subscriptions: Dict[Subscribable, Subscription] = {}
|
||||
self.subscriptions: Dict[APITransport, Subscription] = {}
|
||||
self.subscription_cache: Dict[str, Dict[str, Any]] = {}
|
||||
# Setup remote methods accessable to Klippy. Note that all
|
||||
# registered remote methods should be of the notification type,
|
||||
|
@ -657,7 +657,7 @@ class KlippyConnection:
|
|||
finally:
|
||||
self.pending_requests.pop(base_request.id, None)
|
||||
|
||||
def remove_subscription(self, conn: Subscribable) -> None:
|
||||
def remove_subscription(self, conn: APITransport) -> None:
|
||||
self.subscriptions.pop(conn, None)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
|
|
|
@ -26,6 +26,7 @@ from .klippy_connection import KlippyConnection
|
|||
from .utils import ServerError, Sentinel, get_software_info, json_wrapper
|
||||
from .loghelper import LogManager
|
||||
from .common import RequestType
|
||||
from .websockets import WebsocketManager
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
|
@ -42,7 +43,6 @@ from typing import (
|
|||
)
|
||||
if TYPE_CHECKING:
|
||||
from .common import WebRequest
|
||||
from .websockets import WebsocketManager
|
||||
from .components.file_manager.file_manager import FileManager
|
||||
from .components.machine import Machine
|
||||
from .components.extensions import ExtensionManager
|
||||
|
@ -96,8 +96,8 @@ class Server:
|
|||
self.register_debug_endpoint = app.register_debug_endpoint
|
||||
self.register_static_file_handler = app.register_static_file_handler
|
||||
self.register_upload_handler = app.register_upload_handler
|
||||
self.register_api_transport = app.register_api_transport
|
||||
self.log_manager.set_server(self)
|
||||
self.websocket_manager = WebsocketManager(config)
|
||||
|
||||
for warning in args.get("startup_warnings", []):
|
||||
self.add_warning(warning)
|
||||
|
@ -309,8 +309,7 @@ class Server:
|
|||
def register_notification(
|
||||
self, event_name: str, notify_name: Optional[str] = None
|
||||
) -> None:
|
||||
wsm: WebsocketManager = self.lookup_component("websockets")
|
||||
wsm.register_notification(event_name, notify_name)
|
||||
self.websocket_manager.register_notification(event_name, notify_name)
|
||||
|
||||
def register_event_handler(
|
||||
self, event: str, callback: FlexCallback
|
||||
|
@ -391,6 +390,7 @@ class Server:
|
|||
await asyncio.sleep(.1)
|
||||
try:
|
||||
await self.moonraker_app.close()
|
||||
await self.websocket_manager.close()
|
||||
except Exception:
|
||||
logging.exception("Error Closing App")
|
||||
|
||||
|
@ -434,7 +434,6 @@ class Server:
|
|||
reg_dirs = []
|
||||
if file_manager is not None:
|
||||
reg_dirs = file_manager.get_registered_dirs()
|
||||
wsm: WebsocketManager = self.lookup_component('websockets')
|
||||
mreqs = self.klippy_connection.missing_requirements
|
||||
if raw:
|
||||
warnings = list(self.warnings.values())
|
||||
|
@ -449,7 +448,7 @@ class Server:
|
|||
'failed_components': self.failed_components,
|
||||
'registered_directories': reg_dirs,
|
||||
'warnings': warnings,
|
||||
'websocket_count': wsm.get_count(),
|
||||
'websocket_count': self.websocket_manager.get_count(),
|
||||
'moonraker_version': self.app_args['software_version'],
|
||||
'missing_klippy_requirements': mreqs,
|
||||
'api_version': API_VERSION,
|
||||
|
|
|
@ -14,9 +14,7 @@ from .common import (
|
|||
RequestType,
|
||||
WebRequest,
|
||||
BaseRemoteConnection,
|
||||
APITransport,
|
||||
APIDefinition,
|
||||
JsonRPC,
|
||||
TransportType,
|
||||
)
|
||||
from .utils import ServerError
|
||||
|
||||
|
@ -36,6 +34,7 @@ from typing import (
|
|||
if TYPE_CHECKING:
|
||||
from .server import Server
|
||||
from .klippy_connection import KlippyConnection as Klippy
|
||||
from .confighelper import ConfigHelper
|
||||
from .components.extensions import ExtensionManager
|
||||
from .components.authorization import Authorization
|
||||
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
@ -46,17 +45,21 @@ if TYPE_CHECKING:
|
|||
|
||||
CLIENT_TYPES = ["web", "mobile", "desktop", "display", "bot", "agent", "other"]
|
||||
|
||||
class WebsocketManager(APITransport):
|
||||
def __init__(self, server: Server) -> None:
|
||||
self.server = server
|
||||
class WebsocketManager:
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.clients: Dict[int, BaseRemoteConnection] = {}
|
||||
self.bridge_connections: Dict[int, BridgeSocket] = {}
|
||||
self.rpc = JsonRPC(server)
|
||||
self.closed_event: Optional[asyncio.Event] = None
|
||||
|
||||
self.rpc.register_method("server.websocket.id", self._handle_id_request)
|
||||
self.rpc.register_method(
|
||||
"server.connection.identify", self._handle_identify)
|
||||
self.server.register_endpoint(
|
||||
"/server/websocket/id", RequestType.GET, self._handle_id_request,
|
||||
TransportType.WEBSOCKET
|
||||
)
|
||||
self.server.register_endpoint(
|
||||
"/server/connection/identify", RequestType.POST, self._handle_identify,
|
||||
TransportType.WEBSOCKET
|
||||
)
|
||||
self.server.register_component("websockets", self)
|
||||
|
||||
def register_notification(
|
||||
self,
|
||||
|
@ -75,64 +78,27 @@ class WebsocketManager(APITransport):
|
|||
self.notify_clients(notify_name, args)
|
||||
self.server.register_event_handler(event_name, notify_handler)
|
||||
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for req_type, rpc_method in api_def.rpc_methods.items():
|
||||
rpc_cb = self._generate_rpc_callback(
|
||||
api_def.endpoint, req_type, api_def.callback
|
||||
)
|
||||
self.rpc.register_method(rpc_method, rpc_cb)
|
||||
logging.info(
|
||||
"Registering Websocket JSON-RPC methods: "
|
||||
f"{', '.join(api_def.rpc_methods.values())}"
|
||||
)
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for rpc_method in api_def.rpc_methods.values():
|
||||
self.rpc.remove_method(rpc_method)
|
||||
|
||||
def _generate_rpc_callback(
|
||||
self,
|
||||
endpoint: str,
|
||||
request_type: RequestType,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
sc: BaseRemoteConnection = args.pop("_socket_")
|
||||
sc.check_authenticated(path=endpoint)
|
||||
result = await callback(
|
||||
WebRequest(
|
||||
endpoint, args, request_type, sc,
|
||||
ip_addr=sc.ip_addr, user=sc.user_info
|
||||
)
|
||||
)
|
||||
return result
|
||||
return func
|
||||
|
||||
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseRemoteConnection = args["_socket_"]
|
||||
sc.check_authenticated()
|
||||
async def _handle_id_request(self, web_request: WebRequest) -> Dict[str, int]:
|
||||
sc = web_request.get_client_connection()
|
||||
assert sc is not None
|
||||
return {'websocket_id': sc.uid}
|
||||
|
||||
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseRemoteConnection = args["_socket_"]
|
||||
sc.authenticate(
|
||||
token=args.get("access_token", None),
|
||||
api_key=args.get("api_key", None)
|
||||
)
|
||||
async def _handle_identify(self, web_request: WebRequest) -> Dict[str, int]:
|
||||
sc = web_request.get_client_connection()
|
||||
assert sc is not None
|
||||
if sc.identified:
|
||||
raise self.server.error(
|
||||
f"Connection already identified: {sc.client_data}"
|
||||
)
|
||||
try:
|
||||
name = str(args["client_name"])
|
||||
version = str(args["version"])
|
||||
client_type: str = str(args["type"]).lower()
|
||||
url = str(args["url"])
|
||||
except KeyError as e:
|
||||
missing_key = str(e).split(":")[-1].strip()
|
||||
raise self.server.error(
|
||||
f"No data for argument: {missing_key}"
|
||||
) from None
|
||||
name = web_request.get_str("client_name")
|
||||
version = web_request.get_str("version")
|
||||
client_type: str = web_request.get_str("type").lower()
|
||||
url = web_request.get_str("url")
|
||||
sc.authenticate(
|
||||
token=web_request.get_str("access_token", None),
|
||||
api_key=web_request.get_str("api_key", None)
|
||||
)
|
||||
|
||||
if client_type not in CLIENT_TYPES:
|
||||
raise self.server.error(f"Invalid Client Type: {client_type}")
|
||||
sc.client_data = {
|
||||
|
|
Loading…
Reference in New Issue