authorization: authenticate over websocket
Register all of the "access" endpoints with the websocket. Front ends may now connect to the websocket without an oneshot token and login. If the front end already has a JWT for the user it can be passed to the "identify" endpoint to authenticate directly. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
c83714bfe8
commit
4ca39bec0a
|
@ -224,37 +224,40 @@ class Authorization:
|
|||
self.permitted_paths.add("/access/info")
|
||||
self.server.register_endpoint(
|
||||
"/access/login", ['POST'], self._handle_login,
|
||||
transports=['http'])
|
||||
transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/logout", ['POST'], self._handle_logout,
|
||||
transports=['http'])
|
||||
transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/refresh_jwt", ['POST'], self._handle_refresh_jwt,
|
||||
transports=['http'])
|
||||
transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/user", ['GET', 'POST', 'DELETE'],
|
||||
self._handle_user_request, transports=['http'])
|
||||
self._handle_user_request, transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/users/list", ['GET'], self._handle_list_request,
|
||||
transports=['http'])
|
||||
transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/user/password", ['POST'], self._handle_password_reset,
|
||||
transports=['http'])
|
||||
transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/api_key", ['GET', 'POST'],
|
||||
self._handle_apikey_request, transports=['http'])
|
||||
self._handle_apikey_request, transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/oneshot_token", ['GET'],
|
||||
self._handle_oneshot_request, transports=['http'])
|
||||
self._handle_oneshot_request, transports=['http', 'websocket'])
|
||||
self.server.register_endpoint(
|
||||
"/access/info", ['GET'],
|
||||
self._handle_info_request, transports=['http'])
|
||||
self._handle_info_request, transports=['http', 'websocket'])
|
||||
self.server.register_notification("authorization:user_created")
|
||||
self.server.register_notification("authorization:user_deleted")
|
||||
|
||||
def register_permited_path(self, path: str) -> None:
|
||||
self.permitted_paths.add(path)
|
||||
|
||||
def is_path_permitted(self, path: str) -> bool:
|
||||
return path in self.permitted_paths
|
||||
|
||||
def _sync_user(self, username: str) -> None:
|
||||
self.user_db[username] = self.users[username]
|
||||
|
||||
|
@ -311,7 +314,7 @@ class Authorization:
|
|||
) -> Dict[str, str]:
|
||||
refresh_token: str = web_request.get_str('refresh_token')
|
||||
try:
|
||||
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
||||
user_info = self.decode_jwt(refresh_token, token_type="refresh")
|
||||
except Exception:
|
||||
raise self.server.error("Invalid Refresh Token", 401)
|
||||
username: str = user_info['username']
|
||||
|
@ -474,12 +477,15 @@ class Authorization:
|
|||
refresh_token = self._generate_jwt(
|
||||
username, jwk_id, private_key, token_type="refresh",
|
||||
exp_time=datetime.timedelta(days=self.login_timeout))
|
||||
conn = web_request.get_client_connection()
|
||||
if create:
|
||||
event_loop = self.server.get_event_loop()
|
||||
event_loop.delay_callback(
|
||||
.005, self.server.send_event,
|
||||
"authorization:user_created",
|
||||
{'username': username})
|
||||
elif conn is not None:
|
||||
conn.user_info = user_info
|
||||
return {
|
||||
'username': username,
|
||||
'token': token,
|
||||
|
@ -541,10 +547,9 @@ class Authorization:
|
|||
jwt_sig = base64url_encode(sig)
|
||||
return b".".join([jwt_msg, jwt_sig]).decode()
|
||||
|
||||
def _decode_jwt(self,
|
||||
token: str,
|
||||
token_type: str = "access"
|
||||
) -> Dict[str, Any]:
|
||||
def decode_jwt(
|
||||
self, token: str, token_type: str = "access"
|
||||
) -> Dict[str, Any]:
|
||||
message, sig = token.rsplit('.', maxsplit=1)
|
||||
enc_header, enc_payload = message.split('.')
|
||||
header: Dict[str, Any] = json.loads(base64url_decode(enc_header))
|
||||
|
@ -653,7 +658,7 @@ class Authorization:
|
|||
401, f"Invalid Authorization Header: {auth_token}")
|
||||
if auth_token:
|
||||
try:
|
||||
return self._decode_jwt(auth_token)
|
||||
return self.decode_jwt(auth_token)
|
||||
except Exception:
|
||||
logging.exception(f"JWT Decode Error {auth_token}")
|
||||
raise HTTPError(401, f"Error decoding JWT: {auth_token}")
|
||||
|
|
|
@ -71,8 +71,8 @@ class ExtensionManager:
|
|||
connection.send_notification("agent_event", [evt])
|
||||
|
||||
async def _handle_agent_event(self, web_request: WebRequest) -> str:
|
||||
conn = web_request.get_connection()
|
||||
if not isinstance(conn, BaseSocketClient):
|
||||
conn = web_request.get_client_connection()
|
||||
if conn is None:
|
||||
raise self.server.error("No connection detected")
|
||||
if conn.client_data["type"] != "agent":
|
||||
raise self.server.error(
|
||||
|
|
|
@ -494,7 +494,7 @@ class KlippyConnection:
|
|||
web_request: WebRequest
|
||||
) -> Dict[str, Any]:
|
||||
args = web_request.get_args()
|
||||
conn = web_request.get_connection()
|
||||
conn = web_request.get_subscribable()
|
||||
|
||||
# Build the subscription request from a superset of all client
|
||||
# subscriptions
|
||||
|
|
|
@ -80,9 +80,14 @@ class WebRequest:
|
|||
def get_args(self) -> Dict[str, Any]:
|
||||
return self.args
|
||||
|
||||
def get_connection(self) -> Optional[Subscribable]:
|
||||
def get_subscribable(self) -> Optional[Subscribable]:
|
||||
return self.conn
|
||||
|
||||
def get_client_connection(self) -> Optional[BaseSocketClient]:
|
||||
if isinstance(self.conn, BaseSocketClient):
|
||||
return self.conn
|
||||
return None
|
||||
|
||||
def get_ip_address(self) -> Optional[IPUnion]:
|
||||
return self.ip_addr
|
||||
|
||||
|
@ -258,6 +263,8 @@ class JsonRPC:
|
|||
code = e.status_code
|
||||
if code == 404:
|
||||
code = -32601
|
||||
elif code == 401:
|
||||
code = -32602
|
||||
return self.build_error(code, str(e), req_id, True)
|
||||
except Exception as e:
|
||||
return self.build_error(-31000, str(e), req_id, True)
|
||||
|
@ -326,39 +333,36 @@ class WebsocketManager(APITransport):
|
|||
if api_def.callback is None:
|
||||
# Remote API, uses RPC to reach out to Klippy
|
||||
ws_method = api_def.jrpc_methods[0]
|
||||
rpc_cb = self._generate_callback(api_def.endpoint)
|
||||
rpc_cb = self._generate_callback(
|
||||
api_def.endpoint, "", self.klippy.request
|
||||
)
|
||||
self.rpc.register_method(ws_method, rpc_cb)
|
||||
else:
|
||||
# Local API, uses local callback
|
||||
for ws_method, req_method in \
|
||||
zip(api_def.jrpc_methods, api_def.request_methods):
|
||||
rpc_cb = self._generate_local_callback(
|
||||
api_def.endpoint, req_method, api_def.callback)
|
||||
rpc_cb = self._generate_callback(
|
||||
api_def.endpoint, req_method, api_def.callback
|
||||
)
|
||||
self.rpc.register_method(ws_method, rpc_cb)
|
||||
logging.info(
|
||||
"Registering Websocket JSON-RPC methods: "
|
||||
f"{', '.join(api_def.jrpc_methods)}")
|
||||
f"{', '.join(api_def.jrpc_methods)}"
|
||||
)
|
||||
|
||||
def remove_api_handler(self, api_def: APIDefinition) -> None:
|
||||
for jrpc_method in api_def.jrpc_methods:
|
||||
self.rpc.remove_method(jrpc_method)
|
||||
|
||||
def _generate_callback(self, endpoint: str) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
sc: BaseSocketClient = args.pop("_socket_")
|
||||
result = await self.klippy.request(
|
||||
WebRequest(endpoint, args, conn=sc, ip_addr=sc.ip_addr,
|
||||
user=sc.user_info))
|
||||
return result
|
||||
return func
|
||||
|
||||
def _generate_local_callback(self,
|
||||
endpoint: str,
|
||||
request_method: str,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
def _generate_callback(
|
||||
self,
|
||||
endpoint: str,
|
||||
request_method: str,
|
||||
callback: Callable[[WebRequest], Coroutine]
|
||||
) -> RPCCallback:
|
||||
async def func(args: Dict[str, Any]) -> Any:
|
||||
sc: BaseSocketClient = args.pop("_socket_")
|
||||
sc.authenticate(path=endpoint)
|
||||
result = await callback(
|
||||
WebRequest(endpoint, args, request_method, sc,
|
||||
ip_addr=sc.ip_addr, user=sc.user_info))
|
||||
|
@ -367,10 +371,12 @@ class WebsocketManager(APITransport):
|
|||
|
||||
async def _handle_id_request(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseSocketClient = args["_socket_"]
|
||||
sc.authenticate()
|
||||
return {'websocket_id': sc.uid}
|
||||
|
||||
async def _handle_identify(self, args: Dict[str, Any]) -> Dict[str, int]:
|
||||
sc: BaseSocketClient = args["_socket_"]
|
||||
sc.authenticate(token=args.get("access_token", None))
|
||||
if sc.identified:
|
||||
raise self.server.error(
|
||||
f"Connection already identified: {sc.client_data}"
|
||||
|
@ -468,7 +474,7 @@ class WebsocketManager(APITransport):
|
|||
if data:
|
||||
msg['params'] = data
|
||||
for sc in list(self.clients.values()):
|
||||
if sc.uid in mask:
|
||||
if sc.uid in mask or sc.need_auth:
|
||||
continue
|
||||
sc.queue_message(msg)
|
||||
|
||||
|
@ -507,10 +513,21 @@ class BaseSocketClient(Subscribable):
|
|||
"type": "",
|
||||
"url": ""
|
||||
}
|
||||
self._need_auth: bool = False
|
||||
self._user_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||
return None
|
||||
return self._user_info
|
||||
|
||||
@user_info.setter
|
||||
def user_info(self, uinfo: Dict[str, Any]) -> None:
|
||||
self._user_info = uinfo
|
||||
self._need_auth = False
|
||||
|
||||
@property
|
||||
def need_auth(self) -> bool:
|
||||
return self._need_auth
|
||||
|
||||
@property
|
||||
def uid(self) -> int:
|
||||
|
@ -552,6 +569,25 @@ class BaseSocketClient(Subscribable):
|
|||
self.queue_busy = True
|
||||
self.eventloop.register_callback(self._write_messages)
|
||||
|
||||
def authenticate(self, path: str = "", token: Optional[str] = None) -> None:
|
||||
if not self._need_auth:
|
||||
return
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is None:
|
||||
return
|
||||
if token is not None:
|
||||
try:
|
||||
user_info = auth.decode_jwt(token)
|
||||
except self.server.error:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise self.server.error(
|
||||
f"Failed to decode JWT: {e}", 401
|
||||
) from e
|
||||
self.user_info = user_info
|
||||
elif not auth.is_path_permitted(path):
|
||||
raise self.server.error("Unauthorized", 401)
|
||||
|
||||
async def _write_messages(self):
|
||||
if self.is_closed:
|
||||
self.message_buf = []
|
||||
|
@ -619,14 +655,13 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
|
|||
self.ip_addr: str = self.request.remote_ip or ""
|
||||
self.last_pong_time: float = self.eventloop.get_loop_time()
|
||||
|
||||
@property
|
||||
def user_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.current_user
|
||||
|
||||
@property
|
||||
def hostname(self) -> str:
|
||||
return self.request.host_name
|
||||
|
||||
def get_current_user(self) -> Any:
|
||||
return self._user_info
|
||||
|
||||
def open(self, *args, **kwargs) -> None:
|
||||
self.set_nodelay(True)
|
||||
self._connected_time = self.eventloop.get_loop_time()
|
||||
|
@ -693,7 +728,12 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
|
|||
def prepare(self):
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
self.current_user = auth.check_authorized(self.request)
|
||||
try:
|
||||
self._user_info = auth.check_authorized(self.request)
|
||||
except Exception as e:
|
||||
logging.info(f"Websocket Failed Authentication: {e}")
|
||||
self._user_info = None
|
||||
self._need_auth = True
|
||||
|
||||
def close_socket(self, code: int, reason: str) -> None:
|
||||
self.close(code, reason)
|
||||
|
|
Loading…
Reference in New Issue