authorization: add annotations

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-13 10:48:57 -04:00
parent 410db750c6
commit 41ddbb16a8
1 changed files with 130 additions and 69 deletions

View File

@ -3,6 +3,8 @@
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
#
# This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import base64
import uuid
import hashlib
@ -18,7 +20,28 @@ import socket
import logging
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.web import HTTPError
from utils import ServerError
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Tuple,
Set,
Optional,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from confighelper import ConfigHelper
from websockets import WebRequest
from tornado.httputil import HTTPServerRequest
from tornado.web import RequestHandler
from . import database
DBComp = database.MoonrakerDatabase
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object]
ONESHOT_TIMEOUT = 5
TRUSTED_CONNECTION_TIMEOUT = 3600
@ -35,23 +58,23 @@ JWT_HEADER = {
}
# Helpers for base64url encoding and decoding
def base64url_encode(data):
def base64url_encode(data: bytes) -> bytes:
return base64.urlsafe_b64encode(data).rstrip(b"=")
def base64url_decode(data):
def base64url_decode(data) -> bytes:
pad_cnt = len(data) % 4
if pad_cnt:
data += b"=" * (4 - pad_cnt)
return base64.urlsafe_b64decode(data)
class Authorization:
def __init__(self, config):
def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server()
self.login_timeout = config.getint('login_timeout', 90)
database = self.server.lookup_component('database')
database: DBComp = self.server.lookup_component('database')
database.register_local_namespace('authorized_users', forbidden=True)
self.users = database.wrap_namespace('authorized_users')
api_user = self.users.get(API_USER, None)
api_user: Optional[Dict[str, Any]] = self.users.get(API_USER, None)
if api_user is None:
self.api_key = uuid.uuid4().hex
self.users[API_USER] = {
@ -61,14 +84,14 @@ class Authorization:
}
else:
self.api_key = api_user['api_key']
self.trusted_users = {}
self.oneshot_tokens = {}
self.permitted_paths = set()
self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens: Dict[str, OneshotToken] = {}
self.permitted_paths: Set[str] = set()
# Get allowed cors domains
self.cors_domains = []
self.cors_domains: List[str] = []
cors_cfg = config.get('cors_domains', "").strip()
cds = [d.strip() for d in cors_cfg.split('\n')if d.strip()]
cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()]
for domain in cds:
bad_match = re.search(r"^.+\.[^:]*\*", domain)
if bad_match is not None:
@ -79,12 +102,11 @@ class Authorization:
domain.replace(".", "\\.").replace("*", ".*"))
# Get Trusted Clients
self.trusted_ips = []
self.trusted_ranges = []
self.trusted_domains = []
trusted_clients = config.get('trusted_clients', "")
trusted_clients = [c.strip() for c in trusted_clients.split('\n')
if c.strip()]
self.trusted_ips: List[IPAddr] = []
self.trusted_ranges: List[IPNetwork] = []
self.trusted_domains: List[str] = []
tcs = config.get('trusted_clients', "")
trusted_clients = [c.strip() for c in tcs.split('\n') if c.strip()]
for val in trusted_clients:
# Check IP address
try:
@ -150,26 +172,27 @@ class Authorization:
self.server.register_notification("authorization:user_created")
self.server.register_notification("authorization:user_deleted")
async def _handle_apikey_request(self, web_request):
async def _handle_apikey_request(self, web_request: WebRequest) -> str:
action = web_request.get_action()
if action.upper() == 'POST':
self.api_key = uuid.uuid4().hex
self.users[f'{API_USER}.api_key'] = self.api_key
return self.api_key
async def _handle_token_request(self, web_request):
async def _handle_token_request(self, web_request: WebRequest) -> str:
ip = web_request.get_ip_address()
assert ip is not None
user_info = web_request.get_current_user()
return self.get_oneshot_token(ip, user_info)
async def _handle_login(self, web_request):
async def _handle_login(self, web_request: WebRequest) -> Dict[str, Any]:
return self._login_jwt_user(web_request)
async def _handle_logout(self, web_request):
async def _handle_logout(self, web_request: WebRequest) -> Dict[str, str]:
user_info = web_request.get_current_user()
if user_info is None:
raise self.server.error("No user logged in")
username = user_info['username']
username: str = user_info['username']
if username in RESERVED_USERS:
raise self.server.error(
f"Invalid log out request for user {username}")
@ -179,10 +202,12 @@ class Authorization:
"action": "user_logged_out"
}
async def _handle_refresh_jwt(self, web_request):
refresh_token = web_request.get_str('refresh_token')
async def _handle_refresh_jwt(self,
web_request: WebRequest
) -> Dict[str, str]:
refresh_token: str = web_request.get_str('refresh_token')
user_info = self._decode_jwt(refresh_token, token_type="refresh")
username = user_info['username']
username: str = user_info['username']
secret = bytes.fromhex(user_info['jwt_secret'])
token = self._generate_jwt(username, secret)
return {
@ -191,7 +216,9 @@ class Authorization:
'action': 'user_jwt_refresh'
}
async def _handle_user_request(self, web_request):
async def _handle_user_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
if action == "GET":
user = web_request.get_current_user()
@ -211,8 +238,11 @@ class Authorization:
elif action == "DELETE":
# Delete User
return self._delete_jwt_user(web_request)
raise self.server.error("Invalid Request Method")
async def _handle_list_request(self, web_request):
async def _handle_list_request(self,
web_request: WebRequest
) -> Dict[str, List[Dict[str, Any]]]:
user_list = []
for user in self.users.values():
if user['username'] == API_USER:
@ -225,9 +255,11 @@ class Authorization:
'users': user_list
}
async def _handle_password_reset(self, web_request):
password = web_request.get_str('password')
new_pass = web_request.get_str('new_password')
async def _handle_password_reset(self,
web_request: WebRequest
) -> Dict[str, str]:
password: str = web_request.get_str('password')
new_pass: str = web_request.get_str('new_password')
user_info = web_request.get_current_user()
if user_info is None:
raise self.server.error("No Current User")
@ -248,9 +280,13 @@ class Authorization:
'action': "user_password_reset"
}
def _login_jwt_user(self, web_request, create=False):
username = web_request.get_str('username')
password = web_request.get_str('password')
def _login_jwt_user(self,
web_request: WebRequest,
create: bool = False
) -> Dict[str, Any]:
username: str = web_request.get_str('username')
password: str = web_request.get_str('password')
user_info: Dict[str, Any]
if username in RESERVED_USERS:
raise self.server.error(
f"Invalid Request for user {username}")
@ -278,13 +314,13 @@ class Authorization:
action = "user_logged_in"
if hashed_pass != user_info['password']:
raise self.server.error("Invalid Password")
jwt_secret = user_info.get('jwt_secret', None)
if jwt_secret is None:
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret_hex is None:
jwt_secret = secrets.token_bytes(32)
user_info['jwt_secret'] = jwt_secret.hex()
self.users[username] = user_info
else:
jwt_secret = bytes.fromhex(jwt_secret)
jwt_secret = bytes.fromhex(jwt_secret_hex)
token = self._generate_jwt(username, jwt_secret)
refresh_token = self._generate_jwt(
username, jwt_secret, token_type="refresh",
@ -301,8 +337,8 @@ class Authorization:
'action': action
}
def _delete_jwt_user(self, web_request):
username = web_request.get_str('username')
def _delete_jwt_user(self, web_request: WebRequest) -> Dict[str, str]:
username: str = web_request.get_str('username')
current_user = web_request.get_current_user()
if current_user is not None:
curname = current_user.get('username', None)
@ -312,7 +348,7 @@ class Authorization:
if username in RESERVED_USERS:
raise self.server.error(
f"Invalid Request for reserved user {username}")
user_info = self.users.get(username)
user_info: Optional[Dict[str, Any]] = self.users.get(username)
if user_info is None:
raise self.server.error(f"No registered user: {username}")
del self.users[username]
@ -325,8 +361,12 @@ class Authorization:
"action": "user_deleted"
}
def _generate_jwt(self, username, secret, token_type="auth",
exp_time=JWT_EXP_TIME):
def _generate_jwt(self,
username: str,
secret: bytes,
token_type: str = "auth",
exp_time: datetime.timedelta = JWT_EXP_TIME
) -> str:
curtime = time.time()
payload = {
'iss': "Moonraker",
@ -342,27 +382,30 @@ class Authorization:
message += b"." + signature
return message.decode()
def _decode_jwt(self, jwt, token_type="auth"):
def _decode_jwt(self,
jwt: str,
token_type: str = "auth"
) -> Dict[str, Any]:
parts = jwt.encode().split(b".")
if len(parts) != 3:
raise self.server.error(f"Invalid JWT length of {len(parts)}")
header = json.loads(base64url_decode(parts[0]))
payload = json.loads(base64url_decode(parts[1]))
header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
if header != JWT_HEADER:
raise self.server.error("Invalid JWT header")
recd_type = payload.get('token_type', "")
recd_type: str = payload.get('token_type', "")
if token_type != recd_type:
raise self.server.error(
f"JWT Token type mismatch: Expected {token_type}, "
f"Recd: {recd_type}", 401)
if time.time() > payload['exp']:
raise self.server.error("JWT expired", 401)
username = payload.get('username')
user_info = self.users.get(username, None)
username: str = payload['username']
user_info: Dict[str, Any] = self.users.get(username, None)
if user_info is None:
raise self.server.error(
f"Invalid JWT, no registered user {username}", 401)
jwt_secret = user_info.get('jwt_secret', None)
jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
if jwt_secret is None:
raise self.server.error(
f"Invalid JWT, user {username} not logged in", 401)
@ -375,10 +418,10 @@ class Authorization:
raise self.server.error("Invalid JWT signature")
return user_info
def _prune_conn_handler(self):
def _prune_conn_handler(self) -> None:
cur_time = time.time()
for ip, user_info in list(self.trusted_users.items()):
exp_time = user_info['expires_at']
exp_time: float = user_info['expires_at']
if cur_time >= exp_time:
self.trusted_users.pop(ip, None)
logging.info(
@ -387,7 +430,10 @@ class Authorization:
def _oneshot_token_expire_handler(self, token):
self.oneshot_tokens.pop(token, None)
def get_oneshot_token(self, ip_addr, user):
def get_oneshot_token(self,
ip_addr: IPAddr,
user: Optional[Dict[str, Any]]
) -> str:
token = base64.b32encode(os.urandom(20)).decode()
ioloop = IOLoop.current()
hdl = ioloop.call_later(
@ -395,8 +441,10 @@ class Authorization:
self.oneshot_tokens[token] = (ip_addr, user, hdl)
return token
def _check_json_web_token(self, request):
auth_token = request.headers.get("Authorization")
def _check_json_web_token(self,
request: HTTPServerRequest
) -> Optional[Dict[str, Any]]:
auth_token: Optional[str] = request.headers.get("Authorization")
if auth_token is None:
auth_token = request.headers.get("X-Access-Token")
if auth_token and auth_token.startswith("Bearer "):
@ -407,7 +455,7 @@ class Authorization:
raise HTTPError(401, str(e))
return None
def _check_authorized_ip(self, ip):
def _check_authorized_ip(self, ip: IPAddr) -> bool:
if ip in self.trusted_ips:
return True
for rng in self.trusted_ranges:
@ -418,7 +466,9 @@ class Authorization:
return True
return False
def _check_trusted_connection(self, ip):
def _check_trusted_connection(self,
ip: Optional[IPAddr]
) -> Optional[Dict[str, Any]]:
if ip is not None:
curtime = time.time()
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
@ -437,7 +487,10 @@ class Authorization:
return self.trusted_users[ip]
return None
def _check_oneshot_token(self, token, cur_ip):
def _check_oneshot_token(self,
token: str,
cur_ip: IPAddr
) -> Optional[Dict[str, Any]]:
if token in self.oneshot_tokens:
ip_addr, user, hdl = self.oneshot_tokens.pop(token)
IOLoop.current().remove_timeout(hdl)
@ -449,7 +502,9 @@ class Authorization:
else:
return None
def check_authorized(self, request):
def check_authorized(self,
request: HTTPServerRequest
) -> Optional[Dict[str, Any]]:
if request.path in self.permitted_paths or \
request.method == "OPTIONS":
return None
@ -467,14 +522,14 @@ class Authorization:
ip = None
# Check oneshot access token
ost = request.arguments.get('token', None)
ost: Optional[List[bytes]] = request.arguments.get('token', None)
if ost is not None:
ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
if ost_user is not None:
return ost_user
# Check API Key Header
key = request.headers.get("X-Api-Key")
key: Optional[str] = request.headers.get("X-Api-Key")
if key and key == self.api_key:
return self.users[API_USER]
@ -485,7 +540,10 @@ class Authorization:
raise HTTPError(401, "Unauthorized")
def check_cors(self, origin, request=None):
def check_cors(self,
origin: Optional[str],
req_hdlr: Optional[RequestHandler] = None
) -> bool:
if origin is None or not self.cors_domains:
return False
for regex in self.cors_domains:
@ -494,7 +552,7 @@ class Authorization:
if match.group() == origin:
logging.debug(f"CORS Pattern Matched, origin: {origin} "
f" | pattern: {regex}")
self._set_cors_headers(origin, request)
self._set_cors_headers(origin, req_hdlr)
return True
else:
logging.debug(f"Partial Cors Match: {match.group()}")
@ -512,28 +570,31 @@ class Authorization:
if self._check_authorized_ip(ipaddr):
logging.debug(
f"Cors request matched trusted IP: {ip}")
self._set_cors_headers(origin, request)
self._set_cors_headers(origin, req_hdlr)
return True
logging.debug(f"No CORS match for origin: {origin}\n"
f"Patterns: {self.cors_domains}")
return False
def _set_cors_headers(self, origin, request):
if request is None:
def _set_cors_headers(self,
origin: str,
req_hdlr: Optional[RequestHandler]
) -> None:
if req_hdlr is None:
return
request.set_header("Access-Control-Allow-Origin", origin)
request.set_header(
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
req_hdlr.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
request.set_header(
req_hdlr.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token, Authorization, X-Access-Token, "
"X-Api-Key")
def close(self):
def close(self) -> None:
self.prune_handler.stop()
def load_component(config):
def load_component(config: ConfigHelper) -> Authorization:
return Authorization(config)