From 8546cd6ac5ab4704746367eaab991a3d8b26258b Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Thu, 13 Jan 2022 12:34:06 -0500 Subject: [PATCH] shell_command: use a custom protocol for callbacks Rather than override the Process class instead create a custom protocol that forwards data over callbacks. Signed-off-by: Eric Callahan --- moonraker/components/shell_command.py | 216 ++++++++++++++------------ 1 file changed, 118 insertions(+), 98 deletions(-) diff --git a/moonraker/components/shell_command.py b/moonraker/components/shell_command.py index 3ff380c..3c7f62a 100644 --- a/moonraker/components/shell_command.py +++ b/moonraker/components/shell_command.py @@ -15,6 +15,7 @@ from utils import ServerError # Annotation imports from typing import ( TYPE_CHECKING, + List, Optional, Callable, Coroutine, @@ -23,7 +24,6 @@ from typing import ( ) if TYPE_CHECKING: from confighelper import ConfigHelper - from asyncio import BaseTransport OutputCallback = Optional[Callable[[bytes], None]] class ShellCommandError(ServerError): @@ -39,90 +39,85 @@ class ShellCommandError(ServerError): self.stderr = stderr or b"" self.return_code = return_code -class SCProcess(asyncio.subprocess.Process): - def initialize(self, - program_name: str, - std_out_cb: OutputCallback, - std_err_cb: OutputCallback, - log_stderr: bool - ) -> None: +class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol): + def __init__(self, + limit: int, + loop: asyncio.events.AbstractEventLoop, + program_name: str = "", + std_out_cb: OutputCallback = None, + std_err_cb: OutputCallback = None, + log_stderr: bool = False + ) -> None: + self._loop = loop + self._pipe_fds: List[int] = [] + super().__init__(limit, loop) self.program_name = program_name self.std_out_cb = std_out_cb self.std_err_cb = std_err_cb self.log_stderr = log_stderr - self.cancel_requested = False + self.pending_data: List[bytes] = [b"", b""] - async def _read_stream_with_cb(self, fd: int) -> bytes: - transport: BaseTransport = \ - self._transport.get_pipe_transport(fd) # type: ignore - if fd == 2: - stream = self.stderr - cb = self.std_err_cb - else: - assert fd == 1 - stream = self.stdout + def connection_made(self, + transport: asyncio.transports.BaseTransport + ) -> None: + self._transport = transport + assert isinstance(transport, asyncio.SubprocessTransport) + stdout_transport = transport.get_pipe_transport(1) + if stdout_transport is not None: + self._pipe_fds.append(1) + + stderr_transport = transport.get_pipe_transport(2) + if stderr_transport is not None: + self._pipe_fds.append(2) + + stdin_transport = transport.get_pipe_transport(0) + if stdin_transport is not None: + self.stdin = asyncio.streams.StreamWriter( + stdin_transport, + protocol=self, + reader=None, + loop=self._loop) + + def pipe_data_received(self, fd: int, data: bytes | str) -> None: + cb = None + data_idx = fd - 1 + if fd == 1: cb = self.std_out_cb - assert stream is not None - while not stream.at_eof(): - output = await stream.readline() - if not output: - break - if fd == 2 and self.log_stderr: - logging.info( - f"{self.program_name}: " - f"{output.decode(errors='ignore')}") - output = output.rstrip(b'\n') - if output and cb is not None: - cb(output) - transport.close() - return output + elif fd == 2: + cb = self.std_err_cb + if self.log_stderr: + if isinstance(data, bytes): + msg = data.decode(errors='ignore') + else: + msg = data + logging.info(f"{self.program_name}: {msg}") + if cb is not None: + if isinstance(data, str): + data = data.encode() + lines = data.split(b'\n') + lines[0] = self.pending_data[data_idx] + lines[0] + self.pending_data[data_idx] = lines.pop() + for line in lines: + if not line: + continue + cb(line) - async def cancel(self, sig_idx: int = 1) -> None: - if self.cancel_requested: - return - self.cancel_requested = True - exit_success = False - sig_idx = min(2, max(0, sig_idx)) - sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:] - for sig in sigs: - try: - self.send_signal(sig) - ret = self.wait() - await asyncio.wait_for(ret, timeout=2.) - except asyncio.TimeoutError: - continue - except ProcessLookupError: - pass - logging.debug(f"Command '{self.program_name}' exited with " - f"signal: {sig.name}") - exit_success = True - break - if not exit_success: - logging.info(f"WARNING: {self.program_name} did not cleanly exit") - if self.stdout is not None: - self.stdout.feed_eof() - if self.stderr is not None: - self.stderr.feed_eof() + def pipe_connection_lost(self, + fd: int, + exc: Exception | None + ) -> None: + cb = None + pending = b"" + if fd == 1: + cb = self.std_out_cb + pending = self.pending_data[0] + elif fd == 2: + cb = self.std_err_cb + pending = self.pending_data[1] + if pending and cb is not None: + cb(pending) + super().pipe_connection_lost(fd, exc) - async def communicate_with_cb(self, - input: Optional[bytes] = None - ) -> None: - if input is not None: - stdin: Coroutine = self._feed_stdin(input) # type: ignore - else: - stdin = self._noop() # type: ignore - if self.stdout is not None and self.std_out_cb is not None: - stdout: Coroutine = self._read_stream_with_cb(1) - else: - stdout = self._noop() # type: ignore - has_err_output = self.std_err_cb is not None or self.log_stderr - if self.stderr is not None and has_err_output: - stderr: Coroutine = self._read_stream_with_cb(2) - else: - stderr = self._noop() # type: ignore - stdin, stdout, stderr = await asyncio.tasks.gather( - stdin, stdout, stderr) - await self.wait() class ShellCommand: IDX_SIGINT = 0 @@ -146,15 +141,34 @@ class ShellCommand: self.log_stderr = log_stderr self.env = env self.cwd = cwd - self.proc: Optional[SCProcess] = None + self.proc: Optional[asyncio.subprocess.Process] = None self.cancelled = False self.return_code: Optional[int] = None self.run_lock = asyncio.Lock() async def cancel(self, sig_idx: int = 1) -> None: + if self.cancelled: + return self.cancelled = True if self.proc is not None: - await self.proc.cancel(sig_idx) + exit_success = False + sig_idx = min(2, max(0, sig_idx)) + sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:] + for sig in sigs: + try: + self.proc.send_signal(sig) + ret = self.proc.wait() + await asyncio.wait_for(ret, timeout=2.) + except asyncio.TimeoutError: + continue + except ProcessLookupError: + pass + logging.debug(f"Command '{self.name}' exited with " + f"signal: {sig.name}") + exit_success = True + break + if not exit_success: + logging.info(f"WARNING: {self.name} did not cleanly exit") def get_return_code(self) -> Optional[int]: return self.return_code @@ -175,23 +189,23 @@ class ShellCommand: if not timeout: # Never timeout timeout = 9999999999999999. - if self.std_out_cb is None and self.std_err_cb is None and \ - not self.log_stderr: + if ( + self.std_out_cb is None + and self.std_err_cb is None and + not self.log_stderr + ): # No callbacks set so output cannot be verbose verbose = False - if not await self._create_subprocess(): + if not await self._create_subprocess(use_callbacks=verbose): self.factory.remove_running_command(self) return False assert self.proc is not None try: - if verbose: - ret: Coroutine = self.proc.communicate_with_cb() - else: - ret = self.proc.wait() + ret = self.proc.wait() await asyncio.wait_for(ret, timeout=timeout) except asyncio.TimeoutError: complete = False - await self.proc.cancel(sig_idx) + await self.cancel(sig_idx) else: complete = not self.cancelled self.factory.remove_running_command(self) @@ -219,7 +233,7 @@ class ShellCommand: except asyncio.TimeoutError: complete = False timed_out = True - await self.proc.cancel(sig_idx) + await self.cancel(sig_idx) else: complete = not self.cancelled if self.log_stderr and stderr: @@ -242,24 +256,30 @@ class ShellCommand: f"Error running shell command: '{self.command}'", self.return_code, stdout, stderr) - async def _create_subprocess(self) -> bool: + async def _create_subprocess(self, use_callbacks: bool = False) -> bool: loop = asyncio.get_running_loop() def protocol_factory(): - return asyncio.subprocess.SubprocessStreamProtocol( - limit=2**20, loop=loop) + return ShellCommandProtocol( + limit=2**20, loop=loop, program_name=self.command[0], + std_out_cb=self.std_out_cb, std_err_cb=self.std_err_cb, + log_stderr=self.log_stderr) try: if self.std_err_cb is not None or self.log_stderr: errpipe = asyncio.subprocess.PIPE else: errpipe = asyncio.subprocess.STDOUT - transport, protocol = await loop.subprocess_exec( - protocol_factory, *self.command, - stdout=asyncio.subprocess.PIPE, - stderr=errpipe, env=self.env, cwd=self.cwd) - self.proc = SCProcess(transport, protocol, loop) - self.proc.initialize(self.command[0], self.std_out_cb, - self.std_err_cb, self.log_stderr) + if use_callbacks: + transport, protocol = await loop.subprocess_exec( + protocol_factory, *self.command, + stdout=asyncio.subprocess.PIPE, + stderr=errpipe, env=self.env, cwd=self.cwd) + self.proc = asyncio.subprocess.Process( + transport, protocol, loop) + else: + self.proc = await asyncio.create_subprocess_exec( + *self.command, stdout=asyncio.subprocess.PIPE, + stderr=errpipe, env=self.env, cwd=self.cwd) except Exception: logging.exception( f"shell_command: Command ({self.name}) failed")