app: streamline endpoint registration

Refactor endpoint registration to reduce duplicated code.
Rename some APIDefinition attributes for clarity.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-11-24 20:29:08 -05:00
parent 7de61eb113
commit 8b2d9b26f5
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
6 changed files with 144 additions and 188 deletions

View File

@ -126,20 +126,12 @@ class InternalTransport(APITransport):
def register_api_handler(self, api_def: APIDefinition) -> None: def register_api_handler(self, api_def: APIDefinition) -> None:
ep = api_def.endpoint ep = api_def.endpoint
cb = api_def.callback cb = api_def.callback
if cb is None: for req_type, rpc_method in api_def.rpc_methods.items():
# Request to Klippy self.callbacks[rpc_method] = (ep, req_type, cb)
method = api_def.jrpc_methods[0]
action = RequestType(0)
klippy: Klippy = self.server.lookup_component("klippy_connection")
cb = klippy.request
self.callbacks[method] = (ep, action, cb)
else:
for method, action in zip(api_def.jrpc_methods, api_def.request_types):
self.callbacks[method] = (ep, action, cb)
def remove_api_handler(self, api_def: APIDefinition) -> None: def remove_api_handler(self, api_def: APIDefinition) -> None:
for method in api_def.jrpc_methods: for rpc_method in api_def.rpc_methods.values():
self.callbacks.pop(method, None) self.callbacks.pop(rpc_method, None)
async def call_method(self, async def call_method(self,
method_name: str, method_name: str,
@ -159,7 +151,6 @@ class MoonrakerApp:
self.server = config.get_server() self.server = config.get_server()
self.http_server: Optional[HTTPServer] = None self.http_server: Optional[HTTPServer] = None
self.secure_server: Optional[HTTPServer] = None self.secure_server: Optional[HTTPServer] = None
self.api_cache: Dict[str, APIDefinition] = {}
self.template_cache: Dict[str, JinjaTemplate] = {} self.template_cache: Dict[str, JinjaTemplate] = {}
self.registered_base_handlers: List[str] = [] self.registered_base_handlers: List[str] = []
self.max_upload_size = config.getint('max_upload_size', 1024) self.max_upload_size = config.getint('max_upload_size', 1024)
@ -333,57 +324,38 @@ class MoonrakerApp:
self, trtype: TransportType, api_transport: APITransport self, trtype: TransportType, api_transport: APITransport
) -> Dict[str, APIDefinition]: ) -> Dict[str, APIDefinition]:
self.api_transports[trtype] = api_transport self.api_transports[trtype] = api_transport
return self.api_cache return APIDefinition.get_cache()
def register_remote_handler(self, endpoint: str) -> None: def register_endpoint(
api_def = self._create_api_definition(endpoint, RequestType.GET)
if api_def.uri in self.registered_base_handlers:
# reserved handler or already registered
return
logging.info(
f"Registering HTTP endpoint: ({api_def.request_types}) {api_def.uri}"
)
params: Dict[str, Any] = {}
params['methods'] = api_def.request_types
params['callback'] = api_def.endpoint
params['need_object_parser'] = api_def.need_object_parser
self.mutable_router.add_handler(
f"{self._route_prefix}{api_def.uri}", DynamicRequestHandler, params
)
self.registered_base_handlers.append(api_def.uri)
for api_transport in self.api_transports.values():
api_transport.register_api_handler(api_def)
def register_local_handler(
self, self,
uri: str, endpoint: str,
request_types: Union[List[str], RequestType], request_types: Union[List[str], RequestType],
callback: APICallback, callback: APICallback,
transports: Union[List[str], TransportType] = TransportType.all(), transports: Union[List[str], TransportType] = TransportType.all(),
wrap_result: bool = True, wrap_result: bool = True,
content_type: Optional[str] = None content_type: Optional[str] = None,
is_remote: bool = False
) -> None: ) -> None:
if uri in self.registered_base_handlers:
return
if isinstance(request_types, list): if isinstance(request_types, list):
request_types = RequestType.from_string_list(request_types) request_types = RequestType.from_string_list(request_types)
if isinstance(transports, list): if isinstance(transports, list):
transports = TransportType.from_string_list(transports) transports = TransportType.from_string_list(transports)
api_def = self._create_api_definition( api_def = APIDefinition.create(
uri, request_types, callback, transports=transports endpoint, request_types, callback, transports, is_remote
) )
http_path = api_def.http_path
if http_path in self.registered_base_handlers:
return
if TransportType.HTTP in transports: if TransportType.HTTP in transports:
logging.info(f"Registering HTTP Endpoint: ({request_types}) {uri}") logging.info(f"Registering HTTP Endpoint: ({request_types}) {http_path}")
params: dict[str, Any] = {} params: dict[str, Any] = {}
params['methods'] = request_types params["api_definition"] = api_def
params['callback'] = callback params["wrap_result"] = wrap_result
params['wrap_result'] = wrap_result params["content_type"] = content_type
params['is_remote'] = False
params['content_type'] = content_type
self.mutable_router.add_handler( self.mutable_router.add_handler(
f"{self._route_prefix}{uri}", DynamicRequestHandler, params f"{self._route_prefix}{http_path}", DynamicRequestHandler, params
) )
self.registered_base_handlers.append(uri) self.registered_base_handlers.append(http_path)
for trtype, api_transport in self.api_transports.items(): for trtype, api_transport in self.api_transports.items():
if trtype in transports: if trtype in transports:
api_transport.register_api_handler(api_def) api_transport.register_api_handler(api_def)
@ -420,9 +392,9 @@ class MoonrakerApp:
f"{self._route_prefix}{pattern}", FileUploadHandler, params f"{self._route_prefix}{pattern}", FileUploadHandler, params
) )
def register_debug_handler( def register_debug_endpoint(
self, self,
uri: str, endpoint: str,
request_types: Union[List[str], RequestType], request_types: Union[List[str], RequestType],
callback: APICallback, callback: APICallback,
transports: Union[List[str], TransportType] = TransportType.all(), transports: Union[List[str], TransportType] = TransportType.all(),
@ -430,69 +402,23 @@ class MoonrakerApp:
) -> None: ) -> None:
if not self.server.is_debug_enabled(): if not self.server.is_debug_enabled():
return return
if not uri.startswith("/debug"): if not endpoint.startswith("/debug"):
raise self.server.error( raise self.server.error(
"Debug Endpoints must be registerd in the '/debug' path" "Debug Endpoints must be registered in the '/debug' path"
) )
self.register_local_handler( self.register_endpoint(
uri, request_types, callback, transports, wrap_result endpoint, request_types, callback, transports, wrap_result
) )
def remove_handler(self, endpoint: str) -> None: def remove_endpoint(self, endpoint: str) -> None:
api_def = self.api_cache.pop(endpoint, None) api_def = APIDefinition.pop_cached_def(endpoint)
if api_def is not None: if api_def is not None:
self.mutable_router.remove_handler(api_def.uri) if api_def.http_path in self.registered_base_handlers:
self.registered_base_handlers.remove(api_def.http_path)
self.mutable_router.remove_handler(api_def.http_path)
for api_transport in self.api_transports.values(): for api_transport in self.api_transports.values():
api_transport.remove_api_handler(api_def) api_transport.remove_api_handler(api_def)
def _create_api_definition(
self,
endpoint: str,
request_types: Union[List[str], RequestType],
callback: Optional[APICallback] = None,
transports: Union[List[str], TransportType] = TransportType.all(),
) -> APIDefinition:
is_remote = callback is None
if endpoint in self.api_cache:
return self.api_cache[endpoint]
if isinstance(request_types, list):
request_types = RequestType.from_string_list(request_types)
if isinstance(transports, list):
transports = TransportType.from_string_list(transports)
if endpoint[0] == '/':
uri = endpoint
elif is_remote:
uri = "/printer/" + endpoint
else:
uri = "/server/" + endpoint
jrpc_methods = []
if is_remote:
# Remote requests accept both GET and POST requests. These
# requests execute the same callback, thus they resolve to
# only a single websocket method.
jrpc_methods.append(uri[1:].replace('/', '.'))
request_types = RequestType.GET | RequestType.POST
else:
name_parts = uri[1:].split('/')
if len(request_types) > 1:
for rtype in request_types:
func_name = rtype.name.lower() + "_" + name_parts[-1]
jrpc_methods.append(".".join(
name_parts[:-1] + [func_name]))
else:
jrpc_methods.append(".".join(name_parts))
if not is_remote and len(request_types) != len(jrpc_methods):
raise self.server.error(
"Invalid API definition. Number of websocket methods must "
"match the number of request methods")
need_object_parser = endpoint.startswith("objects/")
api_def = APIDefinition(
endpoint, uri, jrpc_methods, request_types,
transports, callback, need_object_parser
)
self.api_cache[endpoint] = api_def
return api_def
async def load_template(self, asset_name: str) -> JinjaTemplate: async def load_template(self, asset_name: str) -> JinjaTemplate:
if asset_name in self.template_cache: if asset_name in self.template_cache:
return self.template_cache[asset_name] return self.template_cache[asset_name]
@ -524,7 +450,8 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
def prepare(self) -> None: def prepare(self) -> None:
app: MoonrakerApp = self.server.lookup_component("application") app: MoonrakerApp = self.server.lookup_component("application")
self.endpoint = app.parse_endpoint(self.request.path or "") if not self.endpoint:
self.endpoint = app.parse_endpoint(self.request.path or "")
auth: AuthComp = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
self.current_user = auth.check_authorized(self.request, self.endpoint) self.current_user = auth.check_authorized(self.request, self.endpoint)
@ -618,22 +545,16 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
class DynamicRequestHandler(AuthorizedRequestHandler): class DynamicRequestHandler(AuthorizedRequestHandler):
def initialize( def initialize(
self, self,
callback: Union[str, Callable[[WebRequest], Coroutine]] = "", api_definition: Optional[APIDefinition] = None,
methods: RequestType = RequestType((0)),
need_object_parser: bool = False,
is_remote: bool = True,
wrap_result: bool = True, wrap_result: bool = True,
content_type: Optional[str] = None content_type: Optional[str] = None
) -> None: ) -> None:
super(DynamicRequestHandler, self).initialize() super(DynamicRequestHandler, self).initialize()
self.callback = callback assert api_definition is not None
self.methods = methods self.api_defintion = api_definition
self.wrap_result = wrap_result self.wrap_result = wrap_result
self._do_request = self._do_remote_request if is_remote \
else self._do_local_request
self._parse_query = self._object_parser if need_object_parser \
else self._default_parser
self.content_type = content_type self.content_type = content_type
self.endpoint = api_definition.endpoint
# Converts query string values with type hints # Converts query string values with type hints
def _convert_type(self, value: str, hint: str) -> Any: def _convert_type(self, value: str, hint: str) -> Any:
@ -681,7 +602,10 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
def parse_args(self) -> Dict[str, Any]: def parse_args(self) -> Dict[str, Any]:
try: try:
args = self._parse_query() if self.api_defintion.need_object_parser:
args: Dict[str, Any] = self._object_parser()
else:
args = self._default_parser()
except Exception: except Exception:
raise ServerError( raise ServerError(
"Error Parsing Request Arguments. " "Error Parsing Request Arguments. "
@ -720,31 +644,18 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
async def delete(self, *args, **kwargs) -> None: async def delete(self, *args, **kwargs) -> None:
await self._process_http_request(RequestType.DELETE) await self._process_http_request(RequestType.DELETE)
async def _do_local_request( async def _do_request(
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
) -> Any: ) -> Any:
assert callable(self.callback) return await self.api_defintion.callback(
return await self.callback(
WebRequest( WebRequest(
self.endpoint, args, req_type, conn=conn, self.endpoint, args, req_type, conn=conn,
ip_addr=self.request.remote_ip or "", user=self.current_user ip_addr=self.request.remote_ip or "", user=self.current_user
) )
) )
async def _do_remote_request(
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
) -> Any:
assert isinstance(self.callback, str)
klippy: Klippy = self.server.lookup_component("klippy_connection")
return await klippy.request(
WebRequest(
self.callback, args, req_type, conn=conn,
ip_addr=self.request.remote_ip or "", user=self.current_user
)
)
async def _process_http_request(self, req_type: RequestType) -> None: async def _process_http_request(self, req_type: RequestType) -> None:
if req_type not in self.methods: if req_type not in self.api_defintion.request_types:
raise tornado.web.HTTPError(405) raise tornado.web.HTTPError(405)
conn = self.get_associated_websocket() conn = self.get_associated_websocket()
args = self.parse_args() args = self.parse_args()

View File

@ -9,6 +9,7 @@ import sys
import ipaddress import ipaddress
import logging import logging
import copy import copy
import re
from enum import Enum, Flag, auto from enum import Enum, Flag, auto
from .utils import ServerError, Sentinel from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw from .utils import json_wrapper as jsonw
@ -42,6 +43,8 @@ if TYPE_CHECKING:
RPCCallback = Callable[..., Coroutine] RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[Authorization] AuthComp = Optional[Authorization]
ENDPOINT_PREFIXES = ["printer", "server", "machine", "access", "api", "debug"]
class ExtendedFlag(Flag): class ExtendedFlag(Flag):
@classmethod @classmethod
def from_string(cls: Type[_F], flag_name: str) -> _F: def from_string(cls: Type[_F], flag_name: str) -> _F:
@ -161,24 +164,88 @@ class Subscribable:
raise NotImplementedError raise NotImplementedError
class APIDefinition: class APIDefinition:
_cache: Dict[str, APIDefinition] = {}
def __init__( def __init__(
self, self,
endpoint: str, endpoint: str,
http_uri: str, http_path: str,
jrpc_methods: List[str], rpc_methods: Dict[RequestType, str],
request_types: RequestType, request_types: RequestType,
transports: TransportType, transports: TransportType,
callback: Optional[Callable[[WebRequest], Coroutine]], callback: Callable[[WebRequest], Coroutine],
need_object_parser: bool need_object_parser: bool
) -> None: ) -> None:
self.endpoint = endpoint self.endpoint = endpoint
self.uri = http_uri self.http_path = http_path
self.jrpc_methods = jrpc_methods self.rpc_methods = rpc_methods
self.request_types = request_types self.request_types = request_types
self.supported_transports = transports self.supported_transports = transports
self.callback = callback self.callback = callback
self.need_object_parser = need_object_parser self.need_object_parser = need_object_parser
@classmethod
def create(
cls,
endpoint: str,
request_types: Union[List[str], RequestType],
callback: Callable[[WebRequest], Coroutine],
transports: Union[List[str], TransportType] = TransportType.all(),
is_remote: bool = False
) -> APIDefinition:
if isinstance(request_types, list):
request_types = RequestType.from_string_list(request_types)
if isinstance(transports, list):
transports = TransportType.from_string_list(transports)
if endpoint in cls._cache:
return cls._cache[endpoint]
http_path = f"/printer/{endpoint.strip('/')}" if is_remote else endpoint
prf_match = re.match(r"/([^/]+)", http_path)
if TransportType.HTTP in transports:
# Validate the first path segment for definitions that support the
# HTTP transport. We want to restrict components from registering
# using unknown paths.
if prf_match is None or prf_match.group(1) not in ENDPOINT_PREFIXES:
prefixes = [f"/{prefix} " for prefix in ENDPOINT_PREFIXES]
raise ServerError(
f"Invalid endpoint name '{endpoint}', must start with one of "
f"the following: {prefixes}"
)
jrpc_methods: Dict[RequestType, str] = {}
if is_remote:
# Request Types have no meaning for remote requests. Therefore
# both GET and POST http requests are accepted. JRPC requests do
# not need an associated RequestType, so the unknown value is used.
request_types = RequestType.GET | RequestType.POST
jrpc_methods[RequestType(0)] = http_path[1:].replace('/', '.')
else:
name_parts = http_path[1:].split('/')
if len(request_types) > 1:
for rtype in request_types:
func_name = rtype.name.lower() + "_" + name_parts[-1]
jrpc_methods[rtype] = ".".join(name_parts[:-1] + [func_name])
else:
jrpc_methods[request_types] = ".".join(name_parts)
if len(request_types) != len(jrpc_methods):
raise ServerError(
"Invalid API definition. Number of websocket methods must "
"match the number of request methods"
)
need_object_parser = endpoint.startswith("objects/")
api_def = cls(
endpoint, http_path, jrpc_methods, request_types,
transports, callback, need_object_parser
)
cls._cache[endpoint] = api_def
return api_def
@classmethod
def pop_cached_def(cls, endpoint: str) -> Optional[APIDefinition]:
return cls._cache.pop(endpoint, None)
@classmethod
def get_cache(cls) -> Dict[str, APIDefinition]:
return cls._cache
class APITransport: class APITransport:
def register_api_handler(self, api_def: APIDefinition) -> None: def register_api_handler(self, api_def: APIDefinition) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -689,44 +689,31 @@ class MQTTClient(APITransport, Subscribable):
self.api_qos) self.api_qos)
def register_api_handler(self, api_def: APIDefinition) -> None: def register_api_handler(self, api_def: APIDefinition) -> None:
if api_def.callback is None: for req_type, rpc_method in api_def.rpc_methods.items():
# Remote API, uses RPC to reach out to Klippy rpc_cb = self._generate_rpc_callback(
mqtt_method = api_def.jrpc_methods[0] api_def.endpoint, req_type, api_def.callback
rpc_cb = self._generate_remote_callback(api_def.endpoint) )
self.json_rpc.register_method(mqtt_method, rpc_cb) self.json_rpc.register_method(rpc_method, rpc_cb)
else:
# Local API, uses local callback
req_types = api_def.request_types
for mqtt_method, req_type in zip(api_def.jrpc_methods, req_types):
rpc_cb = self._generate_local_callback(
api_def.endpoint, req_type, api_def.callback)
self.json_rpc.register_method(mqtt_method, rpc_cb)
logging.info( logging.info(
"Registering MQTT JSON-RPC methods: " "Registering MQTT JSON-RPC methods: "
f"{', '.join(api_def.jrpc_methods)}") f"{', '.join(api_def.rpc_methods.values())}")
def remove_api_handler(self, api_def: APIDefinition) -> None: def remove_api_handler(self, api_def: APIDefinition) -> None:
for jrpc_method in api_def.jrpc_methods: for jrpc_method in api_def.rpc_methods.values():
self.json_rpc.remove_method(jrpc_method) self.json_rpc.remove_method(jrpc_method)
def _generate_local_callback(self, def _generate_rpc_callback(
endpoint: str, self,
request_type: RequestType, endpoint: str,
callback: Callable[[WebRequest], Coroutine] request_type: RequestType,
) -> RPCCallback: callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any: async def func(args: Dict[str, Any]) -> Any:
self._check_timestamp(args) self._check_timestamp(args)
result = await callback(WebRequest(endpoint, args, request_type)) result = await callback(WebRequest(endpoint, args, request_type))
return result return result
return func return func
def _generate_remote_callback(self, endpoint: str) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
self._check_timestamp(args)
result = await self.klippy.request(WebRequest(endpoint, args))
return result
return func
def _check_timestamp(self, args: Dict[str, Any]) -> None: def _check_timestamp(self, args: Dict[str, Any]) -> None:
ts = args.pop("mqtt_timestamp", None) ts = args.pop("mqtt_timestamp", None)
if ts is not None: if ts is not None:

View File

@ -14,7 +14,7 @@ import asyncio
import pathlib import pathlib
from .utils import ServerError, get_unix_peer_credentials from .utils import ServerError, get_unix_peer_credentials
from .utils import json_wrapper as jsonw from .utils import json_wrapper as jsonw
from .common import KlippyState from .common import KlippyState, RequestType
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -32,7 +32,6 @@ from typing import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .server import Server from .server import Server
from .app import MoonrakerApp
from .common import WebRequest, Subscribable, BaseRemoteConnection from .common import WebRequest, Subscribable, BaseRemoteConnection
from .confighelper import ConfigHelper from .confighelper import ConfigHelper
from .components.klippy_apis import KlippyAPI from .components.klippy_apis import KlippyAPI
@ -352,10 +351,12 @@ class KlippyConnection:
if result is None: if result is None:
return return
endpoints = result.get('endpoints', []) endpoints = result.get('endpoints', [])
app: MoonrakerApp = self.server.lookup_component("application")
for ep in endpoints: for ep in endpoints:
if ep not in RESERVED_ENDPOINTS: if ep not in RESERVED_ENDPOINTS:
app.register_remote_handler(ep) self.server.register_endpoint(
ep, RequestType.GET | RequestType.POST, self.request,
is_remote=True
)
async def _request_initial_subscriptions(self) -> None: async def _request_initial_subscriptions(self) -> None:
try: try:

View File

@ -92,8 +92,8 @@ class Server:
# Tornado Application/Server # Tornado Application/Server
self.moonraker_app = app = MoonrakerApp(config) self.moonraker_app = app = MoonrakerApp(config)
self.register_endpoint = app.register_local_handler self.register_endpoint = app.register_endpoint
self.register_debug_endpoint = app.register_debug_handler self.register_debug_endpoint = app.register_debug_endpoint
self.register_static_file_handler = app.register_static_file_handler self.register_static_file_handler = app.register_static_file_handler
self.register_upload_handler = app.register_upload_handler self.register_upload_handler = app.register_upload_handler
self.register_api_transport = app.register_api_transport self.register_api_transport = app.register_api_transport

View File

@ -76,31 +76,21 @@ class WebsocketManager(APITransport):
self.server.register_event_handler(event_name, notify_handler) self.server.register_event_handler(event_name, notify_handler)
def register_api_handler(self, api_def: APIDefinition) -> None: def register_api_handler(self, api_def: APIDefinition) -> None:
klippy: Klippy = self.server.lookup_component("klippy_connection") for req_type, rpc_method in api_def.rpc_methods.items():
if api_def.callback is None: rpc_cb = self._generate_rpc_callback(
# Remote API, uses RPC to reach out to Klippy api_def.endpoint, req_type, api_def.callback
ws_method = api_def.jrpc_methods[0]
rpc_cb = self._generate_callback(
api_def.endpoint, RequestType(0), klippy.request
) )
self.rpc.register_method(ws_method, rpc_cb) self.rpc.register_method(rpc_method, rpc_cb)
else:
# Local API, uses local callback
for ws_method, req_type in zip(api_def.jrpc_methods, api_def.request_types):
rpc_cb = self._generate_callback(
api_def.endpoint, req_type, api_def.callback
)
self.rpc.register_method(ws_method, rpc_cb)
logging.info( logging.info(
"Registering Websocket JSON-RPC methods: " "Registering Websocket JSON-RPC methods: "
f"{', '.join(api_def.jrpc_methods)}" f"{', '.join(api_def.rpc_methods.values())}"
) )
def remove_api_handler(self, api_def: APIDefinition) -> None: def remove_api_handler(self, api_def: APIDefinition) -> None:
for jrpc_method in api_def.jrpc_methods: for rpc_method in api_def.rpc_methods.values():
self.rpc.remove_method(jrpc_method) self.rpc.remove_method(rpc_method)
def _generate_callback( def _generate_rpc_callback(
self, self,
endpoint: str, endpoint: str,
request_type: RequestType, request_type: RequestType,