shell_command: allow optional success codes

Some commands may return codes other than zero on a successful
return.  Allow callers to specify an optional list of return codes that
will return success without raising an exception.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-02-07 17:56:07 -05:00
parent c69441955a
commit d9579c9374
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
1 changed files with 103 additions and 81 deletions

View File

@ -28,27 +28,29 @@ if TYPE_CHECKING:
OutputCallback = Optional[Callable[[bytes], None]] OutputCallback = Optional[Callable[[bytes], None]]
class ShellCommandError(ServerError): class ShellCommandError(ServerError):
def __init__(self, def __init__(
message: str, self,
return_code: Optional[int], message: str,
stdout: Optional[bytes] = b"", return_code: Optional[int],
stderr: Optional[bytes] = b"", stdout: Optional[bytes] = b"",
status_code: int = 500 stderr: Optional[bytes] = b"",
) -> None: status_code: int = 500
) -> None:
super().__init__(message, status_code=status_code) super().__init__(message, status_code=status_code)
self.stdout = stdout or b"" self.stdout = stdout or b""
self.stderr = stderr or b"" self.stderr = stderr or b""
self.return_code = return_code self.return_code = return_code
class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol): class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
def __init__(self, def __init__(
limit: int, self,
loop: asyncio.events.AbstractEventLoop, limit: int,
program_name: str = "", loop: asyncio.events.AbstractEventLoop,
std_out_cb: OutputCallback = None, program_name: str = "",
std_err_cb: OutputCallback = None, std_out_cb: OutputCallback = None,
log_stderr: bool = False std_err_cb: OutputCallback = None,
) -> None: log_stderr: bool = False
) -> None:
self._loop = loop self._loop = loop
self._pipe_fds: List[int] = [] self._pipe_fds: List[int] = []
super().__init__(limit, loop) super().__init__(limit, loop)
@ -58,9 +60,9 @@ class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
self.log_stderr = log_stderr self.log_stderr = log_stderr
self.pending_data: List[bytes] = [b"", b""] self.pending_data: List[bytes] = [b"", b""]
def connection_made(self, def connection_made(
transport: asyncio.transports.BaseTransport self, transport: asyncio.transports.BaseTransport
) -> None: ) -> None:
self._transport = transport self._transport = transport
assert isinstance(transport, asyncio.SubprocessTransport) assert isinstance(transport, asyncio.SubprocessTransport)
stdout_transport = transport.get_pipe_transport(1) stdout_transport = transport.get_pipe_transport(1)
@ -74,10 +76,11 @@ class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
stdin_transport = transport.get_pipe_transport(0) stdin_transport = transport.get_pipe_transport(0)
if stdin_transport is not None: if stdin_transport is not None:
self.stdin = asyncio.streams.StreamWriter( self.stdin = asyncio.streams.StreamWriter(
stdin_transport, stdin_transport, # type: ignore
protocol=self, protocol=self,
reader=None, reader=None,
loop=self._loop) loop=self._loop
)
def pipe_data_received(self, fd: int, data: bytes | str) -> None: def pipe_data_received(self, fd: int, data: bytes | str) -> None:
cb = None cb = None
@ -103,10 +106,9 @@ class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
continue continue
cb(line) cb(line)
def pipe_connection_lost(self, def pipe_connection_lost(
fd: int, self, fd: int, exc: Exception | None
exc: Exception | None ) -> None:
) -> None:
cb = None cb = None
pending = b"" pending = b""
if fd == 1: if fd == 1:
@ -124,15 +126,16 @@ class ShellCommand:
IDX_SIGINT = 0 IDX_SIGINT = 0
IDX_SIGTERM = 1 IDX_SIGTERM = 1
IDX_SIGKILL = 2 IDX_SIGKILL = 2
def __init__(self, def __init__(
factory: ShellCommandFactory, self,
cmd: str, factory: ShellCommandFactory,
std_out_callback: OutputCallback, cmd: str,
std_err_callback: OutputCallback, std_out_callback: OutputCallback,
env: Optional[Dict[str, str]] = None, std_err_callback: OutputCallback,
log_stderr: bool = False, env: Optional[Dict[str, str]] = None,
cwd: Optional[str] = None log_stderr: bool = False,
) -> None: cwd: Optional[str] = None
) -> None:
self.factory = factory self.factory = factory
self.name = cmd self.name = cmd
self.std_out_cb = std_out_callback self.std_out_cb = std_out_callback
@ -178,13 +181,15 @@ class ShellCommand:
self.return_code = self.proc = None self.return_code = self.proc = None
self.cancelled = False self.cancelled = False
async def run(self, async def run(
timeout: float = 2., self,
verbose: bool = True, timeout: float = 2.,
log_complete: bool = True, verbose: bool = True,
sig_idx: int = 1, log_complete: bool = True,
proc_input: Optional[str] = None sig_idx: int = 1,
) -> bool: proc_input: Optional[str] = None,
success_codes: Optional[List[int]] = None
) -> bool:
async with self.run_lock: async with self.run_lock:
self.factory.add_running_command(self) self.factory.add_running_command(self)
self._reset_command_data() self._reset_command_data()
@ -217,15 +222,19 @@ class ShellCommand:
else: else:
complete = not self.cancelled complete = not self.cancelled
self.factory.remove_running_command(self) self.factory.remove_running_command(self)
return self._check_proc_success(complete, log_complete) return self._check_proc_success(
complete, log_complete, success_codes
)
async def run_with_response(self, async def run_with_response(
timeout: float = 2., self,
retries: int = 1, timeout: float = 2.,
log_complete: bool = True, retries: int = 1,
sig_idx: int = 1, log_complete: bool = True,
proc_input: Optional[str] = None sig_idx: int = 1,
) -> str: proc_input: Optional[str] = None,
success_codes: Optional[List[int]] = None
) -> str:
async with self.run_lock: async with self.run_lock:
self.factory.add_running_command(self) self.factory.add_running_command(self)
retries = max(1, retries) retries = max(1, retries)
@ -252,7 +261,9 @@ class ShellCommand:
logging.info( logging.info(
f"{self.command[0]}: " f"{self.command[0]}: "
f"{stderr.decode(errors='ignore')}") f"{stderr.decode(errors='ignore')}")
if self._check_proc_success(complete, log_complete): if self._check_proc_success(
complete, log_complete, success_codes
):
self.factory.remove_running_command(self) self.factory.remove_running_command(self)
return stdout.decode(errors='ignore').rstrip("\n") return stdout.decode(errors='ignore').rstrip("\n")
if stdout: if stdout:
@ -268,10 +279,11 @@ class ShellCommand:
f"Error running shell command: '{self.name}'", f"Error running shell command: '{self.name}'",
self.return_code, stdout, stderr) self.return_code, stdout, stderr)
async def _create_subprocess(self, async def _create_subprocess(
use_callbacks: bool = False, self,
has_input: bool = False use_callbacks: bool = False,
) -> bool: has_input: bool = False
) -> bool:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
def protocol_factory(): def protocol_factory():
@ -305,13 +317,17 @@ class ShellCommand:
return False return False
return True return True
def _check_proc_success(self, def _check_proc_success(
complete: bool, self,
log_complete: bool complete: bool,
) -> bool: log_complete: bool,
success_codes: Optional[List[int]] = None
) -> bool:
assert self.proc is not None assert self.proc is not None
if success_codes is None:
success_codes = [0]
self.return_code = self.proc.returncode self.return_code = self.proc.returncode
success = self.return_code == 0 and complete success = self.return_code in success_codes and complete
if success: if success:
msg = f"Command ({self.name}) successfully finished" msg = f"Command ({self.name}) successfully finished"
elif self.cancelled: elif self.cancelled:
@ -339,32 +355,38 @@ class ShellCommandFactory:
except KeyError: except KeyError:
pass pass
def build_shell_command(self, def build_shell_command(
cmd: str, self,
callback: OutputCallback = None, cmd: str,
std_err_callback: OutputCallback = None, callback: OutputCallback = None,
env: Optional[Dict[str, str]] = None, std_err_callback: OutputCallback = None,
log_stderr: bool = False, env: Optional[Dict[str, str]] = None,
cwd: Optional[str] = None log_stderr: bool = False,
) -> ShellCommand: cwd: Optional[str] = None
return ShellCommand(self, cmd, callback, std_err_callback, env, ) -> ShellCommand:
log_stderr, cwd) return ShellCommand(
self, cmd, callback, std_err_callback, env, log_stderr, cwd
)
def exec_cmd(self, def exec_cmd(
cmd: str, self,
timeout: float = 2., cmd: str,
retries: int = 1, timeout: float = 2.,
sig_idx: int = 1, retries: int = 1,
proc_input: Optional[str] = None, sig_idx: int = 1,
log_complete: bool = True, proc_input: Optional[str] = None,
log_stderr: bool = False, log_complete: bool = True,
env: Optional[Dict[str, str]] = None, log_stderr: bool = False,
cwd: Optional[str] = None env: Optional[Dict[str, str]] = None,
) -> Awaitable: cwd: Optional[str] = None,
success_codes: Optional[List[int]] = None
) -> Awaitable:
scmd = ShellCommand(self, cmd, None, None, env, scmd = ShellCommand(self, cmd, None, None, env,
log_stderr, cwd) log_stderr, cwd)
coro = scmd.run_with_response(timeout, retries, log_complete, coro = scmd.run_with_response(
sig_idx, proc_input) timeout, retries, log_complete, sig_idx,
proc_input, success_codes
)
return asyncio.create_task(coro) return asyncio.create_task(coro)
async def close(self) -> None: async def close(self) -> None: