authorization: authenticate over websocket

Register all of the "access" endpoints with the websocket.  Front
ends may now connect to the websocket without an oneshot token
and login.  If the front end already has a JWT for the user it
can be passed to the "identify" endpoint to authenticate directly.

Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-11-21 15:14:48 -05:00
parent c83714bfe8
commit 4ca39bec0a
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
4 changed files with 89 additions and 44 deletions

View File

@ -224,37 +224,40 @@ class Authorization:
self.permitted_paths.add("/access/info")
self.server.register_endpoint(
"/access/login", ['POST'], self._handle_login,
transports=['http'])
transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/logout", ['POST'], self._handle_logout,
transports=['http'])
transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/refresh_jwt", ['POST'], self._handle_refresh_jwt,
transports=['http'])
transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/user", ['GET', 'POST', 'DELETE'],
self._handle_user_request, transports=['http'])
self._handle_user_request, transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/users/list", ['GET'], self._handle_list_request,
transports=['http'])
transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/user/password", ['POST'], self._handle_password_reset,
transports=['http'])
transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/api_key", ['GET', 'POST'],
self._handle_apikey_request, transports=['http'])
self._handle_apikey_request, transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/oneshot_token", ['GET'],
self._handle_oneshot_request, transports=['http'])
self._handle_oneshot_request, transports=['http', 'websocket'])
self.server.register_endpoint(
"/access/info", ['GET'],
self._handle_info_request, transports=['http'])
self._handle_info_request, transports=['http', 'websocket'])
self.server.register_notification("authorization:user_created")
self.server.register_notification("authorization:user_deleted")
def register_permited_path(self, path: str) -> None:
self.permitted_paths.add(path)
def is_path_permitted(self, path: str) -> bool:
return path in self.permitted_paths
def _sync_user(self, username: str) -> None:
self.user_db[username] = self.users[username]
@ -311,7 +314,7 @@ class Authorization:
) -> Dict[str, str]:
refresh_token: str = web_request.get_str('refresh_token')
try:
user_info = self._decode_jwt(refresh_token, token_type="refresh")
user_info = self.decode_jwt(refresh_token, token_type="refresh")
except Exception:
raise self.server.error("Invalid Refresh Token", 401)
username: str = user_info['username']
@ -474,12 +477,15 @@ class Authorization:
refresh_token = self._generate_jwt(
username, jwk_id, private_key, token_type="refresh",
exp_time=datetime.timedelta(days=self.login_timeout))
conn = web_request.get_client_connection()
if create:
event_loop = self.server.get_event_loop()
event_loop.delay_callback(
.005, self.server.send_event,
"authorization:user_created",
{'username': username})
elif conn is not None:
conn.user_info = user_info
return {
'username': username,
'token': token,
@ -541,10 +547,9 @@ class Authorization:
jwt_sig = base64url_encode(sig)
return b".".join([jwt_msg, jwt_sig]).decode()
def _decode_jwt(self,
token: str,
token_type: str = "access"
) -> Dict[str, Any]:
def decode_jwt(
self, token: str, token_type: str = "access"
) -> Dict[str, Any]:
message, sig = token.rsplit('.', maxsplit=1)
enc_header, enc_payload = message.split('.')
header: Dict[str, Any] = json.loads(base64url_decode(enc_header))
@ -653,7 +658,7 @@ class Authorization:
401, f"Invalid Authorization Header: {auth_token}")
if auth_token:
try:
return self._decode_jwt(auth_token)
return self.decode_jwt(auth_token)
except Exception:
logging.exception(f"JWT Decode Error {auth_token}")
raise HTTPError(401, f"Error decoding JWT: {auth_token}")

View File

@ -71,8 +71,8 @@ class ExtensionManager:
connection.send_notification("agent_event", [evt])
async def _handle_agent_event(self, web_request: WebRequest) -> str:
conn = web_request.get_connection()
if not isinstance(conn, BaseSocketClient):
conn = web_request.get_client_connection()
if conn is None:
raise self.server.error("No connection detected")
if conn.client_data["type"] != "agent":
raise self.server.error(

View File

@ -494,7 +494,7 @@ class KlippyConnection:
web_request: WebRequest
) -> Dict[str, Any]:
args = web_request.get_args()
conn = web_request.get_connection()
conn = web_request.get_subscribable()
# Build the subscription request from a superset of all client
# subscriptions

View File

