authorization: switch to EdDSA signatures

This removes the cryptography dependency in favor of libsodium.  Also removed is python-jose, as we must generate our own JWTs for use with EdDSA.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>

use libnacl instead of pynacl
This commit is contained in:
Eric Callahan 2021-06-01 20:06:53 -04:00
parent 628c0193f3
commit 39343f984a
1 changed files with 100 additions and 51 deletions

View File

@ -16,11 +16,10 @@ import ipaddress
import re import re
import socket import socket
import logging import logging
from jose import jwt import json
from tornado.ioloop import IOLoop, PeriodicCallback from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.web import HTTPError from tornado.web import HTTPError
import cryptography.hazmat.primitives.asymmetric.ec as ec from libnacl.sign import Signer, Verifier
import cryptography.hazmat.primitives.serialization as serialization
# Annotation imports # Annotation imports
from typing import ( from typing import (
@ -32,7 +31,6 @@ from typing import (
Union, Union,
Dict, Dict,
List, List,
cast,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from confighelper import ConfigHelper from confighelper import ConfigHelper
@ -45,7 +43,16 @@ if TYPE_CHECKING:
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object] OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object]
ECPrivateKey = ec.EllipticCurvePrivateKeyWithSerialization # Helpers for base64url encoding and decoding
def base64url_encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=")
def base64url_decode(data: str) -> bytes:
pad_cnt = len(data) % 4
if pad_cnt:
data += "=" * (4 - pad_cnt)
return base64.urlsafe_b64decode(data)
ONESHOT_TIMEOUT = 5 ONESHOT_TIMEOUT = 5
TRUSTED_CONNECTION_TIMEOUT = 3600 TRUSTED_CONNECTION_TIMEOUT = 3600
@ -57,7 +64,7 @@ TRUSTED_USER = "_TRUSTED_USER_"
RESERVED_USERS = [API_USER, TRUSTED_USER] RESERVED_USERS = [API_USER, TRUSTED_USER]
JWT_EXP_TIME = datetime.timedelta(hours=1) JWT_EXP_TIME = datetime.timedelta(hours=1)
JWT_HEADER = { JWT_HEADER = {
'alg': "ES256", 'alg': "EdDSA",
'typ': "JWT" 'typ': "JWT"
} }
@ -81,19 +88,21 @@ class Authorization:
self.api_key = api_user['api_key'] self.api_key = api_user['api_key']
host_name, port = self.server.get_host_info() host_name, port = self.server.get_host_info()
self.issuer = f"http://{host_name}:{port}" self.issuer = f"http://{host_name}:{port}"
self.public_keys: Dict[str, ec.EllipticCurvePublicKey] = {} self.public_jwks: Dict[str, Dict[str, Any]] = {}
for username, user_info in list(self.users.items()): for username, user_info in list(self.users.items()):
if username == API_USER: if username == API_USER:
continue continue
if 'jwt_secret' in user_info: if 'jwt_secret' in user_info:
try: try:
priv_key = self._load_private_key(user_info['jwt_secret']) priv_key = self._load_private_key(user_info['jwt_secret'])
jwk_id = user_info['jwk_id']
except self.server.error: except self.server.error:
logging.info("Invalid key found for user, removing") logging.info("Invalid key found for user, removing")
user_info.pop('jwt_secret', None) user_info.pop('jwt_secret', None)
user_info.pop('jwk_id', None)
self.users[username] = user_info self.users[username] = user_info
continue continue
self.public_keys[username] = priv_key.public_key() self.public_jwks[jwk_id] = self._generate_public_jwk(priv_key)
self.trusted_users: Dict[IPAddr, Any] = {} self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens: Dict[str, OneshotToken] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {}
@ -209,7 +218,8 @@ class Authorization:
raise self.server.error( raise self.server.error(
f"Invalid log out request for user {username}") f"Invalid log out request for user {username}")
self.users.pop(f"{username}.jwt_secret", None) self.users.pop(f"{username}.jwt_secret", None)
self.public_keys.pop(username, None) jwk_id: str = self.users.pop(f"{username}.jwk_id", "")
self.public_jwks.pop(jwk_id, None)
return { return {
"username": username, "username": username,
"action": "user_logged_out" "action": "user_logged_out"
@ -219,13 +229,16 @@ class Authorization:
web_request: WebRequest web_request: WebRequest
) -> Dict[str, str]: ) -> Dict[str, str]:
refresh_token: str = web_request.get_str('refresh_token') refresh_token: str = web_request.get_str('refresh_token')
user_info = self._decode_jwt(refresh_token, token_type="refresh") try:
user_info = self._decode_jwt(refresh_token, token_type="refresh")
except Exception:
raise self.server.error("Invalid Refresh Token", 401)
username: str = user_info['username'] username: str = user_info['username']
secret: Optional[str] = user_info.get('jwt_secret', None) if 'jwt_secret' not in user_info or "jwk_id" not in user_info:
if secret is None:
raise self.server.error("User not logged in", 401) raise self.server.error("User not logged in", 401)
private_key = self._load_private_key(secret) private_key = self._load_private_key(user_info['jwt_secret'])
token = self._generate_jwt(username, private_key) jwk_id: str = user_info['jwk_id']
token = self._generate_jwt(username, jwk_id, private_key)
return { return {
'username': username, 'username': username,
'token': token, 'token': token,
@ -332,18 +345,18 @@ class Authorization:
raise self.server.error("Invalid Password") raise self.server.error("Invalid Password")
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None) jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret_hex is None: if jwt_secret_hex is None:
private_key = ec.generate_private_key(ec.SECP256R1()) private_key = Signer()
serialized: bytes = cast(ECPrivateKey, private_key).private_bytes( jwk_id = base64url_encode(secrets.token_bytes()).decode()
serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, user_info['jwt_secret'] = private_key.hex_seed().decode()
serialization.NoEncryption()) user_info['jwk_id'] = jwk_id
user_info['jwt_secret'] = serialized.hex()
self.users[username] = user_info self.users[username] = user_info
self.public_jwks[jwk_id] = self._generate_public_jwk(private_key)
else: else:
private_key = self._load_private_key(jwt_secret_hex) private_key = self._load_private_key(jwt_secret_hex)
self.public_keys[username] = private_key.public_key() jwk_id = user_info['jwk_id']
token = self._generate_jwt(username, private_key) token = self._generate_jwt(username, jwk_id, private_key)
refresh_token = self._generate_jwt( refresh_token = self._generate_jwt(
username, private_key, token_type="refresh", username, jwk_id, private_key, token_type="refresh",
exp_time=datetime.timedelta(days=self.login_timeout)) exp_time=datetime.timedelta(days=self.login_timeout))
if create: if create:
IOLoop.current().call_later( IOLoop.current().call_later(
@ -371,8 +384,8 @@ class Authorization:
user_info: Optional[Dict[str, Any]] = self.users.get(username) user_info: Optional[Dict[str, Any]] = self.users.get(username)
if user_info is None: if user_info is None:
raise self.server.error(f"No registered user: {username}") raise self.server.error(f"No registered user: {username}")
self.public_jwks.pop(self.users.get(f"{username}.jwk_id"), None)
del self.users[username] del self.users[username]
self.public_keys.pop(username, None)
IOLoop.current().call_later( IOLoop.current().call_later(
.005, self.server.send_event, .005, self.server.send_event,
"authorization:user_deleted", "authorization:user_deleted",
@ -384,60 +397,96 @@ class Authorization:
def _generate_jwt(self, def _generate_jwt(self,
username: str, username: str,
secret: ec.EllipticCurvePrivateKey, jwk_id: str,
private_key: Signer,
token_type: str = "access", token_type: str = "access",
exp_time: datetime.timedelta = JWT_EXP_TIME exp_time: datetime.timedelta = JWT_EXP_TIME
) -> str: ) -> str:
curtime = datetime.datetime.utcnow() curtime = int(time.time())
payload = { payload = {
'iss': self.issuer, 'iss': self.issuer,
'aud': "Moonraker", 'aud': "Moonraker",
'iat': curtime, 'iat': curtime,
'exp': curtime + exp_time, 'exp': curtime + int(exp_time.total_seconds()),
'username': username, 'username': username,
'token_type': token_type 'token_type': token_type
} }
return jwt.encode(payload, secret, algorithm="ES256", header = {'kid': jwk_id}
headers=JWT_HEADER) header.update(JWT_HEADER)
jwt_header = base64url_encode(json.dumps(header).encode())
jwt_payload = base64url_encode(json.dumps(payload).encode())
jwt_msg = b".".join([jwt_header, jwt_payload])
sig = private_key.signature(jwt_msg)
jwt_sig = base64url_encode(sig)
return b".".join([jwt_msg, jwt_sig]).decode()
def _decode_jwt(self, def _decode_jwt(self,
token: str, token: str,
token_type: str = "access" token_type: str = "access"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
header: Dict[str, Any] = jwt.get_unverified_header(token) message, sig = token.rsplit('.', maxsplit=1)
payload: Dict[str, Any] = jwt.get_unverified_claims(token) enc_header, enc_payload = message.split('.')
if header != JWT_HEADER: header: Dict[str, Any] = json.loads(base64url_decode(enc_header))
sig_bytes = base64url_decode(sig)
# verify header
if header.get('typ') != "JWT" or header.get('alg') != "EdDSA":
raise self.server.error("Invalid JWT header") raise self.server.error("Invalid JWT header")
recd_type: str = payload.get('token_type', "") jwk_id = header.get('kid')
if token_type != recd_type: if jwk_id not in self.public_jwks:
raise self.server.error("Invalid key ID")
# validate signature
public_key = self._public_key_from_jwk(self.public_jwks[jwk_id])
public_key.verify(sig_bytes + message.encode())
# validate claims
payload: Dict[str, Any] = json.loads(base64url_decode(enc_payload))
if payload['token_type'] != token_type:
raise self.server.error( raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, " f"JWT Token type mismatch: Expected {token_type}, "
f"Recd: {recd_type}", 401) f"Recd: {payload['token_type']}", 401)
username: str = payload['username'] if payload['iss'] != self.issuer:
user_info: Dict[str, Any] = self.users.get(username, None) raise self.server.error("Invalid JWT Issuer", 401)
if payload['aud'] != "Moonraker":
raise self.server.error("Invalid JWT Audience", 401)
if payload['exp'] < int(time.time()):
raise self.server.error("JWT Expired", 401)
# get user
user_info: Optional[Dict[str, Any]] = self.users.get(
payload.get('username', ""), None)
if user_info is None: if user_info is None:
raise self.server.error( raise self.server.error("Unknown user", 401)
f"Invalid JWT, no registered user {username}", 401)
public_key = self.public_keys.get(username, None)
if public_key is None:
raise self.server.error(
f"Invalid JWT, user {username} not logged in", 401)
try:
jwt.decode(token, [public_key], algorithms=['ES256'],
audience="Moonraker")
except jwt.JWTError as e:
raise self.server.error(str(e), 401) from None
return user_info return user_info
def _load_private_key(self, secret: str) -> ec.EllipticCurvePrivateKey: def _load_private_key(self, secret: str) -> Signer:
try: try:
key = serialization.load_pem_private_key( key = Signer(bytes.fromhex(secret))
bytes.fromhex(secret), None)
except Exception: except Exception:
raise self.server.error( raise self.server.error(
"Error decoding private key, user data may" "Error decoding private key, user data may"
" be corrupt", 500) from None " be corrupt", 500) from None
return cast(ec.EllipticCurvePrivateKey, key) return key
def _generate_public_jwk(self, private_key: Signer) -> Dict[str, Any]:
public_key = private_key.vk
return {
'x': base64url_encode(public_key).decode(),
'kty': "OKP",
'crv': "Ed25519",
'use': "sig"
}
def _public_key_from_jwk(self, jwk: Dict[str, Any]) -> Verifier:
if jwk.get('kty') != "OKP":
raise self.server.error("Not an Octet Key Pair")
if jwk.get('crv') != "Ed25519":
raise self.server.error("Invalid Curve")
if 'x' not in jwk:
raise self.server.error("No 'x' argument in jwk")
key = base64url_decode(jwk['x'])
return Verifier(key.hex().encode())
def _prune_conn_handler(self) -> None: def _prune_conn_handler(self) -> None:
cur_time = time.time() cur_time = time.time()