diff --git a/moonraker/components/authorization.py b/moonraker/components/authorization.py index 0f62d0e..babd145 100644 --- a/moonraker/components/authorization.py +++ b/moonraker/components/authorization.py @@ -16,11 +16,10 @@ import ipaddress import re import socket import logging -from jose import jwt +import json from tornado.ioloop import IOLoop, PeriodicCallback from tornado.web import HTTPError -import cryptography.hazmat.primitives.asymmetric.ec as ec -import cryptography.hazmat.primitives.serialization as serialization +from libnacl.sign import Signer, Verifier # Annotation imports from typing import ( @@ -32,7 +31,6 @@ from typing import ( Union, Dict, List, - cast, ) if TYPE_CHECKING: from confighelper import ConfigHelper @@ -45,7 +43,16 @@ if TYPE_CHECKING: IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object] -ECPrivateKey = ec.EllipticCurvePrivateKeyWithSerialization +# Helpers for base64url encoding and decoding +def base64url_encode(data: bytes) -> bytes: + return base64.urlsafe_b64encode(data).rstrip(b"=") + +def base64url_decode(data: str) -> bytes: + pad_cnt = len(data) % 4 + if pad_cnt: + data += "=" * (4 - pad_cnt) + return base64.urlsafe_b64decode(data) + ONESHOT_TIMEOUT = 5 TRUSTED_CONNECTION_TIMEOUT = 3600 @@ -57,7 +64,7 @@ TRUSTED_USER = "_TRUSTED_USER_" RESERVED_USERS = [API_USER, TRUSTED_USER] JWT_EXP_TIME = datetime.timedelta(hours=1) JWT_HEADER = { - 'alg': "ES256", + 'alg': "EdDSA", 'typ': "JWT" } @@ -81,19 +88,21 @@ class Authorization: self.api_key = api_user['api_key'] host_name, port = self.server.get_host_info() self.issuer = f"http://{host_name}:{port}" - self.public_keys: Dict[str, ec.EllipticCurvePublicKey] = {} + self.public_jwks: Dict[str, Dict[str, Any]] = {} for username, user_info in list(self.users.items()): if username == API_USER: continue if 'jwt_secret' in user_info: try: priv_key = self._load_private_key(user_info['jwt_secret']) + jwk_id = user_info['jwk_id'] except self.server.error: logging.info("Invalid key found for user, removing") user_info.pop('jwt_secret', None) + user_info.pop('jwk_id', None) self.users[username] = user_info continue - self.public_keys[username] = priv_key.public_key() + self.public_jwks[jwk_id] = self._generate_public_jwk(priv_key) self.trusted_users: Dict[IPAddr, Any] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {} @@ -209,7 +218,8 @@ class Authorization: raise self.server.error( f"Invalid log out request for user {username}") self.users.pop(f"{username}.jwt_secret", None) - self.public_keys.pop(username, None) + jwk_id: str = self.users.pop(f"{username}.jwk_id", "") + self.public_jwks.pop(jwk_id, None) return { "username": username, "action": "user_logged_out" @@ -219,13 +229,16 @@ class Authorization: web_request: WebRequest ) -> Dict[str, str]: refresh_token: str = web_request.get_str('refresh_token') - user_info = self._decode_jwt(refresh_token, token_type="refresh") + try: + user_info = self._decode_jwt(refresh_token, token_type="refresh") + except Exception: + raise self.server.error("Invalid Refresh Token", 401) username: str = user_info['username'] - secret: Optional[str] = user_info.get('jwt_secret', None) - if secret is None: + if 'jwt_secret' not in user_info or "jwk_id" not in user_info: raise self.server.error("User not logged in", 401) - private_key = self._load_private_key(secret) - token = self._generate_jwt(username, private_key) + private_key = self._load_private_key(user_info['jwt_secret']) + jwk_id: str = user_info['jwk_id'] + token = self._generate_jwt(username, jwk_id, private_key) return { 'username': username, 'token': token, @@ -332,18 +345,18 @@ class Authorization: raise self.server.error("Invalid Password") jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None) if jwt_secret_hex is None: - private_key = ec.generate_private_key(ec.SECP256R1()) - serialized: bytes = cast(ECPrivateKey, private_key).private_bytes( - serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, - serialization.NoEncryption()) - user_info['jwt_secret'] = serialized.hex() + private_key = Signer() + jwk_id = base64url_encode(secrets.token_bytes()).decode() + user_info['jwt_secret'] = private_key.hex_seed().decode() + user_info['jwk_id'] = jwk_id self.users[username] = user_info + self.public_jwks[jwk_id] = self._generate_public_jwk(private_key) else: private_key = self._load_private_key(jwt_secret_hex) - self.public_keys[username] = private_key.public_key() - token = self._generate_jwt(username, private_key) + jwk_id = user_info['jwk_id'] + token = self._generate_jwt(username, jwk_id, private_key) refresh_token = self._generate_jwt( - username, private_key, token_type="refresh", + username, jwk_id, private_key, token_type="refresh", exp_time=datetime.timedelta(days=self.login_timeout)) if create: IOLoop.current().call_later( @@ -371,8 +384,8 @@ class Authorization: user_info: Optional[Dict[str, Any]] = self.users.get(username) if user_info is None: raise self.server.error(f"No registered user: {username}") + self.public_jwks.pop(self.users.get(f"{username}.jwk_id"), None) del self.users[username] - self.public_keys.pop(username, None) IOLoop.current().call_later( .005, self.server.send_event, "authorization:user_deleted", @@ -384,60 +397,96 @@ class Authorization: def _generate_jwt(self, username: str, - secret: ec.EllipticCurvePrivateKey, + jwk_id: str, + private_key: Signer, token_type: str = "access", exp_time: datetime.timedelta = JWT_EXP_TIME ) -> str: - curtime = datetime.datetime.utcnow() + curtime = int(time.time()) payload = { 'iss': self.issuer, 'aud': "Moonraker", 'iat': curtime, - 'exp': curtime + exp_time, + 'exp': curtime + int(exp_time.total_seconds()), 'username': username, 'token_type': token_type } - return jwt.encode(payload, secret, algorithm="ES256", - headers=JWT_HEADER) + header = {'kid': jwk_id} + header.update(JWT_HEADER) + jwt_header = base64url_encode(json.dumps(header).encode()) + jwt_payload = base64url_encode(json.dumps(payload).encode()) + jwt_msg = b".".join([jwt_header, jwt_payload]) + sig = private_key.signature(jwt_msg) + jwt_sig = base64url_encode(sig) + return b".".join([jwt_msg, jwt_sig]).decode() def _decode_jwt(self, token: str, token_type: str = "access" ) -> Dict[str, Any]: - header: Dict[str, Any] = jwt.get_unverified_header(token) - payload: Dict[str, Any] = jwt.get_unverified_claims(token) - if header != JWT_HEADER: + message, sig = token.rsplit('.', maxsplit=1) + enc_header, enc_payload = message.split('.') + header: Dict[str, Any] = json.loads(base64url_decode(enc_header)) + sig_bytes = base64url_decode(sig) + + # verify header + if header.get('typ') != "JWT" or header.get('alg') != "EdDSA": raise self.server.error("Invalid JWT header") - recd_type: str = payload.get('token_type', "") - if token_type != recd_type: + jwk_id = header.get('kid') + if jwk_id not in self.public_jwks: + raise self.server.error("Invalid key ID") + + # validate signature + public_key = self._public_key_from_jwk(self.public_jwks[jwk_id]) + public_key.verify(sig_bytes + message.encode()) + + # validate claims + payload: Dict[str, Any] = json.loads(base64url_decode(enc_payload)) + if payload['token_type'] != token_type: raise self.server.error( f"JWT Token type mismatch: Expected {token_type}, " - f"Recd: {recd_type}", 401) - username: str = payload['username'] - user_info: Dict[str, Any] = self.users.get(username, None) + f"Recd: {payload['token_type']}", 401) + if payload['iss'] != self.issuer: + raise self.server.error("Invalid JWT Issuer", 401) + if payload['aud'] != "Moonraker": + raise self.server.error("Invalid JWT Audience", 401) + if payload['exp'] < int(time.time()): + raise self.server.error("JWT Expired", 401) + + # get user + user_info: Optional[Dict[str, Any]] = self.users.get( + payload.get('username', ""), None) if user_info is None: - raise self.server.error( - f"Invalid JWT, no registered user {username}", 401) - public_key = self.public_keys.get(username, None) - if public_key is None: - raise self.server.error( - f"Invalid JWT, user {username} not logged in", 401) - try: - jwt.decode(token, [public_key], algorithms=['ES256'], - audience="Moonraker") - except jwt.JWTError as e: - raise self.server.error(str(e), 401) from None + raise self.server.error("Unknown user", 401) return user_info - def _load_private_key(self, secret: str) -> ec.EllipticCurvePrivateKey: + def _load_private_key(self, secret: str) -> Signer: try: - key = serialization.load_pem_private_key( - bytes.fromhex(secret), None) + key = Signer(bytes.fromhex(secret)) except Exception: raise self.server.error( "Error decoding private key, user data may" " be corrupt", 500) from None - return cast(ec.EllipticCurvePrivateKey, key) + return key + + def _generate_public_jwk(self, private_key: Signer) -> Dict[str, Any]: + public_key = private_key.vk + return { + 'x': base64url_encode(public_key).decode(), + 'kty': "OKP", + 'crv': "Ed25519", + 'use': "sig" + } + + def _public_key_from_jwk(self, jwk: Dict[str, Any]) -> Verifier: + if jwk.get('kty') != "OKP": + raise self.server.error("Not an Octet Key Pair") + if jwk.get('crv') != "Ed25519": + raise self.server.error("Invalid Curve") + if 'x' not in jwk: + raise self.server.error("No 'x' argument in jwk") + key = base64url_decode(jwk['x']) + return Verifier(key.hex().encode()) def _prune_conn_handler(self) -> None: cur_time = time.time()