@ -80,9 +80,14 @@ class WebRequest:
def get_args(self) -> Dict[str, Any]:
return self.args
def get_connection(self) -> Optional[Subscribable]:
def get_subscribable(self) -> Optional[Subscribable]:
return self.conn
def get_client_connection(self) -> Optional[BaseSocketClient]:
if isinstance(self.conn, BaseSocketClient):
return self.conn
return None
def get_ip_address(self) -> Optional[IPUnion]:
return self.ip_addr
@ -258,6 +263,8 @@ class JsonRPC:
code = e.status_code
if code == 404:
code = -32601
elif code == 401:
code = -32602
return self.build_error(code, str(e), req_id, True)
except Exception as e:
return self.build_error(-31000, str(e), req_id, True)
@ -326,39 +333,36 @@ class WebsocketManager(APITransport):
if api_def.callback is None:
# Remote API, uses RPC to reach out to Klippy
ws_method = api_def.jrpc_methods[0]
rpc_cb = self._generate_callback(api_def.endpoint)
rpc_cb = self._generate_callback(
api_def.endpoint, "", self.klippy.request
)
self.rpc.register_method(ws_method, rpc_cb)
else:
# Local API, uses local callback
for ws_method, req_method in \
zip(api_def.jrpc_methods, api_def.request_methods):
rpc_cb = self._generate_local_callback(
api_def.endpoint, req_method, api_def.callback)
rpc_cb = self._generate_callback(
api_def.endpoint, req_method, api_def.callback
)
self.rpc.register_method(ws_method, rpc_cb)
logging.info(
"Registering Websocket JSON-RPC methods: "
f"{', '.join(api_def.jrpc_methods)}")
f"{', '.join(api_def.jrpc_methods)}"
)
def remove_api_handler(self, api_def: APIDefinition) -> None:
for jrpc_method in api_def.jrpc_methods:
self.rpc.remove_method(jrpc_method)
def _generate_callback(self, endpoint: str) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
sc: BaseSocketClient = args.pop("_socket_")
result = await self.klippy.request(
WebRequest(endpoint, args, conn=sc, ip_addr=sc.ip_addr,
user=sc.user_info))
return result
return func
def _generate_local_callback(self,
endpoint: str,
request_method: str,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
def _generate_callback(
self,
endpoint: str,
request_method: str,
callback: Callable[[WebRequest], Coroutine]
) -> RPCCallback:
async def func(args: Dict[str, Any]) -> Any:
sc: BaseSocketClient = args.pop("_socket_")
sc.authenticate(path=endpoint)
result = await callback(
WebRequest(endpoint, args, request_method, sc,
ip_addr=sc.ip_addr, user=sc.user_info))
@ -367,10 +371,12 @@ class WebsocketManager(APITransport):
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"]
sc.authenticate()
return {'websocket_id': sc.uid}
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
sc: BaseSocketClient = args["_socket_"]
sc.authenticate(token=args.get("access_token", None))
if sc.identified:
raise self.server.error(
f"Connection already identified: {sc.client_data}"
@ -468,7 +474,7 @@ class WebsocketManager(APITransport):
if data:
msg['params'] = data
for sc in list(self.clients.values()):
if sc.uid in mask:
if sc.uid in mask or sc.need_auth:
continue
sc.queue_message(msg)
@ -507,10 +513,21 @@ class BaseSocketClient(Subscribable):
"type": "",
"url": ""
}
self._need_auth: bool = False
self._user_info: Optional[Dict[str, Any]] = None
@property
def user_info(self) -> Optional[Dict[str, Any]]:
return None
return self._user_info
@user_info.setter
def user_info(self, uinfo: Dict[str, Any]) -> None:
self._user_info = uinfo
self._need_auth = False
@property
def need_auth(self) -> bool:
return self._need_auth
@property
def uid(self) -> int:
@ -552,6 +569,25 @@ class BaseSocketClient(Subscribable):
self.queue_busy = True
self.eventloop.register_callback(self._write_messages)
def authenticate(self, path: str = "", token: Optional[str] = None) -> None:
if not self._need_auth:
return
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None:
return
if token is not None:
try:
user_info = auth.decode_jwt(token)
except self.server.error:
raise
except Exception as e:
raise self.server.error(
f"Failed to decode JWT: {e}", 401
) from e
self.user_info = user_info
elif not auth.is_path_permitted(path):
raise self.server.error("Unauthorized", 401)
async def _write_messages(self):
if self.is_closed:
self.message_buf = []
@ -619,14 +655,13 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
self.ip_addr: str = self.request.remote_ip or ""
self.last_pong_time: float = self.eventloop.get_loop_time()
@property
def user_info(self) -> Optional[Dict[str, Any]]:
return self.current_user
@property
def hostname(self) -> str:
return self.request.host_name
def get_current_user(self) -> Any:
return self._user_info
def open(self, *args, **kwargs) -> None:
self.set_nodelay(True)
self._connected_time = self.eventloop.get_loop_time()
@ -693,7 +728,12 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
def prepare(self):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
self.current_user = auth.check_authorized(self.request)
try:
self._user_info = auth.check_authorized(self.request)
except Exception as e:
logging.info(f"Websocket Failed Authentication: {e}")
self._user_info = None
self._need_auth = True
def close_socket(self, code: int, reason: str) -> None:
self.close(code, reason)