shell_command: use a custom protocol for callbacks
Rather than override the Process class instead create a custom protocol that forwards data over callbacks. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
9911b5c7dd
commit
8546cd6ac5
|
@ -15,6 +15,7 @@ from utils import ServerError
|
||||||
# Annotation imports
|
# Annotation imports
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
|
@ -23,7 +24,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from confighelper import ConfigHelper
|
from confighelper import ConfigHelper
|
||||||
from asyncio import BaseTransport
|
|
||||||
OutputCallback = Optional[Callable[[bytes], None]]
|
OutputCallback = Optional[Callable[[bytes], None]]
|
||||||
|
|
||||||
class ShellCommandError(ServerError):
|
class ShellCommandError(ServerError):
|
||||||
|
@ -39,90 +39,85 @@ class ShellCommandError(ServerError):
|
||||||
self.stderr = stderr or b""
|
self.stderr = stderr or b""
|
||||||
self.return_code = return_code
|
self.return_code = return_code
|
||||||
|
|
||||||
class SCProcess(asyncio.subprocess.Process):
|
class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
|
||||||
def initialize(self,
|
def __init__(self,
|
||||||
program_name: str,
|
limit: int,
|
||||||
std_out_cb: OutputCallback,
|
loop: asyncio.events.AbstractEventLoop,
|
||||||
std_err_cb: OutputCallback,
|
program_name: str = "",
|
||||||
log_stderr: bool
|
std_out_cb: OutputCallback = None,
|
||||||
) -> None:
|
std_err_cb: OutputCallback = None,
|
||||||
|
log_stderr: bool = False
|
||||||
|
) -> None:
|
||||||
|
self._loop = loop
|
||||||
|
self._pipe_fds: List[int] = []
|
||||||
|
super().__init__(limit, loop)
|
||||||
self.program_name = program_name
|
self.program_name = program_name
|
||||||
self.std_out_cb = std_out_cb
|
self.std_out_cb = std_out_cb
|
||||||
self.std_err_cb = std_err_cb
|
self.std_err_cb = std_err_cb
|
||||||
self.log_stderr = log_stderr
|
self.log_stderr = log_stderr
|
||||||
self.cancel_requested = False
|
self.pending_data: List[bytes] = [b"", b""]
|
||||||
|
|
||||||
async def _read_stream_with_cb(self, fd: int) -> bytes:
|
def connection_made(self,
|
||||||
transport: BaseTransport = \
|
transport: asyncio.transports.BaseTransport
|
||||||
self._transport.get_pipe_transport(fd) # type: ignore
|
) -> None:
|
||||||
if fd == 2:
|
self._transport = transport
|
||||||
stream = self.stderr
|
assert isinstance(transport, asyncio.SubprocessTransport)
|
||||||
cb = self.std_err_cb
|
stdout_transport = transport.get_pipe_transport(1)
|
||||||
else:
|
if stdout_transport is not None:
|
||||||
assert fd == 1
|
self._pipe_fds.append(1)
|
||||||
stream = self.stdout
|
|
||||||
|
stderr_transport = transport.get_pipe_transport(2)
|
||||||
|
if stderr_transport is not None:
|
||||||
|
self._pipe_fds.append(2)
|
||||||
|
|
||||||
|
stdin_transport = transport.get_pipe_transport(0)
|
||||||
|
if stdin_transport is not None:
|
||||||
|
self.stdin = asyncio.streams.StreamWriter(
|
||||||
|
stdin_transport,
|
||||||
|
protocol=self,
|
||||||
|
reader=None,
|
||||||
|
loop=self._loop)
|
||||||
|
|
||||||
|
def pipe_data_received(self, fd: int, data: bytes | str) -> None:
|
||||||
|
cb = None
|
||||||
|
data_idx = fd - 1
|
||||||
|
if fd == 1:
|
||||||
cb = self.std_out_cb
|
cb = self.std_out_cb
|
||||||
assert stream is not None
|
elif fd == 2:
|
||||||
while not stream.at_eof():
|
cb = self.std_err_cb
|
||||||
output = await stream.readline()
|
if self.log_stderr:
|
||||||
if not output:
|
if isinstance(data, bytes):
|
||||||
break
|
msg = data.decode(errors='ignore')
|
||||||
if fd == 2 and self.log_stderr:
|
else:
|
||||||
logging.info(
|
msg = data
|
||||||
f"{self.program_name}: "
|
logging.info(f"{self.program_name}: {msg}")
|
||||||
f"{output.decode(errors='ignore')}")
|
if cb is not None:
|
||||||
output = output.rstrip(b'\n')
|
if isinstance(data, str):
|
||||||
if output and cb is not None:
|
data = data.encode()
|
||||||
cb(output)
|
lines = data.split(b'\n')
|
||||||
transport.close()
|
lines[0] = self.pending_data[data_idx] + lines[0]
|
||||||
return output
|
self.pending_data[data_idx] = lines.pop()
|
||||||
|
for line in lines:
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
cb(line)
|
||||||
|
|
||||||
async def cancel(self, sig_idx: int = 1) -> None:
|
def pipe_connection_lost(self,
|
||||||
if self.cancel_requested:
|
fd: int,
|
||||||
return
|
exc: Exception | None
|
||||||
self.cancel_requested = True
|
) -> None:
|
||||||
exit_success = False
|
cb = None
|
||||||
sig_idx = min(2, max(0, sig_idx))
|
pending = b""
|
||||||
sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:]
|
if fd == 1:
|
||||||
for sig in sigs:
|
cb = self.std_out_cb
|
||||||
try:
|
pending = self.pending_data[0]
|
||||||
self.send_signal(sig)
|
elif fd == 2:
|
||||||
ret = self.wait()
|
cb = self.std_err_cb
|
||||||
await asyncio.wait_for(ret, timeout=2.)
|
pending = self.pending_data[1]
|
||||||
except asyncio.TimeoutError:
|
if pending and cb is not None:
|
||||||
continue
|
cb(pending)
|
||||||
except ProcessLookupError:
|
super().pipe_connection_lost(fd, exc)
|
||||||
pass
|
|
||||||
logging.debug(f"Command '{self.program_name}' exited with "
|
|
||||||
f"signal: {sig.name}")
|
|
||||||
exit_success = True
|
|
||||||
break
|
|
||||||
if not exit_success:
|
|
||||||
logging.info(f"WARNING: {self.program_name} did not cleanly exit")
|
|
||||||
if self.stdout is not None:
|
|
||||||
self.stdout.feed_eof()
|
|
||||||
if self.stderr is not None:
|
|
||||||
self.stderr.feed_eof()
|
|
||||||
|
|
||||||
async def communicate_with_cb(self,
|
|
||||||
input: Optional[bytes] = None
|
|
||||||
) -> None:
|
|
||||||
if input is not None:
|
|
||||||
stdin: Coroutine = self._feed_stdin(input) # type: ignore
|
|
||||||
else:
|
|
||||||
stdin = self._noop() # type: ignore
|
|
||||||
if self.stdout is not None and self.std_out_cb is not None:
|
|
||||||
stdout: Coroutine = self._read_stream_with_cb(1)
|
|
||||||
else:
|
|
||||||
stdout = self._noop() # type: ignore
|
|
||||||
has_err_output = self.std_err_cb is not None or self.log_stderr
|
|
||||||
if self.stderr is not None and has_err_output:
|
|
||||||
stderr: Coroutine = self._read_stream_with_cb(2)
|
|
||||||
else:
|
|
||||||
stderr = self._noop() # type: ignore
|
|
||||||
stdin, stdout, stderr = await asyncio.tasks.gather(
|
|
||||||
stdin, stdout, stderr)
|
|
||||||
await self.wait()
|
|
||||||
|
|
||||||
class ShellCommand:
|
class ShellCommand:
|
||||||
IDX_SIGINT = 0
|
IDX_SIGINT = 0
|
||||||
|
@ -146,15 +141,34 @@ class ShellCommand:
|
||||||
self.log_stderr = log_stderr
|
self.log_stderr = log_stderr
|
||||||
self.env = env
|
self.env = env
|
||||||
self.cwd = cwd
|
self.cwd = cwd
|
||||||
self.proc: Optional[SCProcess] = None
|
self.proc: Optional[asyncio.subprocess.Process] = None
|
||||||
self.cancelled = False
|
self.cancelled = False
|
||||||
self.return_code: Optional[int] = None
|
self.return_code: Optional[int] = None
|
||||||
self.run_lock = asyncio.Lock()
|
self.run_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def cancel(self, sig_idx: int = 1) -> None:
|
async def cancel(self, sig_idx: int = 1) -> None:
|
||||||
|
if self.cancelled:
|
||||||
|
return
|
||||||
self.cancelled = True
|
self.cancelled = True
|
||||||
if self.proc is not None:
|
if self.proc is not None:
|
||||||
await self.proc.cancel(sig_idx)
|
exit_success = False
|
||||||
|
sig_idx = min(2, max(0, sig_idx))
|
||||||
|
sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:]
|
||||||
|
for sig in sigs:
|
||||||
|
try:
|
||||||
|
self.proc.send_signal(sig)
|
||||||
|
ret = self.proc.wait()
|
||||||
|
await asyncio.wait_for(ret, timeout=2.)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
logging.debug(f"Command '{self.name}' exited with "
|
||||||
|
f"signal: {sig.name}")
|
||||||
|
exit_success = True
|
||||||
|
break
|
||||||
|
if not exit_success:
|
||||||
|
logging.info(f"WARNING: {self.name} did not cleanly exit")
|
||||||
|
|
||||||
def get_return_code(self) -> Optional[int]:
|
def get_return_code(self) -> Optional[int]:
|
||||||
return self.return_code
|
return self.return_code
|
||||||
|
@ -175,23 +189,23 @@ class ShellCommand:
|
||||||
if not timeout:
|
if not timeout:
|
||||||
# Never timeout
|
# Never timeout
|
||||||
timeout = 9999999999999999.
|
timeout = 9999999999999999.
|
||||||
if self.std_out_cb is None and self.std_err_cb is None and \
|
if (
|
||||||
not self.log_stderr:
|
self.std_out_cb is None
|
||||||
|
and self.std_err_cb is None and
|
||||||
|
not self.log_stderr
|
||||||
|
):
|
||||||
# No callbacks set so output cannot be verbose
|
# No callbacks set so output cannot be verbose
|
||||||
verbose = False
|
verbose = False
|
||||||
if not await self._create_subprocess():
|
if not await self._create_subprocess(use_callbacks=verbose):
|
||||||
self.factory.remove_running_command(self)
|
self.factory.remove_running_command(self)
|
||||||
return False
|
return False
|
||||||
assert self.proc is not None
|
assert self.proc is not None
|
||||||
try:
|
try:
|
||||||
if verbose:
|
ret = self.proc.wait()
|
||||||
ret: Coroutine = self.proc.communicate_with_cb()
|
|
||||||
else:
|
|
||||||
ret = self.proc.wait()
|
|
||||||
await asyncio.wait_for(ret, timeout=timeout)
|
await asyncio.wait_for(ret, timeout=timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
complete = False
|
complete = False
|
||||||
await self.proc.cancel(sig_idx)
|
await self.cancel(sig_idx)
|
||||||
else:
|
else:
|
||||||
complete = not self.cancelled
|
complete = not self.cancelled
|
||||||
self.factory.remove_running_command(self)
|
self.factory.remove_running_command(self)
|
||||||
|
@ -219,7 +233,7 @@ class ShellCommand:
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
complete = False
|
complete = False
|
||||||
timed_out = True
|
timed_out = True
|
||||||
await self.proc.cancel(sig_idx)
|
await self.cancel(sig_idx)
|
||||||
else:
|
else:
|
||||||
complete = not self.cancelled
|
complete = not self.cancelled
|
||||||
if self.log_stderr and stderr:
|
if self.log_stderr and stderr:
|
||||||
|
@ -242,24 +256,30 @@ class ShellCommand:
|
||||||
f"Error running shell command: '{self.command}'",
|
f"Error running shell command: '{self.command}'",
|
||||||
self.return_code, stdout, stderr)
|
self.return_code, stdout, stderr)
|
||||||
|
|
||||||
async def _create_subprocess(self) -> bool:
|
async def _create_subprocess(self, use_callbacks: bool = False) -> bool:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
def protocol_factory():
|
def protocol_factory():
|
||||||
return asyncio.subprocess.SubprocessStreamProtocol(
|
return ShellCommandProtocol(
|
||||||
limit=2**20, loop=loop)
|
limit=2**20, loop=loop, program_name=self.command[0],
|
||||||
|
std_out_cb=self.std_out_cb, std_err_cb=self.std_err_cb,
|
||||||
|
log_stderr=self.log_stderr)
|
||||||
try:
|
try:
|
||||||
if self.std_err_cb is not None or self.log_stderr:
|
if self.std_err_cb is not None or self.log_stderr:
|
||||||
errpipe = asyncio.subprocess.PIPE
|
errpipe = asyncio.subprocess.PIPE
|
||||||
else:
|
else:
|
||||||
errpipe = asyncio.subprocess.STDOUT
|
errpipe = asyncio.subprocess.STDOUT
|
||||||
transport, protocol = await loop.subprocess_exec(
|
if use_callbacks:
|
||||||
protocol_factory, *self.command,
|
transport, protocol = await loop.subprocess_exec(
|
||||||
stdout=asyncio.subprocess.PIPE,
|
protocol_factory, *self.command,
|
||||||
stderr=errpipe, env=self.env, cwd=self.cwd)
|
stdout=asyncio.subprocess.PIPE,
|
||||||
self.proc = SCProcess(transport, protocol, loop)
|
stderr=errpipe, env=self.env, cwd=self.cwd)
|
||||||
self.proc.initialize(self.command[0], self.std_out_cb,
|
self.proc = asyncio.subprocess.Process(
|
||||||
self.std_err_cb, self.log_stderr)
|
transport, protocol, loop)
|
||||||
|
else:
|
||||||
|
self.proc = await asyncio.create_subprocess_exec(
|
||||||
|
*self.command, stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=errpipe, env=self.env, cwd=self.cwd)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.exception(
|
logging.exception(
|
||||||
f"shell_command: Command ({self.name}) failed")
|
f"shell_command: Command ({self.name}) failed")
|
||||||
|
|
Loading…
Reference in New Issue