authorization: use python_jose dependency for jwt management

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-23 08:11:46 -04:00
parent e4e58f6d97
commit ce7f659a32
1 changed files with 21 additions and 46 deletions

View File

@ -8,16 +8,15 @@ from __future__ import annotations
import base64 import base64
import uuid import uuid
import hashlib import hashlib
import hmac
import secrets import secrets
import os import os
import time import time
import datetime import datetime
import ipaddress import ipaddress
import json
import re import re
import socket import socket
import logging import logging
from jose import jwt
from tornado.ioloop import IOLoop, PeriodicCallback from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.web import HTTPError from tornado.web import HTTPError
@ -57,16 +56,6 @@ JWT_HEADER = {
'typ': "JWT" 'typ': "JWT"
} }
# Helpers for base64url encoding and decoding
def base64url_encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=")
def base64url_decode(data) -> bytes:
pad_cnt = len(data) % 4
if pad_cnt:
data += b"=" * (4 - pad_cnt)
return base64.urlsafe_b64decode(data)
class Authorization: class Authorization:
def __init__(self, config: ConfigHelper) -> None: def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
@ -88,6 +77,8 @@ class Authorization:
self.trusted_users: Dict[IPAddr, Any] = {} self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens: Dict[str, OneshotToken] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {}
self.permitted_paths: Set[str] = set() self.permitted_paths: Set[str] = set()
host_name, port = self.server.get_host_info()
self.issuer = f"http://{host_name}:{port}"
# Get allowed cors domains # Get allowed cors domains
self.cors_domains: List[str] = [] self.cors_domains: List[str] = []
@ -209,7 +200,7 @@ class Authorization:
refresh_token: str = web_request.get_str('refresh_token') refresh_token: str = web_request.get_str('refresh_token')
user_info = self._decode_jwt(refresh_token, token_type="refresh") user_info = self._decode_jwt(refresh_token, token_type="refresh")
username: str = user_info['username'] username: str = user_info['username']
secret = bytes.fromhex(user_info['jwt_secret']) secret = user_info['jwt_secret']
token = self._generate_jwt(username, secret) token = self._generate_jwt(username, secret)
return { return {
'username': username, 'username': username,
@ -315,13 +306,11 @@ class Authorization:
action = "user_logged_in" action = "user_logged_in"
if hashed_pass != user_info['password']: if hashed_pass != user_info['password']:
raise self.server.error("Invalid Password") raise self.server.error("Invalid Password")
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None) jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret_hex is None: if jwt_secret is None:
jwt_secret = secrets.token_bytes(32) jwt_secret = secrets.token_bytes(32).hex()
user_info['jwt_secret'] = jwt_secret.hex() user_info['jwt_secret'] = jwt_secret
self.users[username] = user_info self.users[username] = user_info
else:
jwt_secret = bytes.fromhex(jwt_secret_hex)
token = self._generate_jwt(username, jwt_secret) token = self._generate_jwt(username, jwt_secret)
refresh_token = self._generate_jwt( refresh_token = self._generate_jwt(
username, jwt_secret, token_type="refresh", username, jwt_secret, token_type="refresh",
@ -364,34 +353,27 @@ class Authorization:
def _generate_jwt(self, def _generate_jwt(self,
username: str, username: str,
secret: bytes, secret: str,
token_type: str = "auth", token_type: str = "access",
exp_time: datetime.timedelta = JWT_EXP_TIME exp_time: datetime.timedelta = JWT_EXP_TIME
) -> str: ) -> str:
curtime = time.time() curtime = datetime.datetime.utcnow()
payload = { payload = {
'iss': "Moonraker", 'iss': self.issuer,
'aud': "Moonraker",
'iat': curtime, 'iat': curtime,
'exp': curtime + exp_time.total_seconds(), 'exp': curtime + exp_time,
'username': username, 'username': username,
'token_type': token_type 'token_type': token_type
} }
enc_header = base64url_encode(json.dumps(JWT_HEADER).encode()) return jwt.encode(payload, secret, headers=JWT_HEADER)
enc_payload = base64url_encode(json.dumps(payload).encode())
message = enc_header + b"." + enc_payload
signature = base64url_encode(hmac.digest(secret, message, "sha256"))
message += b"." + signature
return message.decode()
def _decode_jwt(self, def _decode_jwt(self,
jwt: str, token: str,
token_type: str = "auth" token_type: str = "access"
) -> Dict[str, Any]: ) -> Dict[str, Any]:
parts = jwt.encode().split(b".") header: Dict[str, Any] = jwt.get_unverified_header(token)
if len(parts) != 3: payload: Dict[str, Any] = jwt.get_unverified_claims(token)
raise self.server.error(f"Invalid JWT length of {len(parts)}")
header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
if header != JWT_HEADER: if header != JWT_HEADER:
raise self.server.error("Invalid JWT header") raise self.server.error("Invalid JWT header")
recd_type: str = payload.get('token_type', "") recd_type: str = payload.get('token_type', "")
@ -399,8 +381,6 @@ class Authorization:
raise self.server.error( raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, " f"JWT Token type mismatch: Expected {token_type}, "
f"Recd: {recd_type}", 401) f"Recd: {recd_type}", 401)
if time.time() > payload['exp']:
raise self.server.error("JWT expired", 401)
username: str = payload['username'] username: str = payload['username']
user_info: Dict[str, Any] = self.users.get(username, None) user_info: Dict[str, Any] = self.users.get(username, None)
if user_info is None: if user_info is None:
@ -410,13 +390,8 @@ class Authorization:
if jwt_secret is None: if jwt_secret is None:
raise self.server.error( raise self.server.error(
f"Invalid JWT, user {username} not logged in", 401) f"Invalid JWT, user {username} not logged in", 401)
secret = bytes.fromhex(jwt_secret) jwt.decode(token, jwt_secret, algorithms=['HS256'],
# Decode and verify signature audience="Moonraker")
signature = base64url_decode(parts[2])
calc_sig = hmac.digest(
secret, parts[0] + b"." + parts[1], "sha256")
if signature != calc_sig:
raise self.server.error("Invalid JWT signature")
return user_info return user_info
def _prune_conn_handler(self) -> None: def _prune_conn_handler(self) -> None: