extensions: serve JSON-RPC API over a unix socket
Support unix connections with full access to all JSON-RPC APIs. Internally these connections are treated as websocket connections, however the underlying transport protocol is simplfied. Packets are JSON encoded objects terminated with an ETX character. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
53129bef7e
commit
b2d109a840
|
@ -4,8 +4,12 @@
|
|||
#
|
||||
# This file may be distributed under the terms of the GNU GPLv3 license.
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import pathlib
|
||||
import logging
|
||||
import json
|
||||
from websockets import BaseSocketClient
|
||||
|
||||
from utils import get_unix_peer_credentials
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
|
@ -18,13 +22,17 @@ from typing import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from moonraker import Server
|
||||
from confighelper import ConfigHelper
|
||||
from websockets import WebRequest
|
||||
|
||||
UNIX_BUFFER_LIMIT = 20 * 1024 * 1024
|
||||
|
||||
class ExtensionManager:
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.agents: Dict[str, BaseSocketClient] = {}
|
||||
self.uds_server: Optional[asyncio.Server] = None
|
||||
self.server.register_endpoint(
|
||||
"/connection/send_event", ["POST"], self._handle_agent_event,
|
||||
transports=["websocket"]
|
||||
|
@ -103,5 +111,124 @@ class ExtensionManager:
|
|||
conn = self.agents[agent]
|
||||
return await conn.call_method(method, args)
|
||||
|
||||
async def start_unix_server(self) -> None:
|
||||
data_path = pathlib.Path(self.server.get_app_args()["data_path"])
|
||||
comms_path = data_path.joinpath("comms")
|
||||
if not comms_path.exists():
|
||||
comms_path.mkdir()
|
||||
sock_path = comms_path.joinpath("moonraker.sock")
|
||||
logging.info(f"Creating Unix Domain Socket at '{sock_path}'")
|
||||
self.uds_server = await asyncio.start_unix_server(
|
||||
self.on_unix_socket_connected, sock_path, limit=UNIX_BUFFER_LIMIT
|
||||
)
|
||||
|
||||
def on_unix_socket_connected(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
peercred = get_unix_peer_credentials(writer, "Unix Client Connection")
|
||||
UnixSocketClient(self.server, reader, writer, peercred)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.uds_server is not None:
|
||||
self.uds_server.close()
|
||||
await self.uds_server.wait_closed()
|
||||
self.uds_server = None
|
||||
|
||||
class UnixSocketClient(BaseSocketClient):
|
||||
def __init__(
|
||||
self,
|
||||
server: Server,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
peercred: Dict[str, int]
|
||||
) -> None:
|
||||
self.on_create(server)
|
||||
self.writer = writer
|
||||
self._peer_cred = peercred
|
||||
self._connected_time = self.eventloop.get_loop_time()
|
||||
pid = self._peer_cred.get("process_id")
|
||||
uid = self._peer_cred.get("user_id")
|
||||
gid = self._peer_cred.get("group_id")
|
||||
self.wsm.add_client(self)
|
||||
logging.info(
|
||||
f"Unix Socket Opened - Client ID: {self.uid}, "
|
||||
f"Process ID: {pid}, User ID: {uid}, Group ID: {gid}"
|
||||
)
|
||||
self.eventloop.register_callback(self._read_messages, reader)
|
||||
|
||||
async def _read_messages(self, reader: asyncio.StreamReader) -> None:
|
||||
errors_remaining: int = 10
|
||||
while not reader.at_eof():
|
||||
try:
|
||||
data = await reader.readuntil(b'\x03')
|
||||
decoded = data[:-1].decode(encoding="utf-8")
|
||||
except (ConnectionError, asyncio.IncompleteReadError):
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logging.exception("Unix Client Stream Read Cancelled")
|
||||
raise
|
||||
except Exception:
|
||||
logging.exception("Unix Client Stream Read Error")
|
||||
errors_remaining -= 1
|
||||
if not errors_remaining or self.is_closed:
|
||||
break
|
||||
continue
|
||||
errors_remaining = 10
|
||||
self.eventloop.register_callback(self._process_message, decoded)
|
||||
logging.debug("Unix Socket Disconnection From _read_messages()")
|
||||
await self._on_close(reason="Read Exit")
|
||||
|
||||
async def write_to_socket(
|
||||
self, message: Union[str, Dict[str, Any]]
|
||||
) -> None:
|
||||
if isinstance(message, dict):
|
||||
data = json.dumps(message).encode() + b"\x03"
|
||||
else:
|
||||
data = message.encode() + b"\x03"
|
||||
try:
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logging.debug("Unix Socket Disconnection From write_to_socket()")
|
||||
await self._on_close(reason="Write Exception")
|
||||
|
||||
async def _on_close(
|
||||
self,
|
||||
code: Optional[int] = None,
|
||||
reason: Optional[str] = None
|
||||
) -> None:
|
||||
if self.is_closed:
|
||||
return
|
||||
self.is_closed = True
|
||||
if not self.writer.is_closing():
|
||||
self.writer.close()
|
||||
try:
|
||||
await self.writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self.message_buf = []
|
||||
for resp in self.pending_responses.values():
|
||||
resp.set_exception(
|
||||
self.server.error("Client Socket Disconnected", 500)
|
||||
)
|
||||
self.pending_responses = {}
|
||||
logging.info(
|
||||
f"Unix Socket Closed: ID: {self.uid}, "
|
||||
f"Close Code: {code}, "
|
||||
f"Close Reason: {reason}"
|
||||
)
|
||||
if self._client_data["type"] == "agent":
|
||||
extensions: ExtensionManager
|
||||
extensions = self.server.lookup_component("extensions")
|
||||
extensions.remove_agent(self)
|
||||
self.wsm.remove_client(self)
|
||||
|
||||
def close_socket(self, code: int, reason: str) -> None:
|
||||
if not self.is_closed:
|
||||
self.eventloop.register_callback(self._on_close, code, reason)
|
||||
|
||||
|
||||
def load_component(config: ConfigHelper) -> ExtensionManager:
|
||||
return ExtensionManager(config)
|
||||
|
|
|
@ -41,6 +41,7 @@ if TYPE_CHECKING:
|
|||
from websockets import WebRequest, WebsocketManager
|
||||
from components.file_manager.file_manager import FileManager
|
||||
from components.machine import Machine
|
||||
from components.extensions import ExtensionManager
|
||||
FlexCallback = Callable[..., Optional[Coroutine]]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
@ -174,6 +175,10 @@ class Server:
|
|||
await self.start_server()
|
||||
|
||||
async def start_server(self, connect_to_klippy: bool = True) -> None:
|
||||
# Open Unix Socket Server
|
||||
extm: ExtensionManager = self.lookup_component("extensions")
|
||||
await extm.start_unix_server()
|
||||
|
||||
# Start HTTP Server
|
||||
logging.info(
|
||||
f"Starting Moonraker on ({self.host}, {self.port}), "
|
||||
|
|
Loading…
Reference in New Issue