common: add RequestType and TransportType flags

These flags replace strings as constants used to register and
identify Request Types (ie: GET, POST) and API Transport
Types.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-06-28 16:30:24 -04:00
parent 42357891a3
commit 7deb9fac4c
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
4 changed files with 187 additions and 107 deletions

View File

@ -22,8 +22,14 @@ from tornado.escape import url_unescape, url_escape
from tornado.routing import Rule, PathMatches, AnyMatches
from tornado.http1connection import HTTP1Connection
from tornado.log import access_log
from .common import WebRequest, APIDefinition, APITransport
from .utils import ServerError, source_info
from .common import (
WebRequest,
APIDefinition,
APITransport,
TransportType,
RequestType
)
from .utils import json_wrapper as jsonw
from .websockets import (
WebsocketManager,
@ -69,7 +75,6 @@ MAX_WS_CONNS_DEFAULT = 50
EXCLUDED_ARGS = ["_", "token", "access_token", "connection_id"]
AUTHORIZED_EXTS = [".png", ".jpg"]
DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log"
ALL_TRANSPORTS = ["http", "websocket", "mqtt", "internal"]
class MutableRouter(tornado.web.ReversibleRuleRouter):
def __init__(self, application: MoonrakerApp) -> None:
@ -115,7 +120,7 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
class InternalTransport(APITransport):
def __init__(self, server: Server) -> None:
self.server = server
self.callbacks: Dict[str, Tuple[str, str, APICallback]] = {}
self.callbacks: Dict[str, Tuple[str, RequestType, APICallback]] = {}
def register_api_handler(self, api_def: APIDefinition) -> None:
ep = api_def.endpoint
@ -123,13 +128,12 @@ class InternalTransport(APITransport):
if cb is None:
# Request to Klippy
method = api_def.jrpc_methods[0]
action = ""
action = RequestType(0)
klippy: Klippy = self.server.lookup_component("klippy_connection")
cb = klippy.request
self.callbacks[method] = (ep, action, cb)
else:
for method, action in \
zip(api_def.jrpc_methods, api_def.request_methods):
for method, action in zip(api_def.jrpc_methods, api_def.request_types):
self.callbacks[method] = (ep, action, cb)
def remove_api_handler(self, api_def: APIDefinition) -> None:
@ -143,11 +147,11 @@ class InternalTransport(APITransport):
) -> Any:
if method_name not in self.callbacks:
raise self.server.error(f"No method {method_name} available")
ep, action, func = self.callbacks[method_name]
ep, req_type, func = self.callbacks[method_name]
# Request arguments can be suppplied either through a dict object
# or via keyword arugments
args = request_arguments or kwargs
return await func(WebRequest(ep, dict(args), action))
return await func(WebRequest(ep, dict(args), req_type))
class MoonrakerApp:
def __init__(self, config: ConfigHelper) -> None:
@ -188,9 +192,9 @@ class MoonrakerApp:
# Set Up Websocket and Authorization Managers
self.wsm = WebsocketManager(self.server)
self.internal_transport = InternalTransport(self.server)
self.api_transports: Dict[str, APITransport] = {
"websocket": self.wsm,
"internal": self.internal_transport
self.api_transports: Dict[TransportType, APITransport] = {
TransportType.WEBSOCKET: self.wsm,
TransportType.INTERNAL: self.internal_transport
}
mimetypes.add_type('text/plain', '.log')
@ -325,49 +329,52 @@ class MoonrakerApp:
await self.wsm.close()
def register_api_transport(
self, name: str, transport: APITransport
self, trtype: TransportType, api_transport: APITransport
) -> Dict[str, APIDefinition]:
self.api_transports[name] = transport
self.api_transports[trtype] = api_transport
return self.api_cache
def register_remote_handler(self, endpoint: str) -> None:
api_def = self._create_api_definition(endpoint)
api_def = self._create_api_definition(endpoint, RequestType.GET)
if api_def.uri in self.registered_base_handlers:
# reserved handler or already registered
return
logging.info(
f"Registering HTTP endpoint: "
f"({' '.join(api_def.request_methods)}) {api_def.uri}")
f"Registering HTTP endpoint: ({api_def.request_types}) {api_def.uri}"
)
params: Dict[str, Any] = {}
params['methods'] = api_def.request_methods
params['methods'] = api_def.request_types
params['callback'] = api_def.endpoint
params['need_object_parser'] = api_def.need_object_parser
self.mutable_router.add_handler(
f"{self._route_prefix}{api_def.uri}", DynamicRequestHandler, params
)
self.registered_base_handlers.append(api_def.uri)
for name, transport in self.api_transports.items():
transport.register_api_handler(api_def)
for api_transport in self.api_transports.values():
api_transport.register_api_handler(api_def)
def register_local_handler(
self,
uri: str,
request_methods: List[str],
request_types: Union[List[str], RequestType],
callback: APICallback,
transports: List[str] = ALL_TRANSPORTS,
transports: Union[List[str], TransportType] = TransportType.all(),
wrap_result: bool = True,
content_type: Optional[str] = None
) -> None:
if uri in self.registered_base_handlers:
return
if isinstance(request_types, list):
request_types = RequestType.from_string_list(request_types)
if isinstance(transports, list):
transports = TransportType.from_string_list(transports)
api_def = self._create_api_definition(
uri, request_methods, callback, transports=transports)
if "http" in transports:
logging.info(
f"Registering HTTP Endpoint: "
f"({' '.join(request_methods)}) {uri}")
uri, request_types, callback, transports=transports
)
if TransportType.HTTP in transports:
logging.info(f"Registering HTTP Endpoint: ({request_types}) {uri}")
params: dict[str, Any] = {}
params['methods'] = request_methods
params['methods'] = request_types
params['callback'] = callback
params['wrap_result'] = wrap_result
params['is_remote'] = False
@ -376,9 +383,9 @@ class MoonrakerApp:
f"{self._route_prefix}{uri}", DynamicRequestHandler, params
)
self.registered_base_handlers.append(uri)
for name, transport in self.api_transports.items():
if name in transports:
transport.register_api_handler(api_def)
for trtype, api_transport in self.api_transports.items():
if trtype in transports:
api_transport.register_api_handler(api_def)
def register_static_file_handler(
self, pattern: str, file_path: str, force: bool = False
@ -415,9 +422,9 @@ class MoonrakerApp:
def register_debug_handler(
self,
uri: str,
request_methods: List[str],
request_types: Union[List[str], RequestType],
callback: APICallback,
transports: List[str] = ALL_TRANSPORTS,
transports: Union[List[str], TransportType] = TransportType.all(),
wrap_result: bool = True
) -> None:
if not self.server.is_debug_enabled():
@ -427,26 +434,30 @@ class MoonrakerApp:
"Debug Endpoints must be registerd in the '/debug' path"
)
self.register_local_handler(
uri, request_methods, callback, transports, wrap_result
uri, request_types, callback, transports, wrap_result
)
def remove_handler(self, endpoint: str) -> None:
api_def = self.api_cache.pop(endpoint, None)
if api_def is not None:
self.mutable_router.remove_handler(api_def.uri)
for name, transport in self.api_transports.items():
transport.remove_api_handler(api_def)
for api_transport in self.api_transports.values():
api_transport.remove_api_handler(api_def)
def _create_api_definition(
self,
endpoint: str,
request_methods: List[str] = [],
request_types: Union[List[str], RequestType],
callback: Optional[APICallback] = None,
transports: List[str] = ALL_TRANSPORTS
transports: Union[List[str], TransportType] = TransportType.all(),
) -> APIDefinition:
is_remote = callback is None
if endpoint in self.api_cache:
return self.api_cache[endpoint]
if isinstance(request_types, list):
request_types = RequestType.from_string_list(request_types)
if isinstance(transports, list):
transports = TransportType.from_string_list(transports)
if endpoint[0] == '/':
uri = endpoint
elif is_remote:
@ -459,23 +470,25 @@ class MoonrakerApp:
# requests execute the same callback, thus they resolve to
# only a single websocket method.
jrpc_methods.append(uri[1:].replace('/', '.'))
request_methods = ['GET', 'POST']
request_types = RequestType.GET | RequestType.POST
else:
name_parts = uri[1:].split('/')
if len(request_methods) > 1:
for req_mthd in request_methods:
func_name = req_mthd.lower() + "_" + name_parts[-1]
if len(request_types) > 1:
for rtype in request_types:
func_name = rtype.name.lower() + "_" + name_parts[-1]
jrpc_methods.append(".".join(
name_parts[:-1] + [func_name]))
else:
jrpc_methods.append(".".join(name_parts))
if not is_remote and len(request_methods) != len(jrpc_methods):
if not is_remote and len(request_types) != len(jrpc_methods):
raise self.server.error(
"Invalid API definition. Number of websocket methods must "
"match the number of request methods")
need_object_parser = endpoint.startswith("objects/")
api_def = APIDefinition(endpoint, uri, jrpc_methods, request_methods,
transports, callback, need_object_parser)
api_def = APIDefinition(
endpoint, uri, jrpc_methods, request_types,
transports, callback, need_object_parser
)
self.api_cache[endpoint] = api_def
return api_def
@ -605,7 +618,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
def initialize(
self,
callback: Union[str, Callable[[WebRequest], Coroutine]] = "",
methods: List[str] = [],
methods: RequestType = RequestType((0)),
need_object_parser: bool = False,
is_remote: bool = True,
wrap_result: bool = True,
@ -698,44 +711,46 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
logging.debug(f"{header}::{resp}")
async def get(self, *args, **kwargs) -> None:
await self._process_http_request()
await self._process_http_request(RequestType.GET)
async def post(self, *args, **kwargs) -> None:
await self._process_http_request()
await self._process_http_request(RequestType.POST)
async def delete(self, *args, **kwargs) -> None:
await self._process_http_request()
await self._process_http_request(RequestType.DELETE)
async def _do_local_request(self,
args: Dict[str, Any],
conn: Optional[WebSocket]
) -> Any:
async def _do_local_request(
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
) -> Any:
assert callable(self.callback)
return await self.callback(
WebRequest(self.endpoint, args, self.request.method,
conn=conn, ip_addr=self.request.remote_ip or "",
user=self.current_user))
WebRequest(
self.endpoint, args, req_type, conn=conn,
ip_addr=self.request.remote_ip or "", user=self.current_user
)
)
async def _do_remote_request(self,
args: Dict[str, Any],
conn: Optional[WebSocket]
) -> Any:
async def _do_remote_request(
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
) -> Any:
assert isinstance(self.callback, str)
klippy: Klippy = self.server.lookup_component("klippy_connection")
return await klippy.request(
WebRequest(self.callback, args, conn=conn,
ip_addr=self.request.remote_ip or "",
user=self.current_user))
WebRequest(
self.callback, args, req_type, conn=conn,
ip_addr=self.request.remote_ip or "", user=self.current_user
)
)
async def _process_http_request(self) -> None:
if self.request.method not in self.methods:
async def _process_http_request(self, req_type: RequestType) -> None:
if req_type not in self.methods:
raise tornado.web.HTTPError(405)
conn = self.get_associated_websocket()
args = self.parse_args()
req = f"{self.request.method} {self.request.path}"
self._log_debug(f"HTTP Request::{req}", args)
try:
result = await self._do_request(args, conn)
result = await self._do_request(args, conn, req_type)
except ServerError as e:
raise tornado.web.HTTPError(
e.status_code, reason=str(e)) from e

