common: add RequestType and TransportType flags
These flags replace strings as constants used to register and identify Request Types (ie: GET, POST) and API Transport Types. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
42357891a3
commit
7deb9fac4c
145
moonraker/app.py
145
moonraker/app.py
|
@ -22,8 +22,14 @@ from tornado.escape import url_unescape, url_escape
|
|||
from tornado.routing import Rule, PathMatches, AnyMatches
|
||||
from tornado.http1connection import HTTP1Connection
|
||||
from tornado.log import access_log
|
||||
from .common import WebRequest, APIDefinition, APITransport
|
||||
from .utils import ServerError, source_info
|
||||
from .common import (
|
||||
WebRequest,
|
||||
APIDefinition,
|
||||
APITransport,
|
||||
TransportType,
|
||||
RequestType
|
||||
)
|
||||
from .utils import json_wrapper as jsonw
|
||||
from .websockets import (
|
||||
WebsocketManager,
|
||||
|
@ -69,7 +75,6 @@ MAX_WS_CONNS_DEFAULT = 50
|
|||
EXCLUDED_ARGS = ["_", "token", "access_token", "connection_id"]
|
||||
AUTHORIZED_EXTS = [".png", ".jpg"]
|
||||
DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log"
|
||||
ALL_TRANSPORTS = ["http", "websocket", "mqtt", "internal"]
|
||||
|
||||
class MutableRouter(tornado.web.ReversibleRuleRouter):
|
||||
def __init__(self, application: MoonrakerApp) -> None:
|
||||
|
@ -115,7 +120,7 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
|
|||
class InternalTransport(APITransport):
|
||||
def __init__(self, server: Server) -> None:
|
||||
self.server = server
|
||||
self.callbacks: Dict[str, Tuple[str, str, APICallback]] = {}
|
||||
self.callbacks: Dict[str, Tuple[str, RequestType, APICallback]] = {}
|
||||
|
||||
def register_api_handler(self, api_def: APIDefinition) -> None:
|
||||
ep = api_def.endpoint
|
||||
|
@ -123,13 +128,12 @@ class InternalTransport(APITransport):
|
|||
if cb is None:
|
||||
# Request to Klippy
|
||||
method = api_def.jrpc_methods[0]
|
||||
action = ""
|
||||
action = RequestType(0)
|
||||
klippy: Klippy = self.server.lookup_component("klippy_connection")
|
||||
cb = klippy.request
|
||||
self.callbacks[method] = (ep, action, cb)
|
||||
else:
|
||||
for method, action in \
|
||||
zip(api_def.jrpc_methods, api_def.request_methods):
|
||||
for method, action in zip(api_def.jrpc_methods, api_def.request_types):
|
||||
self.callbacks[method] = (ep, action, cb)
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
|
@ -143,11 +147,11 @@ class InternalTransport(APITransport):
|
|||
) -> Any:
|
||||
if method_name not in self.callbacks:
|
||||
raise self.server.error(f"No method {method_name} available")
|
||||
ep, action, func = self.callbacks[method_name]
|
||||
ep, req_type, func = self.callbacks[method_name]
|
||||
# Request arguments can be suppplied either through a dict object
|
||||
# or via keyword arugments
|
||||
args = request_arguments or kwargs
|
||||
return await func(WebRequest(ep, dict(args), action))
|
||||
return await func(WebRequest(ep, dict(args), req_type))
|
||||
|
||||
class MoonrakerApp:
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
|
@ -188,9 +192,9 @@ class MoonrakerApp:
|
|||
# Set Up Websocket and Authorization Managers
|
||||
self.wsm = WebsocketManager(self.server)
|
||||
self.internal_transport = InternalTransport(self.server)
|
||||
self.api_transports: Dict[str, APITransport] = {
|
||||
"websocket": self.wsm,
|
||||
"internal": self.internal_transport
|
||||
self.api_transports: Dict[TransportType, APITransport] = {
|
||||
TransportType.WEBSOCKET: self.wsm,
|
||||
TransportType.INTERNAL: self.internal_transport
|
||||
}
|
||||
|
||||
mimetypes.add_type('text/plain', '.log')
|
||||
|
@ -325,49 +329,52 @@ class MoonrakerApp:
|
|||
await self.wsm.close()
|
||||
|
||||
def register_api_transport(
|
||||
self, name: str, transport: APITransport
|
||||
self, trtype: TransportType, api_transport: APITransport
|
||||
) -> Dict[str, APIDefinition]:
|
||||
self.api_transports[name] = transport
|
||||
self.api_transports[trtype] = api_transport
|
||||
return self.api_cache
|
||||
|
||||
def register_remote_handler(self, endpoint: str) -> None:
|
||||
api_def = self._create_api_definition(endpoint)
|
||||
api_def = self._create_api_definition(endpoint, RequestType.GET)
|
||||
if api_def.uri in self.registered_base_handlers:
|
||||
# reserved handler or already registered
|
||||
return
|
||||
logging.info(
|
||||
f"Registering HTTP endpoint: "
|
||||
f"({' '.join(api_def.request_methods)}) {api_def.uri}")
|
||||
f"Registering HTTP endpoint: ({api_def.request_types}) {api_def.uri}"
|
||||
)
|
||||
params: Dict[str, Any] = {}
|
||||
params['methods'] = api_def.request_methods
|
||||
params['methods'] = api_def.request_types
|
||||
params['callback'] = api_def.endpoint
|
||||
params['need_object_parser'] = api_def.need_object_parser
|
||||
self.mutable_router.add_handler(
|
||||
f"{self._route_prefix}{api_def.uri}", DynamicRequestHandler, params
|
||||
)
|
||||
self.registered_base_handlers.append(api_def.uri)
|
||||
for name, transport in self.api_transports.items():
|
||||
transport.register_api_handler(api_def)
|
||||
for api_transport in self.api_transports.values():
|
||||
api_transport.register_api_handler(api_def)
|
||||
|
||||
def register_local_handler(
|
||||
self,
|
||||
uri: str,
|
||||
request_methods: List[str],
|
||||
request_types: Union[List[str], RequestType],
|
||||
callback: APICallback,
|
||||
transports: List[str] = ALL_TRANSPORTS,
|
||||
transports: Union[List[str], TransportType] = TransportType.all(),
|
||||
wrap_result: bool = True,
|
||||
content_type: Optional[str] = None
|
||||
) -> None:
|
||||
if uri in self.registered_base_handlers:
|
||||
return
|
||||
if isinstance(request_types, list):
|
||||
request_types = RequestType.from_string_list(request_types)
|
||||
if isinstance(transports, list):
|
||||
transports = TransportType.from_string_list(transports)
|
||||
api_def = self._create_api_definition(
|
||||
uri, request_methods, callback, transports=transports)
|
||||
if "http" in transports:
|
||||
logging.info(
|
||||
f"Registering HTTP Endpoint: "
|
||||
f"({' '.join(request_methods)}) {uri}")
|
||||
uri, request_types, callback, transports=transports
|
||||
)
|
||||
if TransportType.HTTP in transports:
|
||||
logging.info(f"Registering HTTP Endpoint: ({request_types}) {uri}")
|
||||
params: dict[str, Any] = {}
|
||||
params['methods'] = request_methods
|
||||
params['methods'] = request_types
|
||||
params['callback'] = callback
|
||||
params['wrap_result'] = wrap_result
|
||||
params['is_remote'] = False
|
||||
|
@ -376,9 +383,9 @@ class MoonrakerApp:
|
|||
f"{self._route_prefix}{uri}", DynamicRequestHandler, params
|
||||
)
|
||||
self.registered_base_handlers.append(uri)
|
||||
for name, transport in self.api_transports.items():
|
||||
if name in transports:
|
||||
transport.register_api_handler(api_def)
|
||||
for trtype, api_transport in self.api_transports.items():
|
||||
if trtype in transports:
|
||||
api_transport.register_api_handler(api_def)
|
||||
|
||||
def register_static_file_handler(
|
||||
self, pattern: str, file_path: str, force: bool = False
|
||||
|
@ -415,9 +422,9 @@ class MoonrakerApp:
|
|||
def register_debug_handler(
|
||||
self,
|
||||
uri: str,
|
||||
request_methods: List[str],
|
||||
request_types: Union[List[str], RequestType],
|
||||
callback: APICallback,
|
||||
transports: List[str] = ALL_TRANSPORTS,
|
||||
transports: Union[List[str], TransportType] = TransportType.all(),
|
||||
wrap_result: bool = True
|
||||
) -> None:
|
||||
if not self.server.is_debug_enabled():
|
||||
|
@ -427,26 +434,30 @@ class MoonrakerApp:
|
|||
"Debug Endpoints must be registerd in the '/debug' path"
|
||||
)
|
||||
self.register_local_handler(
|
||||
uri, request_methods, callback, transports, wrap_result
|
||||
uri, request_types, callback, transports, wrap_result
|
||||
)
|
||||
|
||||
def remove_handler(self, endpoint: str) -> None:
|
||||
api_def = self.api_cache.pop(endpoint, None)
|
||||
if api_def is not None:
|
||||
self.mutable_router.remove_handler(api_def.uri)
|
||||
for name, transport in self.api_transports.items():
|
||||
transport.remove_api_handler(api_def)
|
||||
for api_transport in self.api_transports.values():
|
||||
api_transport.remove_api_handler(api_def)
|
||||
|
||||
def _create_api_definition(
|
||||
self,
|
||||
endpoint: str,
|
||||
request_methods: List[str] = [],
|
||||
request_types: Union[List[str], RequestType],
|
||||
callback: Optional[APICallback] = None,
|
||||
transports: List[str] = ALL_TRANSPORTS
|
||||
transports: Union[List[str], TransportType] = TransportType.all(),
|
||||
) -> APIDefinition:
|
||||
is_remote = callback is None
|
||||
if endpoint in self.api_cache:
|
||||
return self.api_cache[endpoint]
|
||||
if isinstance(request_types, list):
|
||||
request_types = RequestType.from_string_list(request_types)
|
||||
if isinstance(transports, list):
|
||||
transports = TransportType.from_string_list(transports)
|
||||
if endpoint[0] == '/':
|
||||
uri = endpoint
|
||||
elif is_remote:
|
||||
|
@ -459,23 +470,25 @@ class MoonrakerApp:
|
|||
# requests execute the same callback, thus they resolve to
|
||||
# only a single websocket method.
|
||||
jrpc_methods.append(uri[1:].replace('/', '.'))
|
||||
request_methods = ['GET', 'POST']
|
||||
request_types = RequestType.GET | RequestType.POST
|
||||
else:
|
||||
name_parts = uri[1:].split('/')
|
||||
if len(request_methods) > 1:
|
||||
for req_mthd in request_methods:
|
||||
func_name = req_mthd.lower() + "_" + name_parts[-1]
|
||||
if len(request_types) > 1:
|
||||
for rtype in request_types:
|
||||
func_name = rtype.name.lower() + "_" + name_parts[-1]
|
||||
jrpc_methods.append(".".join(
|
||||
name_parts[:-1] + [func_name]))
|
||||
else:
|
||||
jrpc_methods.append(".".join(name_parts))
|
||||
if not is_remote and len(request_methods) != len(jrpc_methods):
|
||||
if not is_remote and len(request_types) != len(jrpc_methods):
|
||||
raise self.server.error(
|
||||
"Invalid API definition. Number of websocket methods must "
|
||||
"match the number of request methods")
|
||||
need_object_parser = endpoint.startswith("objects/")
|
||||
api_def = APIDefinition(endpoint, uri, jrpc_methods, request_methods,
|
||||
transports, callback, need_object_parser)
|
||||
api_def = APIDefinition(
|
||||
endpoint, uri, jrpc_methods, request_types,
|
||||
transports, callback, need_object_parser
|
||||
)
|
||||
self.api_cache[endpoint] = api_def
|
||||
return api_def
|
||||
|
||||
|
@ -605,7 +618,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
|
|||
def initialize(
|
||||
self,
|
||||
callback: Union[str, Callable[[WebRequest], Coroutine]] = "",
|
||||
methods: List[str] = [],
|
||||
methods: RequestType = RequestType((0)),
|
||||
need_object_parser: bool = False,
|
||||
is_remote: bool = True,
|
||||
wrap_result: bool = True,
|
||||
|
@ -698,44 +711,46 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
|
|||
logging.debug(f"{header}::{resp}")
|
||||
|
||||
async def get(self, *args, **kwargs) -> None:
|
||||
await self._process_http_request()
|
||||
await self._process_http_request(RequestType.GET)
|
||||
|
||||
async def post(self, *args, **kwargs) -> None:
|
||||
await self._process_http_request()
|
||||
await self._process_http_request(RequestType.POST)
|
||||
|
||||
async def delete(self, *args, **kwargs) -> None:
|
||||
await self._process_http_request()
|
||||
await self._process_http_request(RequestType.DELETE)
|
||||
|
||||
async def _do_local_request(self,
|
||||
args: Dict[str, Any],
|
||||
conn: Optional[WebSocket]
|
||||
) -> Any:
|
||||
async def _do_local_request(
|
||||
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
|
||||
) -> Any:
|
||||
assert callable(self.callback)
|
||||
return await self.callback(
|
||||
WebRequest(self.endpoint, args, self.request.method,
|
||||
conn=conn, ip_addr=self.request.remote_ip or "",
|
||||
user=self.current_user))
|
||||
WebRequest(
|
||||
self.endpoint, args, req_type, conn=conn,
|
||||
ip_addr=self.request.remote_ip or "", user=self.current_user
|
||||
)
|
||||
)
|
||||
|
||||
async def _do_remote_request(self,
|
||||
args: Dict[str, Any],
|
||||
conn: Optional[WebSocket]
|
||||
) -> Any:
|
||||
async def _do_remote_request(
|
||||
self, args: Dict[str, Any], conn: Optional[WebSocket], req_type: RequestType
|
||||
) -> Any:
|
||||
assert isinstance(self.callback, str)
|
||||
klippy: Klippy = self.server.lookup_component("klippy_connection")
|
||||
return await klippy.request(
|
||||
WebRequest(self.callback, args, conn=conn,
|
||||
ip_addr=self.request.remote_ip or "",
|
||||
user=self.current_user))
|
||||
WebRequest(
|
||||
self.callback, args, req_type, conn=conn,
|
||||
ip_addr=self.request.remote_ip or "", user=self.current_user
|
||||
)
|
||||
)
|
||||
|
||||
async def _process_http_request(self) -> None:
|
||||
if self.request.method not in self.methods:
|
||||
async def _process_http_request(self, req_type: RequestType) -> None:
|
||||
if req_type not in self.methods:
|
||||
raise tornado.web.HTTPError(405)
|
||||
conn = self.get_associated_websocket()
|
||||
args = self.parse_args()
|
||||
req = f"{self.request.method} {self.request.path}"
|
||||
self._log_debug(f"HTTP Request::{req}", args)
|
||||
try:
|
||||
result = await self._do_request(args, conn)
|
||||
result = await self._do_request(args, conn, req_type)
|
||||
except ServerError as e:
|
||||
raise tornado.web.HTTPError(
|
||||
e.status_code, reason=str(e)) from e
|
||||
|
|
|
@ -5,9 +5,11 @@
|
|||
# This file may be distributed under the terms of the GNU GPLv3 license
|
||||
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
import ipaddress
|
||||
import logging
|
||||
import copy
|
||||
from enum import Flag, auto
|
||||
from .utils import ServerError, Sentinel
|
||||
from .utils import json_wrapper as jsonw
|
||||
|
||||
|
@ -33,34 +35,82 @@ if TYPE_CHECKING:
|
|||
from asyncio import Future
|
||||
_T = TypeVar("_T")
|
||||
_C = TypeVar("_C", str, bool, float, int)
|
||||
_F = TypeVar("_F", bound="ExtendedFlag")
|
||||
IPUnion = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
ConvType = Union[str, bool, float, int]
|
||||
ArgVal = Union[None, int, float, bool, str]
|
||||
RPCCallback = Callable[..., Coroutine]
|
||||
AuthComp = Optional[Authorization]
|
||||
|
||||
class ExtendedFlag(Flag):
|
||||
@classmethod
|
||||
def from_string(cls: Type[_F], flag_name: str) -> _F:
|
||||
str_name = flag_name.upper()
|
||||
for name, member in cls.__members__.items():
|
||||
if name == str_name:
|
||||
return cls(member.value)
|
||||
raise ValueError(f"No flag member named {flag_name}")
|
||||
|
||||
@classmethod
|
||||
def from_string_list(cls: Type[_F], flag_list: List[str]) -> _F:
|
||||
ret = cls(0)
|
||||
for flag in flag_list:
|
||||
flag = flag.upper()
|
||||
ret |= cls.from_string(flag)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def all(cls: Type[_F]) -> _F:
|
||||
return ~cls(0)
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
def __len__(self) -> int:
|
||||
return bin(self._value_).count("1")
|
||||
|
||||
def __iter__(self):
|
||||
for i in range(self._value_.bit_length()):
|
||||
val = 1 << i
|
||||
if val & self._value_ == val:
|
||||
yield self.__class__(val)
|
||||
|
||||
class RequestType(ExtendedFlag):
|
||||
"""
|
||||
The Request Type is also known as the "Request Method" for
|
||||
HTTP/REST APIs. The use of "Request Method" nomenclature
|
||||
is discouraged in Moonraker as it could be confused with
|
||||
the JSON-RPC "method" field.
|
||||
"""
|
||||
GET = auto()
|
||||
POST = auto()
|
||||
DELETE = auto()
|
||||
|
||||
class TransportType(ExtendedFlag):
|
||||
HTTP = auto()
|
||||
WEBSOCKET = auto()
|
||||
MQTT = auto()
|
||||
INTERNAL = auto()
|
||||
|
||||
class Subscribable:
|
||||
def send_status(self,
|
||||
status: Dict[str, Any],
|
||||
eventtime: float
|
||||
) -> None:
|
||||
def send_status(
|
||||
self, status: Dict[str, Any], eventtime: float
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
class APIDefinition:
|
||||
def __init__(self,
|
||||
endpoint: str,
|
||||
http_uri: str,
|
||||
jrpc_methods: List[str],
|
||||
request_methods: Union[str, List[str]],
|
||||
transports: List[str],
|
||||
callback: Optional[Callable[[WebRequest], Coroutine]],
|
||||
need_object_parser: bool):
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
http_uri: str,
|
||||
jrpc_methods: List[str],
|
||||
request_types: RequestType,
|
||||
transports: TransportType,
|
||||
callback: Optional[Callable[[WebRequest], Coroutine]],
|
||||
need_object_parser: bool
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.uri = http_uri
|
||||
self.jrpc_methods = jrpc_methods
|
||||
if not isinstance(request_methods, list):
|
||||
request_methods = [request_methods]
|
||||
self.request_methods = request_methods
|
||||
self.request_types = request_types
|
||||
self.supported_transports = transports
|
||||
self.callback = callback
|
||||
self.need_object_parser = need_object_parser
|
||||
|
@ -256,16 +306,17 @@ class BaseRemoteConnection(Subscribable):
|
|||
|
||||
|
||||
class WebRequest:
|
||||
def __init__(self,
|
||||
endpoint: str,
|
||||
args: Dict[str, Any],
|
||||
action: Optional[str] = "",
|
||||
conn: Optional[Subscribable] = None,
|
||||
ip_addr: str = "",
|
||||
user: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
args: Dict[str, Any],
|
||||
request_type: RequestType = RequestType(0),
|
||||
conn: Optional[Subscribable] = None,
|
||||
ip_addr: str = "",
|
||||
user: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
self.endpoint = endpoint
|
||||
self.action = action or ""
|
||||
self.request_type = request_type
|
||||
self.args = args
|
||||
self.conn = conn
|
||||
self.ip_addr: Optional[IPUnion] = None
|
||||
|
@ -278,8 +329,11 @@ class WebRequest:
|
|||
def get_endpoint(self) -> str:
|
||||
return self.endpoint
|
||||
|
||||
def get_request_type(self) -> RequestType:
|
||||
return self.request_type
|
||||
|
||||
def get_action(self) -> str:
|
||||
return self.action
|
||||
return self.request_type.name or ""
|
||||
|
||||
def get_args(self) -> Dict[str, Any]:
|
||||
return self.args
|
||||
|
|
|
@ -12,7 +12,14 @@ import pathlib
|
|||
import ssl
|
||||
from collections import deque
|
||||
import paho.mqtt.client as paho_mqtt
|
||||
from ..common import Subscribable, WebRequest, APITransport, JsonRPC
|
||||
from ..common import (
|
||||
TransportType,
|
||||
RequestType,
|
||||
Subscribable,
|
||||
WebRequest,
|
||||
APITransport,
|
||||
JsonRPC
|
||||
)
|
||||
from ..utils import json_wrapper as jsonw
|
||||
|
||||
# Annotation imports
|
||||
|
@ -330,9 +337,9 @@ class MQTTClient(APITransport, Subscribable):
|
|||
self.timestamp_deque: Deque = deque(maxlen=20)
|
||||
self.api_qos = config.getint('api_qos', self.qos)
|
||||
if config.getboolean("enable_moonraker_api", True):
|
||||
api_cache = self.server.register_api_transport("mqtt", self)
|
||||
api_cache = self.server.register_api_transport(TransportType.MQTT, self)
|
||||
for api_def in api_cache.values():
|
||||
if "mqtt" in api_def.supported_transports:
|
||||
if TransportType.MQTT in api_def.supported_transports:
|
||||
self.register_api_handler(api_def)
|
||||
self.subscribe_topic(self.api_request_topic,
|
||||
self._process_api_request,
|
||||
|
@ -366,7 +373,8 @@ class MQTTClient(APITransport, Subscribable):
|
|||
args = {'objects': self.status_objs}
|
||||
try:
|
||||
await self.klippy.request(
|
||||
WebRequest("objects/subscribe", args, conn=self))
|
||||
WebRequest("objects/subscribe", args, conn=self)
|
||||
)
|
||||
except self.server.error:
|
||||
pass
|
||||
|
||||
|
@ -683,10 +691,10 @@ class MQTTClient(APITransport, Subscribable):
|
|||
self.json_rpc.register_method(mqtt_method, rpc_cb)
|
||||
else:
|
||||
# Local API, uses local callback
|
||||
for mqtt_method, req_method in \
|
||||
zip(api_def.jrpc_methods, api_def.request_methods):
|
||||
req_types = api_def.request_types
|
||||
for mqtt_method, req_type in zip(api_def.jrpc_methods, req_types):
|
||||
rpc_cb = self._generate_local_callback(
|
||||
api_def.endpoint, req_method, api_def.callback)
|
||||
api_def.endpoint, req_type, api_def.callback)
|
||||
self.json_rpc.register_method(mqtt_method, rpc_cb)
|
||||
logging.info(
|
||||
"Registering MQTT JSON-RPC methods: "
|
||||
|
@ -698,12 +706,12 @@ class MQTTClient(APITransport, Subscribable):
|
|||
|
||||
def _generate_local_callback(self,
|
||||
endpoint: str,
|
||||
request_method: str,
|
||||
request_type: RequestType,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
self._check_timestamp(args)
|
||||
result = await callback(WebRequest(endpoint, args, request_method))
|
||||
result = await callback(WebRequest(endpoint, args, request_type))
|
||||
return result
|
||||
return func
|
||||
|
||||
|
|
|
@ -11,11 +11,12 @@ import asyncio
|
|||
from tornado.websocket import WebSocketHandler, WebSocketClosedError
|
||||
from tornado.web import HTTPError
|
||||
from .common import (
|
||||
RequestType,
|
||||
WebRequest,
|
||||
BaseRemoteConnection,
|
||||
APITransport,
|
||||
APIDefinition,
|
||||
JsonRPC
|
||||
JsonRPC,
|
||||
)
|
||||
from .utils import ServerError
|
||||
|
||||
|
@ -80,15 +81,14 @@ class WebsocketManager(APITransport):
|
|||
# Remote API, uses RPC to reach out to Klippy
|
||||
ws_method = api_def.jrpc_methods[0]
|
||||
rpc_cb = self._generate_callback(
|
||||
api_def.endpoint, "", klippy.request
|
||||
api_def.endpoint, RequestType(0), klippy.request
|
||||
)
|
||||
self.rpc.register_method(ws_method, rpc_cb)
|
||||
else:
|
||||
# Local API, uses local callback
|
||||
for ws_method, req_method in \
|
||||
zip(api_def.jrpc_methods, api_def.request_methods):
|
||||
for ws_method, req_type in zip(api_def.jrpc_methods, api_def.request_types):
|
||||
rpc_cb = self._generate_callback(
|
||||
api_def.endpoint, req_method, api_def.callback
|
||||
api_def.endpoint, req_type, api_def.callback
|
||||
)
|
||||
self.rpc.register_method(ws_method, rpc_cb)
|
||||
logging.info(
|
||||
|
@ -103,15 +103,18 @@ class WebsocketManager(APITransport):
|
|||
def _generate_callback(
|
||||
self,
|
||||
endpoint: str,
|
||||
request_method: str,
|
||||
request_type: RequestType,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
sc: BaseRemoteConnection = args.pop("_socket_")
|
||||
sc.check_authenticated(path=endpoint)
|
||||
result = await callback(
|
||||
WebRequest(endpoint, args, request_method, sc,
|
||||
ip_addr=sc.ip_addr, user=sc.user_info))
|
||||
WebRequest(
|
||||
endpoint, args, request_type, sc,
|
||||
ip_addr=sc.ip_addr, user=sc.user_info
|
||||
)
|
||||
)
|
||||
return result
|
||||
return func
|
||||
|
||||
|
|
Loading…
Reference in New Issue