machine: introduce custom allow list for service control

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-12-29 06:29:38 -05:00
parent b3a9447392
commit 690f841768
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
1 changed files with 55 additions and 16 deletions

View File

@ -54,10 +54,14 @@ if TYPE_CHECKING:
SudoReturn = Union[Awaitable[Tuple[str, bool]], Tuple[str, bool]] SudoReturn = Union[Awaitable[Tuple[str, bool]], Tuple[str, bool]]
SudoCallback = Callable[[], SudoReturn] SudoCallback = Callable[[], SudoReturn]
ALLOWED_SERVICES = [ DEFAULT_ALLOWED_SERVICES = [
"moonraker", "klipper", "webcamd", "MoonCord", "klipper_mcu",
"KlipperScreen", "moonraker-telegram-bot", "webcamd",
"sonar", "crowsnest" "MoonCord",
"KlipperScreen",
"moonraker-telegram-bot",
"sonar",
"crowsnest"
] ]
CGROUP_PATH = "/proc/1/cgroup" CGROUP_PATH = "/proc/1/cgroup"
SCHED_PATH = "/proc/1/sched" SCHED_PATH = "/proc/1/sched"
@ -80,6 +84,8 @@ SERVICE_PROPERTIES = [
class Machine: class Machine:
def __init__(self, config: ConfigHelper) -> None: def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
self._allowed_services: List[str] = []
self._init_allowed_services()
dist_info: Dict[str, Any] dist_info: Dict[str, Any]
dist_info = {'name': distro.name(pretty=True)} dist_info = {'name': distro.name(pretty=True)}
dist_info.update(distro.info()) dist_info.update(distro.info())
@ -161,15 +167,39 @@ class Machine:
self.iwgetid_cmd = shell_cmd.build_shell_command(iwgetbin) self.iwgetid_cmd = shell_cmd.build_shell_command(iwgetbin)
self.init_evt = asyncio.Event() self.init_evt = asyncio.Event()
def _init_allowed_services(self) -> None:
app_args = self.server.get_app_args()
data_path = app_args["data_path"]
fpath = pathlib.Path(data_path).joinpath("moonraker.asvc")
fm: FileManager = self.server.lookup_component("file_manager")
fm.add_reserved_path("allowed_services", fpath, False)
try:
if not fpath.exists():
fpath.write_text("\n".join(DEFAULT_ALLOWED_SERVICES))
data = fpath.read_text()
except Exception:
logging.exception("Failed to read allowed_services.txt")
self._allowed_services = DEFAULT_ALLOWED_SERVICES
else:
svcs = [svc.strip() for svc in data.split("\n") if svc.strip()]
for svc in svcs:
if svc.endswith(".service"):
svc = svc.rsplit(".", 1)[0]
if svc not in self._allowed_services:
self._allowed_services.append(svc)
def _update_log_rollover(self, log: bool = False) -> None: def _update_log_rollover(self, log: bool = False) -> None:
sys_info_msg = "\nSystem Info:" sys_info_msg = "\nSystem Info:"
for header, info in self.system_info.items(): for header, info in self.system_info.items():
sys_info_msg += f"\n\n***{header}***" sys_info_msg += f"\n\n***{header}***"
if not isinstance(info, dict): if not isinstance(info, dict):
sys_info_msg += f"\n {repr(info)}" sys_info_msg += f"\n {repr(info)}"
else: else:
for key, val in info.items(): for key, val in info.items():
sys_info_msg += f"\n {key}: {val}" sys_info_msg += f"\n {key}: {val}"
sys_info_msg += f"\n\n***Allowed Services***"
for svc in self._allowed_services:
sys_info_msg += f"\n {svc}"
self.server.add_log_rollover_item('system_info', sys_info_msg, log=log) self.server.add_log_rollover_item('system_info', sys_info_msg, log=log)
@property @property
@ -182,6 +212,13 @@ class Machine:
unit_name = svc_info.get("unit_name", "moonraker.service") unit_name = svc_info.get("unit_name", "moonraker.service")
return unit_name.split(".", 1)[0] return unit_name.split(".", 1)[0]
def is_service_allowed(self, service: str) -> bool:
return (
service in self._allowed_services or
re.match(r"moonraker[_-]?\d*", service) is not None or
re.match(r"klipper[_-]?\d*", service) is not None
)
def validation_enabled(self) -> bool: def validation_enabled(self) -> bool:
return self.validator.validation_enabled return self.validator.validation_enabled
@ -270,7 +307,7 @@ class Machine:
elif self.sys_provider.is_service_available(name): elif self.sys_provider.is_service_available(name):
await self.do_service_action(action, name) await self.do_service_action(action, name)
else: else:
if name in ALLOWED_SERVICES: if name in self._allowed_services:
raise self.server.error(f"Service '{name}' not installed") raise self.server.error(f"Service '{name}' not installed")
raise self.server.error( raise self.server.error(
f"Service '{name}' not allowed") f"Service '{name}' not allowed")
@ -822,7 +859,8 @@ class SystemdCliProvider(BaseProvider):
'virt_identifier': virt_id 'virt_identifier': virt_id
} }
async def _detect_active_services(self): async def _detect_active_services(self) -> None:
machine: Machine = self.server.lookup_component("machine")
try: try:
resp: str = await self.shell_cmd.exec_cmd( resp: str = await self.shell_cmd.exec_cmd(
"systemctl list-units --all --type=service --plain" "systemctl list-units --all --type=service --plain"
@ -834,12 +872,11 @@ class SystemdCliProvider(BaseProvider):
services = [] services = []
for svc in services: for svc in services:
sname = svc.rsplit('.', 1)[0] sname = svc.rsplit('.', 1)[0]
for allowed in ALLOWED_SERVICES: if machine.is_service_allowed(sname):
if sname.startswith(allowed): self.available_services[sname] = {
self.available_services[sname] = { 'active_state': "unknown",
'active_state': "unknown", 'sub_state': "unknown"
'sub_state': "unknown" }
}
async def _update_service_status(self, async def _update_service_status(self,
sequence: int, sequence: int,
@ -1050,11 +1087,13 @@ class SystemdDbusProvider(BaseProvider):
async def _detect_active_services(self) -> None: async def _detect_active_services(self) -> None:
# Get loaded service # Get loaded service
mgr = self.systemd_mgr mgr = self.systemd_mgr
patterns = [f"{svc}*.service" for svc in ALLOWED_SERVICES] machine: Machine = self.server.lookup_component("machine")
units = await mgr.call_list_units_by_patterns( # type: ignore units: List[str]
["loaded"], patterns) units = await mgr.call_list_units_filtered(["loaded"]) # type: ignore
for unit in units: for unit in units:
name: str = unit[0].split('.')[0] name: str = unit[0].split('.')[0]
if not machine.is_service_allowed(name):
continue
state: str = unit[3] state: str = unit[3]
substate: str = unit[4] substate: str = unit[4]
dbus_path: str = unit[6] dbus_path: str = unit[6]