View File

@ -5,9 +5,11 @@
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import sys
import ipaddress
import logging
import copy
from enum import Flag, auto
from .utils import ServerError, Sentinel
from .utils import json_wrapper as jsonw
@ -33,34 +35,82 @@ if TYPE_CHECKING:
from asyncio import Future
_T = TypeVar("_T")
_C = TypeVar("_C", str, bool, float, int)
_F = TypeVar("_F", bound="ExtendedFlag")
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
ConvType = Union[str, bool, float, int]
ArgVal = Union[None, int, float, bool, str]
RPCCallback = Callable[..., Coroutine]
AuthComp = Optional[Authorization]
class ExtendedFlag(Flag):
@classmethod
def from_string(cls: Type[_F], flag_name: str) -> _F:
str_name = flag_name.upper()
for name, member in cls.__members__.items():
if name == str_name:
return cls(member.value)
raise ValueError(f"No flag member named {flag_name}")
@classmethod
def from_string_list(cls: Type[_F], flag_list: List[str]) -> _F:
ret = cls(0)
for flag in flag_list:
flag = flag.upper()
ret |= cls.from_string(flag)
return ret
@classmethod
def all(cls: Type[_F]) -> _F:
return ~cls(0)
if sys.version_info < (3, 11):
def __len__(self) -> int:
return bin(self._value_).count("1")
def __iter__(self):
for i in range(self._value_.bit_length()):
val = 1 << i
if val & self._value_ == val:
yield self.__class__(val)
class RequestType(ExtendedFlag):
"""
The Request Type is also known as the "Request Method" for
HTTP/REST APIs. The use of "Request Method" nomenclature
is discouraged in Moonraker as it could be confused with
the JSON-RPC "method" field.
"""
GET = auto()
POST = auto()
DELETE = auto()
class TransportType(ExtendedFlag):
HTTP = auto()
WEBSOCKET = auto()
MQTT = auto()
INTERNAL = auto()
class Subscribable:
def send_status(self,
status: Dict[str, Any],
eventtime: float
) -> None:
def send_status(
self, status: Dict[str, Any], eventtime: float
) -> None:
raise NotImplementedError
class APIDefinition:
def __init__(self,
endpoint: str,
http_uri: str,
jrpc_methods: List[str],
request_methods: Union[str, List[str]],
transports: List[str],
callback: Optional[Callable[[WebRequest], Coroutine]],
need_object_parser: bool):
def __init__(
self,
endpoint: str,
http_uri: str,
jrpc_methods: List[str],
request_types: RequestType,
transports: TransportType,
callback: Optional[Callable[[WebRequest], Coroutine]],
need_object_parser: bool
) -> None:
self.endpoint = endpoint
self.uri = http_uri
self.jrpc_methods = jrpc_methods
if not isinstance(request_methods, list):
request_methods = [request_methods]
self.request_methods = request_methods
self.request_types = request_types
self.supported_transports = transports
self.callback = callback
self.need_object_parser = need_object_parser
@ -256,16 +306,17 @@ class BaseRemoteConnection(Subscribable):
class WebRequest:
def __init__(self,
endpoint: str,
args: Dict[str, Any],
action: Optional[str] = "",
conn: Optional[Subscribable] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
def __init__(
self,
endpoint: str,
args: Dict[str, Any],
request_type: RequestType = RequestType(0),
conn: Optional[Subscribable] = None,
ip_addr: str = "",
user: Optional[Dict[str, Any]] = None
) -> None:
self.endpoint = endpoint
self.action = action or ""
self.request_type = request_type
self.args = args
self.conn = conn
self.ip_addr: Optional[IPUnion] = None
@ -278,8 +329,11 @@ class WebRequest:
def get_endpoint(self) -> str:
return self.endpoint
def get_request_type(self) -> RequestType:
return self.request_type
def get_action(self) -> str:
return self.action
return self.request_type.name or ""
def get_args(self) -> Dict[str, Any]:
return self.args

View File

@ -12,7 +12,14 @@ import pathlib
import ssl
from collections import deque
import paho.mqtt.client as paho_mqtt
from ..common import Subscribable, WebRequest, APITransport, JsonRPC
from ..common import (
TransportType,
RequestType,
Subscribable,
WebRequest,
APITransport,
JsonRPC
)
from ..utils import json_wrapper as jsonw
# Annotation imports
@ -330,9 +337,9 @@ class MQTTClient(APITransport, Subscribable):
self.timestamp_deque: Deque = deque(maxlen=20)
self.api_qos = config.getint('api_qos', self.qos)
if config.getboolean("enable_moonraker_api", True):
api_cache = self.server.register_api_transport("mqtt", self)
api_cache = self.server.register_api_transport(TransportType.MQTT, self)
for api_def in api_cache.values():
if "mqtt" in api_def.supported_transports:
if TransportType.MQTT in api_def.supported_transports:
self.register_api_handler(api_def)
self.subscribe_topic(self.api_request_topic,
self._process_api_request,
@ -366,7 +373,8 @@ class MQTTClient(APITransport, Subscribable):
args = {'objects': self.status_objs}
try:
await self.klippy.request(
WebRequest("objects/subscribe", args, conn=self))
WebRequest("objects/subscribe", args, conn=self)
)
except self.server.error:
pass
@ -683,10 +691,10 @@ class MQTTClient(APITransport, Subscribable):
self.json_rpc.register_method(mqtt_method, rpc_cb)
else:
# Local API, uses local callback
for mqtt_method, req_method in \
zip(api_def.jrpc_methods, api_def.request_methods):
req_types = api_def.request_types
for mqtt_method, req_type in zip(api_def.jrpc_methods, req_types):
rpc_cb = self._generate_local_callback(
api_def.endpoint, req_method, api_def.callback)
api_def.endpoint, req_type, api_def.callback)
self.json_rpc.register_method(mqtt_method, rpc_cb)
logging.info(
"Registering MQTT JSON-RPC methods: "
@ -698,12 +706,12 @@ class MQTTClient(APITransport, Subscribable):
def _generate_local_callback(self,
endpoint: str,
request_method: str,
request_type: RequestType,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
self._check_timestamp(args)
result = await callback(WebRequest(endpoint, args, request_method))
result = await callback(WebRequest(endpoint, args, request_type))
return result
return func

View File

@ -11,11 +11,12 @@ import asyncio
from tornado.websocket import WebSocketHandler, WebSocketClosedError
from tornado.web import HTTPError
from .common import (
RequestType,
WebRequest,
BaseRemoteConnection,
APITransport,
APIDefinition,
JsonRPC
JsonRPC,
)
from .utils import ServerError
@ -80,15 +81,14 @@ class WebsocketManager(APITransport):
# Remote API, uses RPC to reach out to Klippy
ws_method = api_def.jrpc_methods[0]
rpc_cb = self._generate_callback(
api_def.endpoint, "", klippy.request
api_def.endpoint, RequestType(0), klippy.request
)
self.rpc.register_method(ws_method, rpc_cb)
else:
# Local API, uses local callback
for ws_method, req_method in \
zip(api_def.jrpc_methods, api_def.request_methods):
for ws_method, req_type in zip(api_def.jrpc_methods, api_def.request_types):
rpc_cb = self._generate_callback(
api_def.endpoint, req_method, api_def.callback
api_def.endpoint, req_type, api_def.callback
)
self.rpc.register_method(ws_method, rpc_cb)
logging.info(
@ -103,15 +103,18 @@ class WebsocketManager(APITransport):
def _generate_callback(
self,
endpoint: str,
request_method: str,
request_type: RequestType,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
sc: BaseRemoteConnection = args.pop("_socket_")
sc.check_authenticated(path=endpoint)
result = await callback(
WebRequest(endpoint, args, request_method, sc,
ip_addr=sc.ip_addr, user=sc.user_info))
WebRequest(
endpoint, args, request_type, sc,
ip_addr=sc.ip_addr, user=sc.user_info
)
)
return result
return func