authorization: add annotations

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-13 10:48:57 -04:00
parent 410db750c6
commit 41ddbb16a8
1 changed files with 130 additions and 69 deletions

View File

@ -3,6 +3,8 @@
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com> # Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
# #
# This file may be distributed under the terms of the GNU GPLv3 license # This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import base64 import base64
import uuid import uuid
import hashlib import hashlib
@ -18,7 +20,28 @@ import socket
import logging import logging
from tornado.ioloop import IOLoop, PeriodicCallback from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.web import HTTPError from tornado.web import HTTPError
from utils import ServerError
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Set,
Optional,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from confighelper import ConfigHelper
from websockets import WebRequest
from tornado.httputil import HTTPServerRequest
from tornado.web import RequestHandler
from . import database
DBComp = database.MoonrakerDatabase
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object]
ONESHOT_TIMEOUT = 5 ONESHOT_TIMEOUT = 5
TRUSTED_CONNECTION_TIMEOUT = 3600 TRUSTED_CONNECTION_TIMEOUT = 3600
@ -35,23 +58,23 @@ JWT_HEADER = {
} }
# Helpers for base64url encoding and decoding # Helpers for base64url encoding and decoding
def base64url_encode(data): def base64url_encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=") return base64.urlsafe_b64encode(data).rstrip(b"=")
def base64url_decode(data): def base64url_decode(data) -> bytes:
pad_cnt = len(data) % 4 pad_cnt = len(data) % 4
if pad_cnt: if pad_cnt:
data += b"=" * (4 - pad_cnt) data += b"=" * (4 - pad_cnt)
return base64.urlsafe_b64decode(data) return base64.urlsafe_b64decode(data)
class Authorization: class Authorization:
def __init__(self, config): def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
self.login_timeout = config.getint('login_timeout', 90) self.login_timeout = config.getint('login_timeout', 90)
database = self.server.lookup_component('database') database: DBComp = self.server.lookup_component('database')
database.register_local_namespace('authorized_users', forbidden=True) database.register_local_namespace('authorized_users', forbidden=True)
self.users = database.wrap_namespace('authorized_users') self.users = database.wrap_namespace('authorized_users')
api_user = self.users.get(API_USER, None) api_user: Optional[Dict[str, Any]] = self.users.get(API_USER, None)
if api_user is None: if api_user is None:
self.api_key = uuid.uuid4().hex self.api_key = uuid.uuid4().hex
self.users[API_USER] = { self.users[API_USER] = {
@ -61,14 +84,14 @@ class Authorization:
} }
else: else:
self.api_key = api_user['api_key'] self.api_key = api_user['api_key']
self.trusted_users = {} self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens = {} self.oneshot_tokens: Dict[str, OneshotToken] = {}
self.permitted_paths = set() self.permitted_paths: Set[str] = set()
# Get allowed cors domains # Get allowed cors domains
self.cors_domains = [] self.cors_domains: List[str] = []
cors_cfg = config.get('cors_domains', "").strip() cors_cfg = config.get('cors_domains', "").strip()
cds = [d.strip() for d in cors_cfg.split('\n')if d.strip()] cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()]
for domain in cds: for domain in cds:
bad_match = re.search(r"^.+\.[^:]*\*", domain) bad_match = re.search(r"^.+\.[^:]*\*", domain)
if bad_match is not None: if bad_match is not None:
@ -79,12 +102,11 @@ class Authorization:
domain.replace(".", "\\.").replace("*", ".*")) domain.replace(".", "\\.").replace("*", ".*"))
# Get Trusted Clients # Get Trusted Clients
self.trusted_ips = [] self.trusted_ips: List[IPAddr] = []
self.trusted_ranges = [] self.trusted_ranges: List[IPNetwork] = []
self.trusted_domains = [] self.trusted_domains: List[str] = []
trusted_clients = config.get('trusted_clients', "") tcs = config.get('trusted_clients', "")
trusted_clients = [c.strip() for c in trusted_clients.split('\n') trusted_clients = [c.strip() for c in tcs.split('\n') if c.strip()]
if c.strip()]
for val in trusted_clients: for val in trusted_clients:
# Check IP address # Check IP address
try: try:
@ -150,26 +172,27 @@ class Authorization:
self.server.register_notification("authorization:user_created") self.server.register_notification("authorization:user_created")
self.server.register_notification("authorization:user_deleted") self.server.register_notification("authorization:user_deleted")
async def _handle_apikey_request(self, web_request): async def _handle_apikey_request(self, web_request: WebRequest) -> str:
action = web_request.get_action() action = web_request.get_action()
if action.upper() == 'POST': if action.upper() == 'POST':
self.api_key = uuid.uuid4().hex self.api_key = uuid.uuid4().hex
self.users[f'{API_USER}.api_key'] = self.api_key self.users[f'{API_USER}.api_key'] = self.api_key
return self.api_key return self.api_key
async def _handle_token_request(self, web_request): async def _handle_token_request(self, web_request: WebRequest) -> str:
ip = web_request.get_ip_address() ip = web_request.get_ip_address()
assert ip is not None
user_info = web_request.get_current_user() user_info = web_request.get_current_user()
return self.get_oneshot_token(ip, user_info) return self.get_oneshot_token(ip, user_info)
async def _handle_login(self, web_request): async def _handle_login(self, web_request: WebRequest) -> Dict[str, Any]:
return self._login_jwt_user(web_request) return self._login_jwt_user(web_request)
async def _handle_logout(self, web_request): async def _handle_logout(self, web_request: WebRequest) -> Dict[str, str]:
user_info = web_request.get_current_user() user_info = web_request.get_current_user()
if user_info is None: if user_info is None:
raise self.server.error("No user logged in") raise self.server.error("No user logged in")
username = user_info['username'] username: str = user_info['username']
if username in RESERVED_USERS: if username in RESERVED_USERS:
raise self.server.error( raise self.server.error(
f"Invalid log out request for user {username}") f"Invalid log out request for user {username}")
@ -179,10 +202,12 @@ class Authorization:
"action": "user_logged_out" "action": "user_logged_out"
} }
async def _handle_refresh_jwt(self, web_request): async def _handle_refresh_jwt(self,
refresh_token = web_request.get_str('refresh_token') web_request: WebRequest
) -> Dict[str, str]:
refresh_token: str = web_request.get_str('refresh_token')
user_info = self._decode_jwt(refresh_token, token_type="refresh") user_info = self._decode_jwt(refresh_token, token_type="refresh")
username = user_info['username'] username: str = user_info['username']
secret = bytes.fromhex(user_info['jwt_secret']) secret = bytes.fromhex(user_info['jwt_secret'])
token = self._generate_jwt(username, secret) token = self._generate_jwt(username, secret)
return { return {
@ -191,7 +216,9 @@ class Authorization:
'action': 'user_jwt_refresh' 'action': 'user_jwt_refresh'
} }
async def _handle_user_request(self, web_request): async def _handle_user_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action() action = web_request.get_action()
if action == "GET": if action == "GET":
user = web_request.get_current_user() user = web_request.get_current_user()
@ -211,8 +238,11 @@ class Authorization:
elif action == "DELETE": elif action == "DELETE":
# Delete User # Delete User
return self._delete_jwt_user(web_request) return self._delete_jwt_user(web_request)
raise self.server.error("Invalid Request Method")
async def _handle_list_request(self, web_request): async def _handle_list_request(self,
web_request: WebRequest
) -> Dict[str, List[Dict[str, Any]]]:
user_list = [] user_list = []
for user in self.users.values(): for user in self.users.values():
if user['username'] == API_USER: if user['username'] == API_USER:
@ -225,9 +255,11 @@ class Authorization:
'users': user_list 'users': user_list
} }
async def _handle_password_reset(self, web_request): async def _handle_password_reset(self,
password = web_request.get_str('password') web_request: WebRequest
new_pass = web_request.get_str('new_password') ) -> Dict[str, str]:
password: str = web_request.get_str('password')
new_pass: str = web_request.get_str('new_password')
user_info = web_request.get_current_user() user_info = web_request.get_current_user()
if user_info is None: if user_info is None:
raise self.server.error("No Current User") raise self.server.error("No Current User")
@ -248,9 +280,13 @@ class Authorization:
'action': "user_password_reset" 'action': "user_password_reset"
} }
def _login_jwt_user(self, web_request, create=False): def _login_jwt_user(self,
username = web_request.get_str('username') web_request: WebRequest,
password = web_request.get_str('password') create: bool = False
) -> Dict[str, Any]:
username: str = web_request.get_str('username')
password: str = web_request.get_str('password')
user_info: Dict[str, Any]
if username in RESERVED_USERS: if username in RESERVED_USERS:
raise self.server.error( raise self.server.error(
f"Invalid Request for user {username}") f"Invalid Request for user {username}")
@ -278,13 +314,13 @@ class Authorization:
action = "user_logged_in" action = "user_logged_in"
if hashed_pass != user_info['password']: if hashed_pass != user_info['password']:
raise self.server.error("Invalid Password") raise self.server.error("Invalid Password")
jwt_secret = user_info.get('jwt_secret', None) jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret is None: if jwt_secret_hex is None:
jwt_secret = secrets.token_bytes(32) jwt_secret = secrets.token_bytes(32)
user_info['jwt_secret'] = jwt_secret.hex() user_info['jwt_secret'] = jwt_secret.hex()
self.users[username] = user_info self.users[username] = user_info
else: else:
jwt_secret = bytes.fromhex(jwt_secret) jwt_secret = bytes.fromhex(jwt_secret_hex)
token = self._generate_jwt(username, jwt_secret) token = self._generate_jwt(username, jwt_secret)
refresh_token = self._generate_jwt( refresh_token = self._generate_jwt(
username, jwt_secret, token_type="refresh", username, jwt_secret, token_type="refresh",
@ -301,8 +337,8 @@ class Authorization:
'action': action 'action': action
} }
def _delete_jwt_user(self, web_request): def _delete_jwt_user(self, web_request: WebRequest) -> Dict[str, str]:
username = web_request.get_str('username') username: str = web_request.get_str('username')
current_user = web_request.get_current_user() current_user = web_request.get_current_user()
if current_user is not None: if current_user is not None:
curname = current_user.get('username', None) curname = current_user.get('username', None)
@ -312,7 +348,7 @@ class Authorization:
if username in RESERVED_USERS: if username in RESERVED_USERS:
raise self.server.error( raise self.server.error(
f"Invalid Request for reserved user {username}") f"Invalid Request for reserved user {username}")
user_info = self.users.get(username) user_info: Optional[Dict[str, Any]] = self.users.get(username)
if user_info is None: if user_info is None:
raise self.server.error(f"No registered user: {username}") raise self.server.error(f"No registered user: {username}")
del self.users[username] del self.users[username]
@ -325,8 +361,12 @@ class Authorization:
"action": "user_deleted" "action": "user_deleted"
} }
def _generate_jwt(self, username, secret, token_type="auth", def _generate_jwt(self,
exp_time=JWT_EXP_TIME): username: str,
secret: bytes,
token_type: str = "auth",
exp_time: datetime.timedelta = JWT_EXP_TIME
) -> str:
curtime = time.time() curtime = time.time()
payload = { payload = {
'iss': "Moonraker", 'iss': "Moonraker",
@ -342,27 +382,30 @@ class Authorization:
message += b"." + signature message += b"." + signature
return message.decode() return message.decode()
def _decode_jwt(self, jwt, token_type="auth"): def _decode_jwt(self,
jwt: str,
token_type: str = "auth"
) -> Dict[str, Any]:
parts = jwt.encode().split(b".") parts = jwt.encode().split(b".")
if len(parts) != 3: if len(parts) != 3:
raise self.server.error(f"Invalid JWT length of {len(parts)}") raise self.server.error(f"Invalid JWT length of {len(parts)}")
header = json.loads(base64url_decode(parts[0])) header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
payload = json.loads(base64url_decode(parts[1])) payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
if header != JWT_HEADER: if header != JWT_HEADER:
raise self.server.error("Invalid JWT header") raise self.server.error("Invalid JWT header")
recd_type = payload.get('token_type', "") recd_type: str = payload.get('token_type', "")
if token_type != recd_type: if token_type != recd_type:
raise self.server.error( raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, " f"JWT Token type mismatch: Expected {token_type}, "
f"Recd: {recd_type}", 401) f"Recd: {recd_type}", 401)
if time.time() > payload['exp']: if time.time() > payload['exp']:
raise self.server.error("JWT expired", 401) raise self.server.error("JWT expired", 401)
username = payload.get('username') username: str = payload['username']
user_info = self.users.get(username, None) user_info: Dict[str, Any] = self.users.get(username, None)
if user_info is None: if user_info is None:
raise self.server.error( raise self.server.error(
f"Invalid JWT, no registered user {username}", 401) f"Invalid JWT, no registered user {username}", 401)
jwt_secret = user_info.get('jwt_secret', None) jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret is None: if jwt_secret is None:
raise self.server.error( raise self.server.error(
f"Invalid JWT, user {username} not logged in", 401) f"Invalid JWT, user {username} not logged in", 401)
@ -375,10 +418,10 @@ class Authorization:
raise self.server.error("Invalid JWT signature") raise self.server.error("Invalid JWT signature")
return user_info return user_info
def _prune_conn_handler(self): def _prune_conn_handler(self) -> None:
cur_time = time.time() cur_time = time.time()
for ip, user_info in list(self.trusted_users.items()): for ip, user_info in list(self.trusted_users.items()):
exp_time = user_info['expires_at'] exp_time: float = user_info['expires_at']
if cur_time >= exp_time: if cur_time >= exp_time:
self.trusted_users.pop(ip, None) self.trusted_users.pop(ip, None)
logging.info( logging.info(
@ -387,7 +430,10 @@ class Authorization:
def _oneshot_token_expire_handler(self, token): def _oneshot_token_expire_handler(self, token):
self.oneshot_tokens.pop(token, None) self.oneshot_tokens.pop(token, None)
def get_oneshot_token(self, ip_addr, user): def get_oneshot_token(self,
ip_addr: IPAddr,
user: Optional[Dict[str, Any]]
) -> str:
token = base64.b32encode(os.urandom(20)).decode() token = base64.b32encode(os.urandom(20)).decode()
ioloop = IOLoop.current() ioloop = IOLoop.current()
hdl = ioloop.call_later( hdl = ioloop.call_later(
@ -395,8 +441,10 @@ class Authorization:
self.oneshot_tokens[token] = (ip_addr, user, hdl) self.oneshot_tokens[token] = (ip_addr, user, hdl)
return token return token
def _check_json_web_token(self, request): def _check_json_web_token(self,
auth_token = request.headers.get("Authorization") request: HTTPServerRequest
) -> Optional[Dict[str, Any]]:
auth_token: Optional[str] = request.headers.get("Authorization")
if auth_token is None: if auth_token is None:
auth_token = request.headers.get("X-Access-Token") auth_token = request.headers.get("X-Access-Token")
if auth_token and auth_token.startswith("Bearer "): if auth_token and auth_token.startswith("Bearer "):
@ -407,7 +455,7 @@ class Authorization:
raise HTTPError(401, str(e)) raise HTTPError(401, str(e))
return None return None
def _check_authorized_ip(self, ip): def _check_authorized_ip(self, ip: IPAddr) -> bool:
if ip in self.trusted_ips: if ip in self.trusted_ips:
return True return True
for rng in self.trusted_ranges: for rng in self.trusted_ranges:
@ -418,7 +466,9 @@ class Authorization:
return True return True
return False return False
def _check_trusted_connection(self, ip): def _check_trusted_connection(self,
ip: Optional[IPAddr]
) -> Optional[Dict[str, Any]]:
if ip is not None: if ip is not None:
curtime = time.time() curtime = time.time()
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
@ -437,7 +487,10 @@ class Authorization:
return self.trusted_users[ip] return self.trusted_users[ip]
return None return None
def _check_oneshot_token(self, token, cur_ip): def _check_oneshot_token(self,
token: str,
cur_ip: IPAddr
) -> Optional[Dict[str, Any]]:
if token in self.oneshot_tokens: if token in self.oneshot_tokens:
ip_addr, user, hdl = self.oneshot_tokens.pop(token) ip_addr, user, hdl = self.oneshot_tokens.pop(token)
IOLoop.current().remove_timeout(hdl) IOLoop.current().remove_timeout(hdl)
@ -449,7 +502,9 @@ class Authorization:
else: else:
return None return None
def check_authorized(self, request): def check_authorized(self,
request: HTTPServerRequest
) -> Optional[Dict[str, Any]]:
if request.path in self.permitted_paths or \ if request.path in self.permitted_paths or \
request.method == "OPTIONS": request.method == "OPTIONS":
return None return None
@ -467,14 +522,14 @@ class Authorization:
ip = None ip = None
# Check oneshot access token # Check oneshot access token
ost = request.arguments.get('token', None) ost: Optional[List[bytes]] = request.arguments.get('token', None)
if ost is not None: if ost is not None:
ost_user = self._check_oneshot_token(ost[-1].decode(), ip) ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
if ost_user is not None: if ost_user is not None:
return ost_user return ost_user
# Check API Key Header # Check API Key Header
key = request.headers.get("X-Api-Key") key: Optional[str] = request.headers.get("X-Api-Key")
if key and key == self.api_key: if key and key == self.api_key:
return self.users[API_USER] return self.users[API_USER]
@ -485,7 +540,10 @@ class Authorization:
raise HTTPError(401, "Unauthorized") raise HTTPError(401, "Unauthorized")
def check_cors(self, origin, request=None): def check_cors(self,
origin: Optional[str],
req_hdlr: Optional[RequestHandler] = None
) -> bool:
if origin is None or not self.cors_domains: if origin is None or not self.cors_domains:
return False return False
for regex in self.cors_domains: for regex in self.cors_domains:
@ -494,7 +552,7 @@ class Authorization:
if match.group() == origin: if match.group() == origin:
logging.debug(f"CORS Pattern Matched, origin: {origin} " logging.debug(f"CORS Pattern Matched, origin: {origin} "
f" | pattern: {regex}") f" | pattern: {regex}")
self._set_cors_headers(origin, request) self._set_cors_headers(origin, req_hdlr)
return True return True
else: else:
logging.debug(f"Partial Cors Match: {match.group()}") logging.debug(f"Partial Cors Match: {match.group()}")
@ -512,28 +570,31 @@ class Authorization:
if self._check_authorized_ip(ipaddr): if self._check_authorized_ip(ipaddr):
logging.debug( logging.debug(
f"Cors request matched trusted IP: {ip}") f"Cors request matched trusted IP: {ip}")
self._set_cors_headers(origin, request) self._set_cors_headers(origin, req_hdlr)
return True return True
logging.debug(f"No CORS match for origin: {origin}\n" logging.debug(f"No CORS match for origin: {origin}\n"
f"Patterns: {self.cors_domains}") f"Patterns: {self.cors_domains}")
return False return False
def _set_cors_headers(self, origin, request): def _set_cors_headers(self,
if request is None: origin: str,
req_hdlr: Optional[RequestHandler]
) -> None:
if req_hdlr is None:
return return
request.set_header("Access-Control-Allow-Origin", origin) req_hdlr.set_header("Access-Control-Allow-Origin", origin)
request.set_header( req_hdlr.set_header(
"Access-Control-Allow-Methods", "Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS") "GET, POST, PUT, DELETE, OPTIONS")
request.set_header( req_hdlr.set_header(
"Access-Control-Allow-Headers", "Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, " "Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token, Authorization, X-Access-Token, " "X-CRSF-Token, Authorization, X-Access-Token, "
"X-Api-Key") "X-Api-Key")
def close(self): def close(self) -> None:
self.prune_handler.stop() self.prune_handler.stop()
def load_component(config): def load_component(config: ConfigHelper) -> Authorization:
return Authorization(config) return Authorization(config)