websockets: require re-auth on user logout

Propagate user state changes to open websockets and unix sockets.
If a websocket's user is logged out require re-authentication.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-01-19 10:54:11 -05:00
parent 06ec5541e3
commit 80862799ed
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
2 changed files with 49 additions and 12 deletions

View File

@ -35,7 +35,7 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
from confighelper import ConfigHelper from confighelper import ConfigHelper
from websockets import WebRequest from websockets import WebRequest, WebsocketManager
from tornado.httputil import HTTPServerRequest from tornado.httputil import HTTPServerRequest
from tornado.web import RequestHandler from tornado.web import RequestHandler
from .database import MoonrakerDatabase as DBComp from .database import MoonrakerDatabase as DBComp
@ -251,8 +251,14 @@ class Authorization:
self.server.register_endpoint( self.server.register_endpoint(
"/access/info", ['GET'], "/access/info", ['GET'],
self._handle_info_request, transports=['http', 'websocket']) self._handle_info_request, transports=['http', 'websocket'])
self.server.register_notification("authorization:user_created") wsm: WebsocketManager = self.server.lookup_component("websockets")
self.server.register_notification("authorization:user_deleted") wsm.register_notification("authorization:user_created")
wsm.register_notification(
"authorization:user_deleted", event_type="logout"
)
wsm.register_notification(
"authorization:user_logged_out", event_type="logout"
)
def register_permited_path(self, path: str) -> None: def register_permited_path(self, path: str) -> None:
self.permitted_paths.add(path) self.permitted_paths.add(path)
@ -311,6 +317,11 @@ class Authorization:
jwk_id: str = self.users[username].pop("jwk_id", None) jwk_id: str = self.users[username].pop("jwk_id", None)
self._sync_user(username) self._sync_user(username)
self.public_jwks.pop(jwk_id, None) self.public_jwks.pop(jwk_id, None)
eventloop = self.server.get_event_loop()
eventloop.delay_callback(
.005, self.server.send_event, "authorization:user_logged_out",
{'username': username}
)
return { return {
"username": username, "username": username,
"action": "user_logged_out" "action": "user_logged_out"

View File

@ -317,17 +317,22 @@ class WebsocketManager(APITransport):
self.rpc.register_method( self.rpc.register_method(
"server.connection.identify", self._handle_identify) "server.connection.identify", self._handle_identify)
def register_notification(self, def register_notification(
event_name: str, self,
notify_name: Optional[str] = None event_name: str,
) -> None: notify_name: Optional[str] = None,
event_type: Optional[str] = None
) -> None:
if notify_name is None: if notify_name is None:
notify_name = event_name.split(':')[-1] notify_name = event_name.split(':')[-1]
if event_type == "logout":
def notify_handler(*args): def notify_handler(*args):
self.notify_clients(notify_name, args) self.notify_clients(notify_name, args)
self.server.register_event_handler( self._process_logout(*args)
event_name, notify_handler) else:
def notify_handler(*args):
self.notify_clients(notify_name, args)
self.server.register_event_handler(event_name, notify_handler)
def register_api_handler(self, api_def: APIDefinition) -> None: def register_api_handler(self, api_def: APIDefinition) -> None:
if api_def.callback is None: if api_def.callback is None:
@ -417,6 +422,13 @@ class WebsocketManager(APITransport):
self.server.send_event("websockets:client_identified", sc) self.server.send_event("websockets:client_identified", sc)
return {'connection_id': sc.uid} return {'connection_id': sc.uid}
def _process_logout(self, user: Dict[str, Any]) -> None:
if "username" not in user:
return
name = user["username"]
for sc in self.clients.values():
sc.on_user_logout(name)
def has_socket(self, ws_id: int) -> bool: def has_socket(self, ws_id: int) -> bool:
return ws_id in self.clients return ws_id in self.clients
@ -589,6 +601,14 @@ class BaseSocketClient(Subscribable):
elif not auth.is_path_permitted(path): elif not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401) raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool:
if self._user_info is None:
return False
if user == self._user_info.get("username", ""):
self._user_info = None
return True
return False
async def _write_messages(self): async def _write_messages(self):
if self.is_closed: if self.is_closed:
self.message_buf = [] self.message_buf = []
@ -729,6 +749,12 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
return False return False
return True return True
def on_user_logout(self, user: str) -> bool:
if super().on_user_logout(user):
self._need_auth = True
return True
return False
# Check Authorized User # Check Authorized User
def prepare(self) -> None: def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"] max_conns = self.settings["max_websocket_connections"]