398 lines
13 KiB
Python
398 lines
13 KiB
Python
# linux shell command execution utility
|
|
#
|
|
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
|
#
|
|
# This file may be distributed under the terms of the GNU GPLv3 license.
|
|
|
|
from __future__ import annotations
|
|
import os
|
|
import shlex
|
|
import logging
|
|
import signal
|
|
import asyncio
|
|
from ..utils import ServerError
|
|
|
|
# Annotation imports
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Awaitable,
|
|
List,
|
|
Optional,
|
|
Callable,
|
|
Coroutine,
|
|
Dict,
|
|
Set,
|
|
)
|
|
if TYPE_CHECKING:
|
|
from ..confighelper import ConfigHelper
|
|
OutputCallback = Optional[Callable[[bytes], None]]
|
|
|
|
class ShellCommandError(ServerError):
|
|
def __init__(
|
|
self,
|
|
message: str,
|
|
return_code: Optional[int],
|
|
stdout: Optional[bytes] = b"",
|
|
stderr: Optional[bytes] = b"",
|
|
status_code: int = 500
|
|
) -> None:
|
|
super().__init__(message, status_code=status_code)
|
|
self.stdout = stdout or b""
|
|
self.stderr = stderr or b""
|
|
self.return_code = return_code
|
|
|
|
class ShellCommandProtocol(asyncio.subprocess.SubprocessStreamProtocol):
|
|
def __init__(
|
|
self,
|
|
limit: int,
|
|
loop: asyncio.events.AbstractEventLoop,
|
|
std_out_cb: OutputCallback = None,
|
|
std_err_cb: OutputCallback = None,
|
|
log_stderr: bool = False
|
|
) -> None:
|
|
self._loop = loop
|
|
self._pipe_fds: List[int] = []
|
|
super().__init__(limit, loop)
|
|
self.std_out_cb = std_out_cb
|
|
self.std_err_cb = std_err_cb
|
|
self.log_stderr = log_stderr
|
|
self.pending_data: List[bytes] = [b"", b""]
|
|
|
|
def connection_made(
|
|
self, transport: asyncio.transports.BaseTransport
|
|
) -> None:
|
|
self._transport = transport
|
|
assert isinstance(transport, asyncio.SubprocessTransport)
|
|
stdout_transport = transport.get_pipe_transport(1)
|
|
if stdout_transport is not None:
|
|
self._pipe_fds.append(1)
|
|
|
|
stderr_transport = transport.get_pipe_transport(2)
|
|
if stderr_transport is not None:
|
|
self._pipe_fds.append(2)
|
|
|
|
stdin_transport = transport.get_pipe_transport(0)
|
|
if stdin_transport is not None:
|
|
self.stdin = asyncio.streams.StreamWriter(
|
|
stdin_transport, # type: ignore
|
|
protocol=self,
|
|
reader=None,
|
|
loop=self._loop
|
|
)
|
|
|
|
def pipe_data_received(self, fd: int, data: bytes | str) -> None:
|
|
cb = None
|
|
data_idx = fd - 1
|
|
if fd == 1:
|
|
cb = self.std_out_cb
|
|
elif fd == 2:
|
|
cb = self.std_err_cb
|
|
if self.log_stderr:
|
|
if isinstance(data, bytes):
|
|
msg = data.decode(errors='ignore')
|
|
else:
|
|
msg = data
|
|
logging.info(msg)
|
|
if cb is not None:
|
|
if isinstance(data, str):
|
|
data = data.encode()
|
|
lines = data.split(b'\n')
|
|
lines[0] = self.pending_data[data_idx] + lines[0]
|
|
self.pending_data[data_idx] = lines.pop()
|
|
for line in lines:
|
|
if not line:
|
|
continue
|
|
cb(line)
|
|
|
|
def pipe_connection_lost(
|
|
self, fd: int, exc: Exception | None
|
|
) -> None:
|
|
cb = None
|
|
pending = b""
|
|
if fd == 1:
|
|
cb = self.std_out_cb
|
|
pending = self.pending_data[0]
|
|
elif fd == 2:
|
|
cb = self.std_err_cb
|
|
pending = self.pending_data[1]
|
|
if pending and cb is not None:
|
|
cb(pending)
|
|
super().pipe_connection_lost(fd, exc)
|
|
|
|
|
|
class ShellCommand:
|
|
IDX_SIGINT = 0
|
|
IDX_SIGTERM = 1
|
|
IDX_SIGKILL = 2
|
|
def __init__(
|
|
self,
|
|
factory: ShellCommandFactory,
|
|
cmd: str,
|
|
std_out_callback: OutputCallback,
|
|
std_err_callback: OutputCallback,
|
|
env: Optional[Dict[str, str]] = None,
|
|
log_stderr: bool = False,
|
|
cwd: Optional[str] = None
|
|
) -> None:
|
|
self.factory = factory
|
|
self.name = cmd
|
|
self.std_out_cb = std_out_callback
|
|
self.std_err_cb = std_err_callback
|
|
cmd = os.path.expanduser(cmd)
|
|
self.command = shlex.split(cmd)
|
|
self.log_stderr = log_stderr
|
|
self.env = env
|
|
self.cwd = cwd
|
|
self.proc: Optional[asyncio.subprocess.Process] = None
|
|
self.cancelled = False
|
|
self.return_code: Optional[int] = None
|
|
self.run_lock = asyncio.Lock()
|
|
|
|
async def cancel(self, sig_idx: int = 1) -> None:
|
|
if self.cancelled:
|
|
return
|
|
self.cancelled = True
|
|
if self.proc is not None:
|
|
exit_success = False
|
|
sig_idx = min(2, max(0, sig_idx))
|
|
sigs = [signal.SIGINT, signal.SIGTERM, signal.SIGKILL][sig_idx:]
|
|
for sig in sigs:
|
|
try:
|
|
self.proc.send_signal(sig)
|
|
ret = self.proc.wait()
|
|
await asyncio.wait_for(ret, timeout=2.)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
except ProcessLookupError:
|
|
pass
|
|
logging.debug(f"Command '{self.name}' exited with "
|
|
f"signal: {sig.name}")
|
|
exit_success = True
|
|
break
|
|
if not exit_success:
|
|
logging.info(f"WARNING: {self.name} did not cleanly exit")
|
|
|
|
def get_return_code(self) -> Optional[int]:
|
|
return self.return_code
|
|
|
|
def _reset_command_data(self) -> None:
|
|
self.return_code = self.proc = None
|
|
self.cancelled = False
|
|
|
|
async def run(
|
|
self,
|
|
timeout: float = 2.,
|
|
verbose: bool = True,
|
|
log_complete: bool = True,
|
|
sig_idx: int = 1,
|
|
proc_input: Optional[str] = None,
|
|
success_codes: Optional[List[int]] = None
|
|
) -> bool:
|
|
async with self.run_lock:
|
|
self.factory.add_running_command(self)
|
|
self._reset_command_data()
|
|
if not timeout:
|
|
# Never timeout
|
|
timeout = 9999999999999999.
|
|
if (
|
|
self.std_out_cb is None
|
|
and self.std_err_cb is None and
|
|
not self.log_stderr
|
|
):
|
|
# No callbacks set so output cannot be verbose
|
|
verbose = False
|
|
created = await self._create_subprocess(
|
|
verbose, proc_input is not None)
|
|
if not created:
|
|
self.factory.remove_running_command(self)
|
|
return False
|
|
assert self.proc is not None
|
|
try:
|
|
if proc_input is not None:
|
|
ret: Coroutine = self.proc.communicate(
|
|
input=proc_input.encode())
|
|
else:
|
|
ret = self.proc.wait()
|
|
await asyncio.wait_for(ret, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
complete = False
|
|
await self.cancel(sig_idx)
|
|
else:
|
|
complete = not self.cancelled
|
|
self.factory.remove_running_command(self)
|
|
return self._check_proc_success(
|
|
complete, log_complete, success_codes
|
|
)
|
|
|
|
async def run_with_response(
|
|
self,
|
|
timeout: float = 2.,
|
|
retries: int = 1,
|
|
log_complete: bool = True,
|
|
sig_idx: int = 1,
|
|
proc_input: Optional[str] = None,
|
|
success_codes: Optional[List[int]] = None
|
|
) -> str:
|
|
async with self.run_lock:
|
|
self.factory.add_running_command(self)
|
|
retries = max(1, retries)
|
|
stdin: Optional[bytes] = None
|
|
if proc_input is not None:
|
|
stdin = proc_input.encode()
|
|
while retries > 0:
|
|
self._reset_command_data()
|
|
timed_out = False
|
|
stdout = stderr = b""
|
|
if await self._create_subprocess(has_input=stdin is not None):
|
|
assert self.proc is not None
|
|
try:
|
|
ret = self.proc.communicate(input=stdin)
|
|
stdout, stderr = await asyncio.wait_for(
|
|
ret, timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
complete = False
|
|
timed_out = True
|
|
await self.cancel(sig_idx)
|
|
else:
|
|
complete = not self.cancelled
|
|
if self.log_stderr and stderr:
|
|
logging.info(
|
|
f"{self.command[0]}: "
|
|
f"{stderr.decode(errors='ignore')}")
|
|
if self._check_proc_success(
|
|
complete, log_complete, success_codes
|
|
):
|
|
self.factory.remove_running_command(self)
|
|
return stdout.decode(errors='ignore').rstrip("\n")
|
|
if stdout:
|
|
logging.debug(
|
|
f"Shell command '{self.name}' output:"
|
|
f"\n{stdout.decode(errors='ignore')}")
|
|
if self.cancelled and not timed_out:
|
|
break
|
|
retries -= 1
|
|
await asyncio.sleep(.5)
|
|
self.factory.remove_running_command(self)
|
|
raise ShellCommandError(
|
|
f"Error running shell command: '{self.name}'",
|
|
self.return_code, stdout, stderr)
|
|
|
|
async def _create_subprocess(
|
|
self,
|
|
use_callbacks: bool = False,
|
|
has_input: bool = False
|
|
) -> bool:
|
|
loop = asyncio.get_running_loop()
|
|
|
|
def protocol_factory():
|
|
return ShellCommandProtocol(
|
|
limit=2**20, loop=loop, std_out_cb=self.std_out_cb,
|
|
std_err_cb=self.std_err_cb, log_stderr=self.log_stderr
|
|
)
|
|
try:
|
|
stdpipe: Optional[int] = None
|
|
if has_input:
|
|
stdpipe = asyncio.subprocess.PIPE
|
|
if self.std_err_cb is not None or self.log_stderr:
|
|
errpipe = asyncio.subprocess.PIPE
|
|
else:
|
|
errpipe = asyncio.subprocess.STDOUT
|
|
if use_callbacks:
|
|
transport, protocol = await loop.subprocess_exec(
|
|
protocol_factory, *self.command, stdin=stdpipe,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=errpipe, env=self.env, cwd=self.cwd)
|
|
self.proc = asyncio.subprocess.Process(
|
|
transport, protocol, loop)
|
|
else:
|
|
self.proc = await asyncio.create_subprocess_exec(
|
|
*self.command, stdin=stdpipe,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=errpipe, env=self.env, cwd=self.cwd)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logging.exception(
|
|
f"shell_command: Command ({self.name}) failed")
|
|
return False
|
|
return True
|
|
|
|
def _check_proc_success(
|
|
self,
|
|
complete: bool,
|
|
log_complete: bool,
|
|
success_codes: Optional[List[int]] = None
|
|
) -> bool:
|
|
assert self.proc is not None
|
|
if success_codes is None:
|
|
success_codes = [0]
|
|
self.return_code = self.proc.returncode
|
|
success = self.return_code in success_codes and complete
|
|
if success:
|
|
msg = f"Command ({self.name}) successfully finished"
|
|
elif self.cancelled:
|
|
msg = f"Command ({self.name}) cancelled"
|
|
elif not complete:
|
|
msg = f"Command ({self.name}) timed out"
|
|
else:
|
|
msg = f"Command ({self.name}) exited with return code" \
|
|
f" {self.return_code}"
|
|
if log_complete:
|
|
logging.info(msg)
|
|
return success
|
|
|
|
class ShellCommandFactory:
|
|
error = ShellCommandError
|
|
def __init__(self, config: ConfigHelper) -> None:
|
|
self.running_commands: Set[ShellCommand] = set()
|
|
|
|
def add_running_command(self, cmd: ShellCommand) -> None:
|
|
self.running_commands.add(cmd)
|
|
|
|
def remove_running_command(self, cmd: ShellCommand) -> None:
|
|
try:
|
|
self.running_commands.remove(cmd)
|
|
except KeyError:
|
|
pass
|
|
|
|
def build_shell_command(
|
|
self,
|
|
cmd: str,
|
|
callback: OutputCallback = None,
|
|
std_err_callback: OutputCallback = None,
|
|
env: Optional[Dict[str, str]] = None,
|
|
log_stderr: bool = False,
|
|
cwd: Optional[str] = None
|
|
) -> ShellCommand:
|
|
return ShellCommand(
|
|
self, cmd, callback, std_err_callback, env, log_stderr, cwd
|
|
)
|
|
|
|
def exec_cmd(
|
|
self,
|
|
cmd: str,
|
|
timeout: float = 2.,
|
|
retries: int = 1,
|
|
sig_idx: int = 1,
|
|
proc_input: Optional[str] = None,
|
|
log_complete: bool = True,
|
|
log_stderr: bool = False,
|
|
env: Optional[Dict[str, str]] = None,
|
|
cwd: Optional[str] = None,
|
|
success_codes: Optional[List[int]] = None
|
|
) -> Awaitable:
|
|
scmd = ShellCommand(self, cmd, None, None, env,
|
|
log_stderr, cwd)
|
|
coro = scmd.run_with_response(
|
|
timeout, retries, log_complete, sig_idx,
|
|
proc_input, success_codes
|
|
)
|
|
return asyncio.create_task(coro)
|
|
|
|
async def close(self) -> None:
|
|
for cmd in self.running_commands:
|
|
await cmd.cancel()
|
|
|
|
def load_component(config: ConfigHelper) -> ShellCommandFactory:
|
|
return ShellCommandFactory(config)
|