authorization: use python_jose dependency for jwt management
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
e4e58f6d97
commit
ce7f659a32
|
@ -8,16 +8,15 @@ from __future__ import annotations
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
|
||||||
import secrets
|
import secrets
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
import logging
|
import logging
|
||||||
|
from jose import jwt
|
||||||
from tornado.ioloop import IOLoop, PeriodicCallback
|
from tornado.ioloop import IOLoop, PeriodicCallback
|
||||||
from tornado.web import HTTPError
|
from tornado.web import HTTPError
|
||||||
|
|
||||||
|
@ -57,16 +56,6 @@ JWT_HEADER = {
|
||||||
'typ': "JWT"
|
'typ': "JWT"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Helpers for base64url encoding and decoding
|
|
||||||
def base64url_encode(data: bytes) -> bytes:
|
|
||||||
return base64.urlsafe_b64encode(data).rstrip(b"=")
|
|
||||||
|
|
||||||
def base64url_decode(data) -> bytes:
|
|
||||||
pad_cnt = len(data) % 4
|
|
||||||
if pad_cnt:
|
|
||||||
data += b"=" * (4 - pad_cnt)
|
|
||||||
return base64.urlsafe_b64decode(data)
|
|
||||||
|
|
||||||
class Authorization:
|
class Authorization:
|
||||||
def __init__(self, config: ConfigHelper) -> None:
|
def __init__(self, config: ConfigHelper) -> None:
|
||||||
self.server = config.get_server()
|
self.server = config.get_server()
|
||||||
|
@ -88,6 +77,8 @@ class Authorization:
|
||||||
self.trusted_users: Dict[IPAddr, Any] = {}
|
self.trusted_users: Dict[IPAddr, Any] = {}
|
||||||
self.oneshot_tokens: Dict[str, OneshotToken] = {}
|
self.oneshot_tokens: Dict[str, OneshotToken] = {}
|
||||||
self.permitted_paths: Set[str] = set()
|
self.permitted_paths: Set[str] = set()
|
||||||
|
host_name, port = self.server.get_host_info()
|
||||||
|
self.issuer = f"http://{host_name}:{port}"
|
||||||
|
|
||||||
# Get allowed cors domains
|
# Get allowed cors domains
|
||||||
self.cors_domains: List[str] = []
|
self.cors_domains: List[str] = []
|
||||||
|
@ -209,7 +200,7 @@ class Authorization:
|
||||||
refresh_token: str = web_request.get_str('refresh_token')
|
refresh_token: str = web_request.get_str('refresh_token')
|
||||||
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
||||||
username: str = user_info['username']
|
username: str = user_info['username']
|
||||||
secret = bytes.fromhex(user_info['jwt_secret'])
|
secret = user_info['jwt_secret']
|
||||||
token = self._generate_jwt(username, secret)
|
token = self._generate_jwt(username, secret)
|
||||||
return {
|
return {
|
||||||
'username': username,
|
'username': username,
|
||||||
|
@ -315,13 +306,11 @@ class Authorization:
|
||||||
action = "user_logged_in"
|
action = "user_logged_in"
|
||||||
if hashed_pass != user_info['password']:
|
if hashed_pass != user_info['password']:
|
||||||
raise self.server.error("Invalid Password")
|
raise self.server.error("Invalid Password")
|
||||||
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
|
jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
|
||||||
if jwt_secret_hex is None:
|
if jwt_secret is None:
|
||||||
jwt_secret = secrets.token_bytes(32)
|
jwt_secret = secrets.token_bytes(32).hex()
|
||||||
user_info['jwt_secret'] = jwt_secret.hex()
|
user_info['jwt_secret'] = jwt_secret
|
||||||
self.users[username] = user_info
|
self.users[username] = user_info
|
||||||
else:
|
|
||||||
jwt_secret = bytes.fromhex(jwt_secret_hex)
|
|
||||||
token = self._generate_jwt(username, jwt_secret)
|
token = self._generate_jwt(username, jwt_secret)
|
||||||
refresh_token = self._generate_jwt(
|
refresh_token = self._generate_jwt(
|
||||||
username, jwt_secret, token_type="refresh",
|
username, jwt_secret, token_type="refresh",
|
||||||
|
@ -364,34 +353,27 @@ class Authorization:
|
||||||
|
|
||||||
def _generate_jwt(self,
|
def _generate_jwt(self,
|
||||||
username: str,
|
username: str,
|
||||||
secret: bytes,
|
secret: str,
|
||||||
token_type: str = "auth",
|
token_type: str = "access",
|
||||||
exp_time: datetime.timedelta = JWT_EXP_TIME
|
exp_time: datetime.timedelta = JWT_EXP_TIME
|
||||||
) -> str:
|
) -> str:
|
||||||
curtime = time.time()
|
curtime = datetime.datetime.utcnow()
|
||||||
payload = {
|
payload = {
|
||||||
'iss': "Moonraker",
|
'iss': self.issuer,
|
||||||
|
'aud': "Moonraker",
|
||||||
'iat': curtime,
|
'iat': curtime,
|
||||||
'exp': curtime + exp_time.total_seconds(),
|
'exp': curtime + exp_time,
|
||||||
'username': username,
|
'username': username,
|
||||||
'token_type': token_type
|
'token_type': token_type
|
||||||
}
|
}
|
||||||
enc_header = base64url_encode(json.dumps(JWT_HEADER).encode())
|
return jwt.encode(payload, secret, headers=JWT_HEADER)
|
||||||
enc_payload = base64url_encode(json.dumps(payload).encode())
|
|
||||||
message = enc_header + b"." + enc_payload
|
|
||||||
signature = base64url_encode(hmac.digest(secret, message, "sha256"))
|
|
||||||
message += b"." + signature
|
|
||||||
return message.decode()
|
|
||||||
|
|
||||||
def _decode_jwt(self,
|
def _decode_jwt(self,
|
||||||
jwt: str,
|
token: str,
|
||||||
token_type: str = "auth"
|
token_type: str = "access"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
parts = jwt.encode().split(b".")
|
header: Dict[str, Any] = jwt.get_unverified_header(token)
|
||||||
if len(parts) != 3:
|
payload: Dict[str, Any] = jwt.get_unverified_claims(token)
|
||||||
raise self.server.error(f"Invalid JWT length of {len(parts)}")
|
|
||||||
header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
|
|
||||||
payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
|
|
||||||
if header != JWT_HEADER:
|
if header != JWT_HEADER:
|
||||||
raise self.server.error("Invalid JWT header")
|
raise self.server.error("Invalid JWT header")
|
||||||
recd_type: str = payload.get('token_type', "")
|
recd_type: str = payload.get('token_type', "")
|
||||||
|
@ -399,8 +381,6 @@ class Authorization:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"JWT Token type mismatch: Expected {token_type}, "
|
f"JWT Token type mismatch: Expected {token_type}, "
|
||||||
f"Recd: {recd_type}", 401)
|
f"Recd: {recd_type}", 401)
|
||||||
if time.time() > payload['exp']:
|
|
||||||
raise self.server.error("JWT expired", 401)
|
|
||||||
username: str = payload['username']
|
username: str = payload['username']
|
||||||
user_info: Dict[str, Any] = self.users.get(username, None)
|
user_info: Dict[str, Any] = self.users.get(username, None)
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
|
@ -410,13 +390,8 @@ class Authorization:
|
||||||
if jwt_secret is None:
|
if jwt_secret is None:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid JWT, user {username} not logged in", 401)
|
f"Invalid JWT, user {username} not logged in", 401)
|
||||||
secret = bytes.fromhex(jwt_secret)
|
jwt.decode(token, jwt_secret, algorithms=['HS256'],
|
||||||
# Decode and verify signature
|
audience="Moonraker")
|
||||||
signature = base64url_decode(parts[2])
|
|
||||||
calc_sig = hmac.digest(
|
|
||||||
secret, parts[0] + b"." + parts[1], "sha256")
|
|
||||||
if signature != calc_sig:
|
|
||||||
raise self.server.error("Invalid JWT signature")
|
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
def _prune_conn_handler(self) -> None:
|
def _prune_conn_handler(self) -> None:
|
||||||
|
|
Loading…
Reference in New Issue