shell_command: add annotations

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-13 20:16:26 -04:00
parent 41ddbb16a8
commit ce7495ecce
1 changed files with 86 additions and 31 deletions

View File

@ -3,6 +3,8 @@
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com> # Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
# #
# This file may be distributed under the terms of the GNU GPLv3 license. # This file may be distributed under the terms of the GNU GPLv3 license.
from __future__ import annotations
import os import os
import shlex import shlex
import logging import logging
@ -11,24 +13,48 @@ import asyncio
from tornado import gen from tornado import gen
from utils import ServerError from utils import ServerError
# Annotation imports
from typing import (
TYPE_CHECKING,
Optional,
Callable,
Coroutine,
Dict,
)
if TYPE_CHECKING:
from confighelper import ConfigHelper
from asyncio import BaseTransport
OutputCallback = Optional[Callable[[bytes], None]]
class ShellCommandError(ServerError): class ShellCommandError(ServerError):
def __init__(self, message, return_code, stdout=b"", def __init__(self,
stderr=b"", status_code=500): message: str,
return_code: Optional[int],
stdout: Optional[bytes] = b"",
stderr: Optional[bytes] = b"",
status_code: int = 500
) -> None:
super().__init__(message, status_code=status_code) super().__init__(message, status_code=status_code)
self.stdout = stdout or b"" self.stdout = stdout or b""
self.stderr = stderr or b"" self.stderr = stderr or b""
self.return_code = return_code self.return_code = return_code
class SCProcess(asyncio.subprocess.Process): class SCProcess(asyncio.subprocess.Process):
def initialize(self, program_name, std_out_cb, std_err_cb, log_stderr): def initialize(self,
program_name: str,
std_out_cb: OutputCallback,
std_err_cb: OutputCallback,
log_stderr: bool
) -> None:
self.program_name = program_name self.program_name = program_name
self.std_out_cb = std_out_cb self.std_out_cb = std_out_cb
self.std_err_cb = std_err_cb self.std_err_cb = std_err_cb
self.log_stderr = log_stderr self.log_stderr = log_stderr
self.cancel_requested = False self.cancel_requested = False
async def _read_stream_with_cb(self, fd): async def _read_stream_with_cb(self, fd: int) -> bytes:
transport = self._transport.get_pipe_transport(fd) transport: BaseTransport = \
self._transport.get_pipe_transport(fd) # type: ignore
if fd == 2: if fd == 2:
stream = self.stderr stream = self.stderr
cb = self.std_err_cb cb = self.std_err_cb
@ -36,6 +62,7 @@ class SCProcess(asyncio.subprocess.Process):
assert fd == 1 assert fd == 1
stream = self.stdout stream = self.stdout
cb = self.std_out_cb cb = self.std_out_cb
assert stream is not None
while not stream.at_eof(): while not stream.at_eof():
output = await stream.readline() output = await stream.readline()
if not output: if not output:
@ -48,7 +75,7 @@ class SCProcess(asyncio.subprocess.Process):
transport.close() transport.close()
return output return output
async def cancel(self, sig_idx=1): async def cancel(self, sig_idx: int = 1) -> None:
if self.cancel_requested: if self.cancel_requested:
return return
self.cancel_requested = True self.cancel_requested = True
@ -73,30 +100,38 @@ class SCProcess(asyncio.subprocess.Process):
if self.stderr is not None: if self.stderr is not None:
self.stderr.feed_eof() self.stderr.feed_eof()
async def communicate_with_cb(self, input=None): async def communicate_with_cb(self,
input: Optional[bytes] = None
) -> None:
if input is not None: if input is not None:
stdin = self._feed_stdin(input) stdin: Coroutine = self._feed_stdin(input) # type: ignore
else: else:
stdin = self._noop() stdin = self._noop() # type: ignore
if self.stdout is not None and self.std_out_cb is not None: if self.stdout is not None and self.std_out_cb is not None:
stdout = self._read_stream_with_cb(1) stdout: Coroutine = self._read_stream_with_cb(1)
else: else:
stdout = self._noop() stdout = self._noop() # type: ignore
has_err_output = self.std_err_cb is not None or self.log_stderr has_err_output = self.std_err_cb is not None or self.log_stderr
if self.stderr is not None and has_err_output: if self.stderr is not None and has_err_output:
stderr = self._read_stream_with_cb(2) stderr: Coroutine = self._read_stream_with_cb(2)
else: else:
stderr = self._noop() stderr = self._noop() # type: ignore
stdin, stdout, stderr = await asyncio.tasks.gather( stdin, stdout, stderr = await asyncio.tasks.gather(
stdin, stdout, stderr, loop=self._loop) stdin, stdout, stderr, loop=self._loop) # type: ignore
await self.wait() await self.wait()
class ShellCommand: class ShellCommand:
IDX_SIGINT = 0 IDX_SIGINT = 0
IDX_SIGTERM = 1 IDX_SIGTERM = 1
IDX_SIGKILL = 2 IDX_SIGKILL = 2
def __init__(self, cmd, std_out_callback, std_err_callback, def __init__(self,
env=None, log_stderr=False, cwd=None): cmd: str,
std_out_callback: OutputCallback,
std_err_callback: OutputCallback,
env: Optional[Dict[str, str]] = None,
log_stderr: bool = False,
cwd: Optional[str] = None
) -> None:
self.name = cmd self.name = cmd
self.std_out_cb = std_out_callback self.std_out_cb = std_out_callback
self.std_err_cb = std_err_callback self.std_err_cb = std_err_callback
@ -105,24 +140,28 @@ class ShellCommand:
self.log_stderr = log_stderr self.log_stderr = log_stderr
self.env = env self.env = env
self.cwd = cwd self.cwd = cwd
self.proc = None self.proc: Optional[SCProcess] = None
self.cancelled = False self.cancelled = False
self.return_code = None self.return_code: Optional[int] = None
async def cancel(self, sig_idx=1): async def cancel(self, sig_idx: int = 1) -> None:
self.cancelled = True self.cancelled = True
if self.proc is not None: if self.proc is not None:
await self.proc.cancel(sig_idx) await self.proc.cancel(sig_idx)
def get_return_code(self): def get_return_code(self) -> Optional[int]:
return self.return_code return self.return_code
def _reset_command_data(self): def _reset_command_data(self) -> None:
self.return_code = self.proc = None self.return_code = self.proc = None
self.cancelled = False self.cancelled = False
async def run(self, timeout=2., verbose=True, log_complete=True, async def run(self,
sig_idx=1): timeout: float = 2.,
verbose: bool = True,
log_complete: bool = True,
sig_idx: int = 1
) -> bool:
self._reset_command_data() self._reset_command_data()
if not timeout: if not timeout:
# Never timeout # Never timeout
@ -133,9 +172,10 @@ class ShellCommand:
verbose = False verbose = False
if not await self._create_subprocess(): if not await self._create_subprocess():
return False return False
assert self.proc is not None
try: try:
if verbose: if verbose:
ret = self.proc.communicate_with_cb() ret: Coroutine = self.proc.communicate_with_cb()
else: else:
ret = self.proc.wait() ret = self.proc.wait()
await asyncio.wait_for(ret, timeout=timeout) await asyncio.wait_for(ret, timeout=timeout)
@ -146,13 +186,18 @@ class ShellCommand:
complete = not self.cancelled complete = not self.cancelled
return self._check_proc_success(complete, log_complete) return self._check_proc_success(complete, log_complete)
async def run_with_response(self, timeout=2., retries=1, async def run_with_response(self,
log_complete=True, sig_idx=1): timeout: float = 2.,
retries: int = 1,
log_complete: bool = True,
sig_idx: int = 1
) -> str:
self._reset_command_data() self._reset_command_data()
retries = max(1, retries) retries = max(1, retries)
while retries > 0: while retries > 0:
stdout = stderr = b"" stdout = stderr = b""
if await self._create_subprocess(): if await self._create_subprocess():
assert self.proc is not None
try: try:
ret = self.proc.communicate() ret = self.proc.communicate()
stdout, stderr = await asyncio.wait_for( stdout, stderr = await asyncio.wait_for(
@ -176,7 +221,7 @@ class ShellCommand:
f"Error running shell command: '{self.command}'", f"Error running shell command: '{self.command}'",
self.return_code, stdout, stderr) self.return_code, stdout, stderr)
async def _create_subprocess(self): async def _create_subprocess(self) -> bool:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def protocol_factory(): def protocol_factory():
@ -200,7 +245,11 @@ class ShellCommand:
return False return False
return True return True
def _check_proc_success(self, complete, log_complete): def _check_proc_success(self,
complete: bool,
log_complete: bool
) -> bool:
assert self.proc is not None
self.return_code = self.proc.returncode self.return_code = self.proc.returncode
success = self.return_code == 0 and complete success = self.return_code == 0 and complete
if success: if success:
@ -218,10 +267,16 @@ class ShellCommand:
class ShellCommandFactory: class ShellCommandFactory:
error = ShellCommandError error = ShellCommandError
def build_shell_command(self, cmd, callback=None, std_err_callback=None, def build_shell_command(self,
env=None, log_stderr=False, cwd=None): cmd: str,
callback: OutputCallback = None,
std_err_callback: OutputCallback = None,
env: Optional[Dict[str, str]] = None,
log_stderr: bool = False,
cwd: Optional[str] = None
) -> ShellCommand:
return ShellCommand(cmd, callback, std_err_callback, env, return ShellCommand(cmd, callback, std_err_callback, env,
log_stderr, cwd) log_stderr, cwd)
def load_component(config): def load_component(config: ConfigHelper) -> ShellCommandFactory:
return ShellCommandFactory() return ShellCommandFactory()