shell_command: use a custom protocol for callbacks

Rather than override the Process class instead create a custom
protocol that forwards data over callbacks.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-01-13 12:34:06 -05:00 committed by Eric Callahan
parent 9911b5c7dd
commit 8546cd6ac5
1 changed files with 118 additions and 98 deletions

View File

@ -15,6 +15,7 @@ from utils import ServerError
# Annotation imports # Annotation imports
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
List,
Optional, Optional,
Callable, Callable,
Coroutine, Coroutine,
@ -23,7 +24,6 @@ from typing import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from confighelper import ConfigHelper from confighelper import ConfigHelper
from asyncio import BaseTransport
OutputCallback = Optional[Callable[[bytes], None]] OutputCallback = Optional[Callable[[bytes], None]]
class ShellCommandError(ServerError): class ShellCommandError(ServerError):
@ -39,90 +39,85 @@ class ShellCommandError(ServerError):
self.stderr = stderr or b"" self.stderr = stderr or b""
self.return_code = return_code self.return_code = return_code
class SCProcess(asyncio.subprocess.Process): class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
def initialize(self, def __init__(self,
program_name: str, limit: int,
std_out_cb: OutputCallback, loop: asyncio.events.AbstractEventLoop,
std_err_cb: OutputCallback, program_name: str = "",
log_stderr: bool std_out_cb: OutputCallback = None,
std_err_cb: OutputCallback = None,
log_stderr: bool = False
) -> None: ) -> None:
self._loop = loop
self._pipe_fds: List[int] = []
super().__init__(limit, loop)
self.program_name = program_name self.program_name = program_name
self.std_out_cb = std_out_cb self.std_out_cb = std_out_cb
self.std_err_cb = std_err_cb self.std_err_cb = std_err_cb
self.log_stderr = log_stderr self.log_stderr = log_stderr
self.cancel_requested = False self.pending_data: List[bytes] = [b"", b""]
async def _read_stream_with_cb(self, fd: int) -> bytes: def connection_made(self,
transport: BaseTransport = \ transport: asyncio.transports.BaseTransport
self._transport.get_pipe_transport(fd) # type: ignore
if fd == 2:
stream = self.stderr
cb = self.std_err_cb
else:
assert fd == 1
stream = self.stdout
cb = self.std_out_cb
assert stream is not None
while not stream.at_eof():
output = await stream.readline()
if not output:
break
if fd == 2 and self.log_stderr:
logging.info(
f"{self.program_name}: "
f"{output.decode(errors='ignore')}")
output = output.rstrip(b'\n')
if output and cb is not None:
cb(output)
transport.close()
return output
async def cancel(self, sig_idx: int = 1) -> None:
if self.cancel_requested:
return
self.cancel_requested = True
exit_success = False
sig_idx = min(2, max(0, sig_idx))
sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:]
for sig in sigs:
try:
self.send_signal(sig)
ret = self.wait()
await asyncio.wait_for(ret, timeout=2.)
except asyncio.TimeoutError:
continue
except ProcessLookupError:
pass
logging.debug(f"Command '{self.program_name}' exited with "
f"signal: {sig.name}")
exit_success = True
break
if not exit_success:
logging.info(f"WARNING: {self.program_name} did not cleanly exit")
if self.stdout is not None:
self.stdout.feed_eof()
if self.stderr is not None:
self.stderr.feed_eof()
async def communicate_with_cb(self,
input: Optional[bytes] = None
) -> None: ) -> None:
if input is not None: self._transport = transport
stdin: Coroutine = self._feed_stdin(input) # type: ignore assert isinstance(transport, asyncio.SubprocessTransport)
stdout_transport = transport.get_pipe_transport(1)
if stdout_transport is not None:
self._pipe_fds.append(1)
stderr_transport = transport.get_pipe_transport(2)
if stderr_transport is not None:
self._pipe_fds.append(2)
stdin_transport = transport.get_pipe_transport(0)
if stdin_transport is not None:
self.stdin = asyncio.streams.StreamWriter(
stdin_transport,
protocol=self,
reader=None,
loop=self._loop)
def pipe_data_received(self, fd: int, data: bytes | str) -> None:
cb = None
data_idx = fd - 1
if fd == 1:
cb = self.std_out_cb
elif fd == 2:
cb = self.std_err_cb
if self.log_stderr:
if isinstance(data, bytes):
msg = data.decode(errors='ignore')
else: else:
stdin = self._noop() # type: ignore msg = data
if self.stdout is not None and self.std_out_cb is not None: logging.info(f"{self.program_name}: {msg}")
stdout: Coroutine = self._read_stream_with_cb(1) if cb is not None:
else: if isinstance(data, str):
stdout = self._noop() # type: ignore data = data.encode()
has_err_output = self.std_err_cb is not None or self.log_stderr lines = data.split(b'\n')
if self.stderr is not None and has_err_output: lines[0] = self.pending_data[data_idx] + lines[0]
stderr: Coroutine = self._read_stream_with_cb(2) self.pending_data[data_idx] = lines.pop()
else: for line in lines:
stderr = self._noop() # type: ignore if not line:
stdin, stdout, stderr = await asyncio.tasks.gather( continue
stdin, stdout, stderr) cb(line)
await self.wait()
def pipe_connection_lost(self,
fd: int,
exc: Exception | None
) -> None:
cb = None
pending = b""
if fd == 1:
cb = self.std_out_cb
pending = self.pending_data[0]
elif fd == 2:
cb = self.std_err_cb
pending = self.pending_data[1]
if pending and cb is not None:
cb(pending)
super().pipe_connection_lost(fd, exc)
class ShellCommand: class ShellCommand:
IDX_SIGINT = 0 IDX_SIGINT = 0
@ -146,15 +141,34 @@ class ShellCommand:
self.log_stderr = log_stderr self.log_stderr = log_stderr
self.env = env self.env = env
self.cwd = cwd self.cwd = cwd
self.proc: Optional[SCProcess] = None self.proc: Optional[asyncio.subprocess.Process] = None
self.cancelled = False self.cancelled = False
self.return_code: Optional[int] = None self.return_code: Optional[int] = None
self.run_lock = asyncio.Lock() self.run_lock = asyncio.Lock()
async def cancel(self, sig_idx: int = 1) -> None: async def cancel(self, sig_idx: int = 1) -> None:
if self.cancelled:
return
self.cancelled = True self.cancelled = True
if self.proc is not None: if self.proc is not None:
await self.proc.cancel(sig_idx) exit_success = False
sig_idx = min(2, max(0, sig_idx))
sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:]
for sig in sigs:
try:
self.proc.send_signal(sig)
ret = self.proc.wait()
await asyncio.wait_for(ret, timeout=2.)
except asyncio.TimeoutError:
continue
except ProcessLookupError:
pass
logging.debug(f"Command '{self.name}' exited with "
f"signal: {sig.name}")
exit_success = True
break
if not exit_success:
logging.info(f"WARNING: {self.name} did not cleanly exit")
def get_return_code(self) -> Optional[int]: def get_return_code(self) -> Optional[int]:
return self.return_code return self.return_code
@ -175,23 +189,23 @@ class ShellCommand:
if not timeout: if not timeout:
# Never timeout # Never timeout
timeout = 9999999999999999. timeout = 9999999999999999.
if self.std_out_cb is None and self.std_err_cb is None and \ if (
not self.log_stderr: self.std_out_cb is None
and self.std_err_cb is None and
not self.log_stderr
):
# No callbacks set so output cannot be verbose # No callbacks set so output cannot be verbose
verbose = False verbose = False
if not await self._create_subprocess(): if not await self._create_subprocess(use_callbacks=verbose):
self.factory.remove_running_command(self) self.factory.remove_running_command(self)
return False return False
assert self.proc is not None assert self.proc is not None
try: try:
if verbose:
ret: Coroutine = self.proc.communicate_with_cb()
else:
ret = self.proc.wait() ret = self.proc.wait()
await asyncio.wait_for(ret, timeout=timeout) await asyncio.wait_for(ret, timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
complete = False complete = False
await self.proc.cancel(sig_idx) await self.cancel(sig_idx)
else: else:
complete = not self.cancelled complete = not self.cancelled
self.factory.remove_running_command(self) self.factory.remove_running_command(self)
@ -219,7 +233,7 @@ class ShellCommand:
except asyncio.TimeoutError: except asyncio.TimeoutError:
complete = False complete = False
timed_out = True timed_out = True
await self.proc.cancel(sig_idx) await self.cancel(sig_idx)
else: else:
complete = not self.cancelled complete = not self.cancelled
if self.log_stderr and stderr: if self.log_stderr and stderr:
@ -242,24 +256,30 @@ class ShellCommand:
f"Error running shell command: '{self.command}'", f"Error running shell command: '{self.command}'",
self.return_code, stdout, stderr) self.return_code, stdout, stderr)
async def _create_subprocess(self) -> bool: async def _create_subprocess(self, use_callbacks: bool = False) -> bool:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
def protocol_factory(): def protocol_factory():
return asyncio.subprocess.SubprocessStreamProtocol( return ShellCommandProtocol(
limit=2**20, loop=loop) limit=2**20, loop=loop, program_name=self.command[0],
std_out_cb=self.std_out_cb, std_err_cb=self.std_err_cb,
log_stderr=self.log_stderr)
try: try:
if self.std_err_cb is not None or self.log_stderr: if self.std_err_cb is not None or self.log_stderr:
errpipe = asyncio.subprocess.PIPE errpipe = asyncio.subprocess.PIPE
else: else:
errpipe = asyncio.subprocess.STDOUT errpipe = asyncio.subprocess.STDOUT
if use_callbacks:
transport, protocol = await loop.subprocess_exec( transport, protocol = await loop.subprocess_exec(
protocol_factory, *self.command, protocol_factory, *self.command,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=errpipe, env=self.env, cwd=self.cwd) stderr=errpipe, env=self.env, cwd=self.cwd)
self.proc = SCProcess(transport, protocol, loop) self.proc = asyncio.subprocess.Process(
self.proc.initialize(self.command[0], self.std_out_cb, transport, protocol, loop)
self.std_err_cb, self.log_stderr) else:
self.proc = await asyncio.create_subprocess_exec(
*self.command, stdout=asyncio.subprocess.PIPE,
stderr=errpipe, env=self.env, cwd=self.cwd)
except Exception: except Exception:
logging.exception( logging.exception(
f"shell_command: Command ({self.name}) failed") f"shell_command: Command ({self.name}) failed")