authorization: add annotations
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
410db750c6
commit
41ddbb16a8
|
@ -3,6 +3,8 @@
|
|||
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
||||
#
|
||||
# This file may be distributed under the terms of the GNU GPLv3 license
|
||||
|
||||
from __future__ import annotations
|
||||
import base64
|
||||
import uuid
|
||||
import hashlib
|
||||
|
@ -18,7 +20,28 @@ import socket
|
|||
import logging
|
||||
from tornado.ioloop import IOLoop, PeriodicCallback
|
||||
from tornado.web import HTTPError
|
||||
from utils import ServerError
|
||||
|
||||
# Annotation imports
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Tuple,
|
||||
Set,
|
||||
Optional,
|
||||
Union,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from confighelper import ConfigHelper
|
||||
from websockets import WebRequest
|
||||
from tornado.httputil import HTTPServerRequest
|
||||
from tornado.web import RequestHandler
|
||||
from . import database
|
||||
DBComp = database.MoonrakerDatabase
|
||||
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
||||
OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object]
|
||||
|
||||
ONESHOT_TIMEOUT = 5
|
||||
TRUSTED_CONNECTION_TIMEOUT = 3600
|
||||
|
@ -35,23 +58,23 @@ JWT_HEADER = {
|
|||
}
|
||||
|
||||
# Helpers for base64url encoding and decoding
|
||||
def base64url_encode(data):
|
||||
def base64url_encode(data: bytes) -> bytes:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=")
|
||||
|
||||
def base64url_decode(data):
|
||||
def base64url_decode(data) -> bytes:
|
||||
pad_cnt = len(data) % 4
|
||||
if pad_cnt:
|
||||
data += b"=" * (4 - pad_cnt)
|
||||
return base64.urlsafe_b64decode(data)
|
||||
|
||||
class Authorization:
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: ConfigHelper) -> None:
|
||||
self.server = config.get_server()
|
||||
self.login_timeout = config.getint('login_timeout', 90)
|
||||
database = self.server.lookup_component('database')
|
||||
database: DBComp = self.server.lookup_component('database')
|
||||
database.register_local_namespace('authorized_users', forbidden=True)
|
||||
self.users = database.wrap_namespace('authorized_users')
|
||||
api_user = self.users.get(API_USER, None)
|
||||
api_user: Optional[Dict[str, Any]] = self.users.get(API_USER, None)
|
||||
if api_user is None:
|
||||
self.api_key = uuid.uuid4().hex
|
||||
self.users[API_USER] = {
|
||||
|
@ -61,14 +84,14 @@ class Authorization:
|
|||
}
|
||||
else:
|
||||
self.api_key = api_user['api_key']
|
||||
self.trusted_users = {}
|
||||
self.oneshot_tokens = {}
|
||||
self.permitted_paths = set()
|
||||
self.trusted_users: Dict[IPAddr, Any] = {}
|
||||
self.oneshot_tokens: Dict[str, OneshotToken] = {}
|
||||
self.permitted_paths: Set[str] = set()
|
||||
|
||||
# Get allowed cors domains
|
||||
self.cors_domains = []
|
||||
self.cors_domains: List[str] = []
|
||||
cors_cfg = config.get('cors_domains', "").strip()
|
||||
cds = [d.strip() for d in cors_cfg.split('\n')if d.strip()]
|
||||
cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()]
|
||||
for domain in cds:
|
||||
bad_match = re.search(r"^.+\.[^:]*\*", domain)
|
||||
if bad_match is not None:
|
||||
|
@ -79,12 +102,11 @@ class Authorization:
|
|||
domain.replace(".", "\\.").replace("*", ".*"))
|
||||
|
||||
# Get Trusted Clients
|
||||
self.trusted_ips = []
|
||||
self.trusted_ranges = []
|
||||
self.trusted_domains = []
|
||||
trusted_clients = config.get('trusted_clients', "")
|
||||
trusted_clients = [c.strip() for c in trusted_clients.split('\n')
|
||||
if c.strip()]
|
||||
self.trusted_ips: List[IPAddr] = []
|
||||
self.trusted_ranges: List[IPNetwork] = []
|
||||
self.trusted_domains: List[str] = []
|
||||
tcs = config.get('trusted_clients', "")
|
||||
trusted_clients = [c.strip() for c in tcs.split('\n') if c.strip()]
|
||||
for val in trusted_clients:
|
||||
# Check IP address
|
||||
try:
|
||||
|
@ -150,26 +172,27 @@ class Authorization:
|
|||
self.server.register_notification("authorization:user_created")
|
||||
self.server.register_notification("authorization:user_deleted")
|
||||
|
||||
async def _handle_apikey_request(self, web_request):
|
||||
async def _handle_apikey_request(self, web_request: WebRequest) -> str:
|
||||
action = web_request.get_action()
|
||||
if action.upper() == 'POST':
|
||||
self.api_key = uuid.uuid4().hex
|
||||
self.users[f'{API_USER}.api_key'] = self.api_key
|
||||
return self.api_key
|
||||
|
||||
async def _handle_token_request(self, web_request):
|
||||
async def _handle_token_request(self, web_request: WebRequest) -> str:
|
||||
ip = web_request.get_ip_address()
|
||||
assert ip is not None
|
||||
user_info = web_request.get_current_user()
|
||||
return self.get_oneshot_token(ip, user_info)
|
||||
|
||||
async def _handle_login(self, web_request):
|
||||
async def _handle_login(self, web_request: WebRequest) -> Dict[str, Any]:
|
||||
return self._login_jwt_user(web_request)
|
||||
|
||||
async def _handle_logout(self, web_request):
|
||||
async def _handle_logout(self, web_request: WebRequest) -> Dict[str, str]:
|
||||
user_info = web_request.get_current_user()
|
||||
if user_info is None:
|
||||
raise self.server.error("No user logged in")
|
||||
username = user_info['username']
|
||||
username: str = user_info['username']
|
||||
if username in RESERVED_USERS:
|
||||
raise self.server.error(
|
||||
f"Invalid log out request for user {username}")
|
||||
|
@ -179,10 +202,12 @@ class Authorization:
|
|||
"action": "user_logged_out"
|
||||
}
|
||||
|
||||
async def _handle_refresh_jwt(self, web_request):
|
||||
refresh_token = web_request.get_str('refresh_token')
|
||||
async def _handle_refresh_jwt(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, str]:
|
||||
refresh_token: str = web_request.get_str('refresh_token')
|
||||
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
||||
username = user_info['username']
|
||||
username: str = user_info['username']
|
||||
secret = bytes.fromhex(user_info['jwt_secret'])
|
||||
token = self._generate_jwt(username, secret)
|
||||
return {
|
||||
|
@ -191,7 +216,9 @@ class Authorization:
|
|||
'action': 'user_jwt_refresh'
|
||||
}
|
||||
|
||||
async def _handle_user_request(self, web_request):
|
||||
async def _handle_user_request(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, Any]:
|
||||
action = web_request.get_action()
|
||||
if action == "GET":
|
||||
user = web_request.get_current_user()
|
||||
|
@ -211,8 +238,11 @@ class Authorization:
|
|||
elif action == "DELETE":
|
||||
# Delete User
|
||||
return self._delete_jwt_user(web_request)
|
||||
raise self.server.error("Invalid Request Method")
|
||||
|
||||
async def _handle_list_request(self, web_request):
|
||||
async def _handle_list_request(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
user_list = []
|
||||
for user in self.users.values():
|
||||
if user['username'] == API_USER:
|
||||
|
@ -225,9 +255,11 @@ class Authorization:
|
|||
'users': user_list
|
||||
}
|
||||
|
||||
async def _handle_password_reset(self, web_request):
|
||||
password = web_request.get_str('password')
|
||||
new_pass = web_request.get_str('new_password')
|
||||
async def _handle_password_reset(self,
|
||||
web_request: WebRequest
|
||||
) -> Dict[str, str]:
|
||||
password: str = web_request.get_str('password')
|
||||
new_pass: str = web_request.get_str('new_password')
|
||||
user_info = web_request.get_current_user()
|
||||
if user_info is None:
|
||||
raise self.server.error("No Current User")
|
||||
|
@ -248,9 +280,13 @@ class Authorization:
|
|||
'action': "user_password_reset"
|
||||
}
|
||||
|
||||
def _login_jwt_user(self, web_request, create=False):
|
||||
username = web_request.get_str('username')
|
||||
password = web_request.get_str('password')
|
||||
def _login_jwt_user(self,
|
||||
web_request: WebRequest,
|
||||
create: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
username: str = web_request.get_str('username')
|
||||
password: str = web_request.get_str('password')
|
||||
user_info: Dict[str, Any]
|
||||
if username in RESERVED_USERS:
|
||||
raise self.server.error(
|
||||
f"Invalid Request for user {username}")
|
||||
|
@ -278,13 +314,13 @@ class Authorization:
|
|||
action = "user_logged_in"
|
||||
if hashed_pass != user_info['password']:
|
||||
raise self.server.error("Invalid Password")
|
||||
jwt_secret = user_info.get('jwt_secret', None)
|
||||
if jwt_secret is None:
|
||||
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
|
||||
if jwt_secret_hex is None:
|
||||
jwt_secret = secrets.token_bytes(32)
|
||||
user_info['jwt_secret'] = jwt_secret.hex()
|
||||
self.users[username] = user_info
|
||||
else:
|
||||
jwt_secret = bytes.fromhex(jwt_secret)
|
||||
jwt_secret = bytes.fromhex(jwt_secret_hex)
|
||||
token = self._generate_jwt(username, jwt_secret)
|
||||
refresh_token = self._generate_jwt(
|
||||
username, jwt_secret, token_type="refresh",
|
||||
|
@ -301,8 +337,8 @@ class Authorization:
|
|||
'action': action
|
||||
}
|
||||
|
||||
def _delete_jwt_user(self, web_request):
|
||||
username = web_request.get_str('username')
|
||||
def _delete_jwt_user(self, web_request: WebRequest) -> Dict[str, str]:
|
||||
username: str = web_request.get_str('username')
|
||||
current_user = web_request.get_current_user()
|
||||
if current_user is not None:
|
||||
curname = current_user.get('username', None)
|
||||
|
@ -312,7 +348,7 @@ class Authorization:
|
|||
if username in RESERVED_USERS:
|
||||
raise self.server.error(
|
||||
f"Invalid Request for reserved user {username}")
|
||||
user_info = self.users.get(username)
|
||||
user_info: Optional[Dict[str, Any]] = self.users.get(username)
|
||||
if user_info is None:
|
||||
raise self.server.error(f"No registered user: {username}")
|
||||
del self.users[username]
|
||||
|
@ -325,8 +361,12 @@ class Authorization:
|
|||
"action": "user_deleted"
|
||||
}
|
||||
|
||||
def _generate_jwt(self, username, secret, token_type="auth",
|
||||
exp_time=JWT_EXP_TIME):
|
||||
def _generate_jwt(self,
|
||||
username: str,
|
||||
secret: bytes,
|
||||
token_type: str = "auth",
|
||||
exp_time: datetime.timedelta = JWT_EXP_TIME
|
||||
) -> str:
|
||||
curtime = time.time()
|
||||
payload = {
|
||||
'iss': "Moonraker",
|
||||
|
@ -342,27 +382,30 @@ class Authorization:
|
|||
message += b"." + signature
|
||||
return message.decode()
|
||||
|
||||
def _decode_jwt(self, jwt, token_type="auth"):
|
||||
def _decode_jwt(self,
|
||||
jwt: str,
|
||||
token_type: str = "auth"
|
||||
) -> Dict[str, Any]:
|
||||
parts = jwt.encode().split(b".")
|
||||
if len(parts) != 3:
|
||||
raise self.server.error(f"Invalid JWT length of {len(parts)}")
|
||||
header = json.loads(base64url_decode(parts[0]))
|
||||
payload = json.loads(base64url_decode(parts[1]))
|
||||
header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
|
||||
payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
|
||||
if header != JWT_HEADER:
|
||||
raise self.server.error("Invalid JWT header")
|
||||
recd_type = payload.get('token_type', "")
|
||||
recd_type: str = payload.get('token_type', "")
|
||||
if token_type != recd_type:
|
||||
raise self.server.error(
|
||||
f"JWT Token type mismatch: Expected {token_type}, "
|
||||
f"Recd: {recd_type}", 401)
|
||||
if time.time() > payload['exp']:
|
||||
raise self.server.error("JWT expired", 401)
|
||||
username = payload.get('username')
|
||||
user_info = self.users.get(username, None)
|
||||
username: str = payload['username']
|
||||
user_info: Dict[str, Any] = self.users.get(username, None)
|
||||
if user_info is None:
|
||||
raise self.server.error(
|
||||
f"Invalid JWT, no registered user {username}", 401)
|
||||
jwt_secret = user_info.get('jwt_secret', None)
|
||||
jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
|
||||
if jwt_secret is None:
|
||||
raise self.server.error(
|
||||
f"Invalid JWT, user {username} not logged in", 401)
|
||||
|
@ -375,10 +418,10 @@ class Authorization:
|
|||
raise self.server.error("Invalid JWT signature")
|
||||
return user_info
|
||||
|
||||
def _prune_conn_handler(self):
|
||||
def _prune_conn_handler(self) -> None:
|
||||
cur_time = time.time()
|
||||
for ip, user_info in list(self.trusted_users.items()):
|
||||
exp_time = user_info['expires_at']
|
||||
exp_time: float = user_info['expires_at']
|
||||
if cur_time >= exp_time:
|
||||
self.trusted_users.pop(ip, None)
|
||||
logging.info(
|
||||
|
@ -387,7 +430,10 @@ class Authorization:
|
|||
def _oneshot_token_expire_handler(self, token):
|
||||
self.oneshot_tokens.pop(token, None)
|
||||
|
||||
def get_oneshot_token(self, ip_addr, user):
|
||||
def get_oneshot_token(self,
|
||||
ip_addr: IPAddr,
|
||||
user: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
token = base64.b32encode(os.urandom(20)).decode()
|
||||
ioloop = IOLoop.current()
|
||||
hdl = ioloop.call_later(
|
||||
|
@ -395,8 +441,10 @@ class Authorization:
|
|||
self.oneshot_tokens[token] = (ip_addr, user, hdl)
|
||||
return token
|
||||
|
||||
def _check_json_web_token(self, request):
|
||||
auth_token = request.headers.get("Authorization")
|
||||
def _check_json_web_token(self,
|
||||
request: HTTPServerRequest
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
auth_token: Optional[str] = request.headers.get("Authorization")
|
||||
if auth_token is None:
|
||||
auth_token = request.headers.get("X-Access-Token")
|
||||
if auth_token and auth_token.startswith("Bearer "):
|
||||
|
@ -407,7 +455,7 @@ class Authorization:
|
|||
raise HTTPError(401, str(e))
|
||||
return None
|
||||
|
||||
def _check_authorized_ip(self, ip):
|
||||
def _check_authorized_ip(self, ip: IPAddr) -> bool:
|
||||
if ip in self.trusted_ips:
|
||||
return True
|
||||
for rng in self.trusted_ranges:
|
||||
|
@ -418,7 +466,9 @@ class Authorization:
|
|||
return True
|
||||
return False
|
||||
|
||||
def _check_trusted_connection(self, ip):
|
||||
def _check_trusted_connection(self,
|
||||
ip: Optional[IPAddr]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if ip is not None:
|
||||
curtime = time.time()
|
||||
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
|
||||
|
@ -437,7 +487,10 @@ class Authorization:
|
|||
return self.trusted_users[ip]
|
||||
return None
|
||||
|
||||
def _check_oneshot_token(self, token, cur_ip):
|
||||
def _check_oneshot_token(self,
|
||||
token: str,
|
||||
cur_ip: IPAddr
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if token in self.oneshot_tokens:
|
||||
ip_addr, user, hdl = self.oneshot_tokens.pop(token)
|
||||
IOLoop.current().remove_timeout(hdl)
|
||||
|
@ -449,7 +502,9 @@ class Authorization:
|
|||
else:
|
||||
return None
|
||||
|
||||
def check_authorized(self, request):
|
||||
def check_authorized(self,
|
||||
request: HTTPServerRequest
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if request.path in self.permitted_paths or \
|
||||
request.method == "OPTIONS":
|
||||
return None
|
||||
|
@ -467,14 +522,14 @@ class Authorization:
|
|||
ip = None
|
||||
|
||||
# Check oneshot access token
|
||||
ost = request.arguments.get('token', None)
|
||||
ost: Optional[List[bytes]] = request.arguments.get('token', None)
|
||||
if ost is not None:
|
||||
ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
|
||||
if ost_user is not None:
|
||||
return ost_user
|
||||
|
||||
# Check API Key Header
|
||||
key = request.headers.get("X-Api-Key")
|
||||
key: Optional[str] = request.headers.get("X-Api-Key")
|
||||
if key and key == self.api_key:
|
||||
return self.users[API_USER]
|
||||
|
||||
|
@ -485,7 +540,10 @@ class Authorization:
|
|||
|
||||
raise HTTPError(401, "Unauthorized")
|
||||
|
||||
def check_cors(self, origin, request=None):
|
||||
def check_cors(self,
|
||||
origin: Optional[str],
|
||||
req_hdlr: Optional[RequestHandler] = None
|
||||
) -> bool:
|
||||
if origin is None or not self.cors_domains:
|
||||
return False
|
||||
for regex in self.cors_domains:
|
||||
|
@ -494,7 +552,7 @@ class Authorization:
|
|||
if match.group() == origin:
|
||||
logging.debug(f"CORS Pattern Matched, origin: {origin} "
|
||||
f" | pattern: {regex}")
|
||||
self._set_cors_headers(origin, request)
|
||||
self._set_cors_headers(origin, req_hdlr)
|
||||
return True
|
||||
else:
|
||||
logging.debug(f"Partial Cors Match: {match.group()}")
|
||||
|
@ -512,28 +570,31 @@ class Authorization:
|
|||
if self._check_authorized_ip(ipaddr):
|
||||
logging.debug(
|
||||
f"Cors request matched trusted IP: {ip}")
|
||||
self._set_cors_headers(origin, request)
|
||||
self._set_cors_headers(origin, req_hdlr)
|
||||
return True
|
||||
logging.debug(f"No CORS match for origin: {origin}\n"
|
||||
f"Patterns: {self.cors_domains}")
|
||||
return False
|
||||
|
||||
def _set_cors_headers(self, origin, request):
|
||||
if request is None:
|
||||
def _set_cors_headers(self,
|
||||
origin: str,
|
||||
req_hdlr: Optional[RequestHandler]
|
||||
) -> None:
|
||||
if req_hdlr is None:
|
||||
return
|
||||
request.set_header("Access-Control-Allow-Origin", origin)
|
||||
request.set_header(
|
||||
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, OPTIONS")
|
||||
request.set_header(
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, Accept, Content-Type, X-Requested-With, "
|
||||
"X-CRSF-Token, Authorization, X-Access-Token, "
|
||||
"X-Api-Key")
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self.prune_handler.stop()
|
||||
|
||||
|
||||
def load_component(config):
|
||||
def load_component(config: ConfigHelper) -> Authorization:
|
||||
return Authorization(config)
|
||||
|
|
Loading…
Reference in New Issue