websockets: require re-auth on user logout

Propagate user state changes to open websockets and unix sockets.
If a websocket's user is logged out require re-authentication.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-01-19 10:54:11 -05:00
parent 06ec5541e3
commit 80862799ed
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
2 changed files with 49 additions and 12 deletions

View File

@ -35,7 +35,7 @@ from typing import (
if TYPE_CHECKING:
from confighelper import ConfigHelper
from websockets import WebRequest
from websockets import WebRequest, WebsocketManager
from tornado.httputil import HTTPServerRequest
from tornado.web import RequestHandler
from .database import MoonrakerDatabase as DBComp
@ -251,8 +251,14 @@ class Authorization:
self.server.register_endpoint(
"/access/info", ['GET'],
self._handle_info_request, transports=['http', 'websocket'])
self.server.register_notification("authorization:user_created")
self.server.register_notification("authorization:user_deleted")
wsm: WebsocketManager = self.server.lookup_component("websockets")
wsm.register_notification("authorization:user_created")
wsm.register_notification(
"authorization:user_deleted", event_type="logout"
)
wsm.register_notification(
"authorization:user_logged_out", event_type="logout"
)
def register_permited_path(self, path: str) -> None:
self.permitted_paths.add(path)
@ -311,6 +317,11 @@ class Authorization:
jwk_id: str = self.users[username].pop("jwk_id", None)
self._sync_user(username)
self.public_jwks.pop(jwk_id, None)
eventloop = self.server.get_event_loop()
eventloop.delay_callback(
.005, self.server.send_event, "authorization:user_logged_out",
{'username': username}
)
return {
"username": username,
"action": "user_logged_out"

View File

@ -317,17 +317,22 @@ class WebsocketManager(APITransport):
self.rpc.register_method(
"server.connection.identify", self._handle_identify)
def register_notification(self,
event_name: str,
notify_name: Optional[str] = None
) -> None:
def register_notification(
self,
event_name: str,
notify_name: Optional[str] = None,
event_type: Optional[str] = None
) -> None:
if notify_name is None:
notify_name = event_name.split(':')[-1]
def notify_handler(*args):
self.notify_clients(notify_name, args)
self.server.register_event_handler(
event_name, notify_handler)
if event_type == "logout":
def notify_handler(*args):
self.notify_clients(notify_name, args)
self._process_logout(*args)
else:
def notify_handler(*args):
self.notify_clients(notify_name, args)
self.server.register_event_handler(event_name, notify_handler)
def register_api_handler(self, api_def: APIDefinition) -> None:
if api_def.callback is None:
@ -417,6 +422,13 @@ class WebsocketManager(APITransport):
self.server.send_event("websockets:client_identified", sc)
return {'connection_id': sc.uid}
def _process_logout(self, user: Dict[str, Any]) -> None:
if "username" not in user:
return
name = user["username"]
for sc in self.clients.values():
sc.on_user_logout(name)
def has_socket(self, ws_id: int) -> bool:
return ws_id in self.clients
@ -589,6 +601,14 @@ class BaseSocketClient(Subscribable):
elif not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool:
if self._user_info is None:
return False
if user == self._user_info.get("username", ""):
self._user_info = None
return True
return False
async def _write_messages(self):
if self.is_closed:
self.message_buf = []
@ -729,6 +749,12 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
return False
return True
def on_user_logout(self, user: str) -> bool:
if super().on_user_logout(user):
self._need_auth = True
return True
return False
# Check Authorized User
def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"]