diff --git a/moonraker/components/authorization.py b/moonraker/components/authorization.py index 659eccd..5861351 100644 --- a/moonraker/components/authorization.py +++ b/moonraker/components/authorization.py @@ -3,6 +3,8 @@ # Copyright (C) 2020 Eric Callahan # # This file may be distributed under the terms of the GNU GPLv3 license + +from __future__ import annotations import base64 import uuid import hashlib @@ -18,7 +20,28 @@ import socket import logging from tornado.ioloop import IOLoop, PeriodicCallback from tornado.web import HTTPError -from utils import ServerError + +# Annotation imports +from typing import ( + TYPE_CHECKING, + Any, + Tuple, + Set, + Optional, + Union, + Dict, + List, +) +if TYPE_CHECKING: + from confighelper import ConfigHelper + from websockets import WebRequest + from tornado.httputil import HTTPServerRequest + from tornado.web import RequestHandler + from . import database + DBComp = database.MoonrakerDatabase + IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] + IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] + OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object] ONESHOT_TIMEOUT = 5 TRUSTED_CONNECTION_TIMEOUT = 3600 @@ -35,23 +58,23 @@ JWT_HEADER = { } # Helpers for base64url encoding and decoding -def base64url_encode(data): +def base64url_encode(data: bytes) -> bytes: return base64.urlsafe_b64encode(data).rstrip(b"=") -def base64url_decode(data): +def base64url_decode(data) -> bytes: pad_cnt = len(data) % 4 if pad_cnt: data += b"=" * (4 - pad_cnt) return base64.urlsafe_b64decode(data) class Authorization: - def __init__(self, config): + def __init__(self, config: ConfigHelper) -> None: self.server = config.get_server() self.login_timeout = config.getint('login_timeout', 90) - database = self.server.lookup_component('database') + database: DBComp = self.server.lookup_component('database') database.register_local_namespace('authorized_users', forbidden=True) self.users = database.wrap_namespace('authorized_users') - api_user = self.users.get(API_USER, None) + api_user: Optional[Dict[str, Any]] = self.users.get(API_USER, None) if api_user is None: self.api_key = uuid.uuid4().hex self.users[API_USER] = { @@ -61,14 +84,14 @@ class Authorization: } else: self.api_key = api_user['api_key'] - self.trusted_users = {} - self.oneshot_tokens = {} - self.permitted_paths = set() + self.trusted_users: Dict[IPAddr, Any] = {} + self.oneshot_tokens: Dict[str, OneshotToken] = {} + self.permitted_paths: Set[str] = set() # Get allowed cors domains - self.cors_domains = [] + self.cors_domains: List[str] = [] cors_cfg = config.get('cors_domains', "").strip() - cds = [d.strip() for d in cors_cfg.split('\n')if d.strip()] + cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()] for domain in cds: bad_match = re.search(r"^.+\.[^:]*\*", domain) if bad_match is not None: @@ -79,12 +102,11 @@ class Authorization: domain.replace(".", "\\.").replace("*", ".*")) # Get Trusted Clients - self.trusted_ips = [] - self.trusted_ranges = [] - self.trusted_domains = [] - trusted_clients = config.get('trusted_clients', "") - trusted_clients = [c.strip() for c in trusted_clients.split('\n') - if c.strip()] + self.trusted_ips: List[IPAddr] = [] + self.trusted_ranges: List[IPNetwork] = [] + self.trusted_domains: List[str] = [] + tcs = config.get('trusted_clients', "") + trusted_clients = [c.strip() for c in tcs.split('\n') if c.strip()] for val in trusted_clients: # Check IP address try: @@ -150,26 +172,27 @@ class Authorization: self.server.register_notification("authorization:user_created") self.server.register_notification("authorization:user_deleted") - async def _handle_apikey_request(self, web_request): + async def _handle_apikey_request(self, web_request: WebRequest) -> str: action = web_request.get_action() if action.upper() == 'POST': self.api_key = uuid.uuid4().hex self.users[f'{API_USER}.api_key'] = self.api_key return self.api_key - async def _handle_token_request(self, web_request): + async def _handle_token_request(self, web_request: WebRequest) -> str: ip = web_request.get_ip_address() + assert ip is not None user_info = web_request.get_current_user() return self.get_oneshot_token(ip, user_info) - async def _handle_login(self, web_request): + async def _handle_login(self, web_request: WebRequest) -> Dict[str, Any]: return self._login_jwt_user(web_request) - async def _handle_logout(self, web_request): + async def _handle_logout(self, web_request: WebRequest) -> Dict[str, str]: user_info = web_request.get_current_user() if user_info is None: raise self.server.error("No user logged in") - username = user_info['username'] + username: str = user_info['username'] if username in RESERVED_USERS: raise self.server.error( f"Invalid log out request for user {username}") @@ -179,10 +202,12 @@ class Authorization: "action": "user_logged_out" } - async def _handle_refresh_jwt(self, web_request): - refresh_token = web_request.get_str('refresh_token') + async def _handle_refresh_jwt(self, + web_request: WebRequest + ) -> Dict[str, str]: + refresh_token: str = web_request.get_str('refresh_token') user_info = self._decode_jwt(refresh_token, token_type="refresh") - username = user_info['username'] + username: str = user_info['username'] secret = bytes.fromhex(user_info['jwt_secret']) token = self._generate_jwt(username, secret) return { @@ -191,7 +216,9 @@ class Authorization: 'action': 'user_jwt_refresh' } - async def _handle_user_request(self, web_request): + async def _handle_user_request(self, + web_request: WebRequest + ) -> Dict[str, Any]: action = web_request.get_action() if action == "GET": user = web_request.get_current_user() @@ -211,8 +238,11 @@ class Authorization: elif action == "DELETE": # Delete User return self._delete_jwt_user(web_request) + raise self.server.error("Invalid Request Method") - async def _handle_list_request(self, web_request): + async def _handle_list_request(self, + web_request: WebRequest + ) -> Dict[str, List[Dict[str, Any]]]: user_list = [] for user in self.users.values(): if user['username'] == API_USER: @@ -225,9 +255,11 @@ class Authorization: 'users': user_list } - async def _handle_password_reset(self, web_request): - password = web_request.get_str('password') - new_pass = web_request.get_str('new_password') + async def _handle_password_reset(self, + web_request: WebRequest + ) -> Dict[str, str]: + password: str = web_request.get_str('password') + new_pass: str = web_request.get_str('new_password') user_info = web_request.get_current_user() if user_info is None: raise self.server.error("No Current User") @@ -248,9 +280,13 @@ class Authorization: 'action': "user_password_reset" } - def _login_jwt_user(self, web_request, create=False): - username = web_request.get_str('username') - password = web_request.get_str('password') + def _login_jwt_user(self, + web_request: WebRequest, + create: bool = False + ) -> Dict[str, Any]: + username: str = web_request.get_str('username') + password: str = web_request.get_str('password') + user_info: Dict[str, Any] if username in RESERVED_USERS: raise self.server.error( f"Invalid Request for user {username}") @@ -278,13 +314,13 @@ class Authorization: action = "user_logged_in" if hashed_pass != user_info['password']: raise self.server.error("Invalid Password") - jwt_secret = user_info.get('jwt_secret', None) - if jwt_secret is None: + jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None) + if jwt_secret_hex is None: jwt_secret = secrets.token_bytes(32) user_info['jwt_secret'] = jwt_secret.hex() self.users[username] = user_info else: - jwt_secret = bytes.fromhex(jwt_secret) + jwt_secret = bytes.fromhex(jwt_secret_hex) token = self._generate_jwt(username, jwt_secret) refresh_token = self._generate_jwt( username, jwt_secret, token_type="refresh", @@ -301,8 +337,8 @@ class Authorization: 'action': action } - def _delete_jwt_user(self, web_request): - username = web_request.get_str('username') + def _delete_jwt_user(self, web_request: WebRequest) -> Dict[str, str]: + username: str = web_request.get_str('username') current_user = web_request.get_current_user() if current_user is not None: curname = current_user.get('username', None) @@ -312,7 +348,7 @@ class Authorization: if username in RESERVED_USERS: raise self.server.error( f"Invalid Request for reserved user {username}") - user_info = self.users.get(username) + user_info: Optional[Dict[str, Any]] = self.users.get(username) if user_info is None: raise self.server.error(f"No registered user: {username}") del self.users[username] @@ -325,8 +361,12 @@ class Authorization: "action": "user_deleted" } - def _generate_jwt(self, username, secret, token_type="auth", - exp_time=JWT_EXP_TIME): + def _generate_jwt(self, + username: str, + secret: bytes, + token_type: str = "auth", + exp_time: datetime.timedelta = JWT_EXP_TIME + ) -> str: curtime = time.time() payload = { 'iss': "Moonraker", @@ -342,27 +382,30 @@ class Authorization: message += b"." + signature return message.decode() - def _decode_jwt(self, jwt, token_type="auth"): + def _decode_jwt(self, + jwt: str, + token_type: str = "auth" + ) -> Dict[str, Any]: parts = jwt.encode().split(b".") if len(parts) != 3: raise self.server.error(f"Invalid JWT length of {len(parts)}") - header = json.loads(base64url_decode(parts[0])) - payload = json.loads(base64url_decode(parts[1])) + header: Dict[str, Any] = json.loads(base64url_decode(parts[0])) + payload: Dict[str, Any] = json.loads(base64url_decode(parts[1])) if header != JWT_HEADER: raise self.server.error("Invalid JWT header") - recd_type = payload.get('token_type', "") + recd_type: str = payload.get('token_type', "") if token_type != recd_type: raise self.server.error( f"JWT Token type mismatch: Expected {token_type}, " f"Recd: {recd_type}", 401) if time.time() > payload['exp']: raise self.server.error("JWT expired", 401) - username = payload.get('username') - user_info = self.users.get(username, None) + username: str = payload['username'] + user_info: Dict[str, Any] = self.users.get(username, None) if user_info is None: raise self.server.error( f"Invalid JWT, no registered user {username}", 401) - jwt_secret = user_info.get('jwt_secret', None) + jwt_secret: Optional[str] = user_info.get('jwt_secret', None) if jwt_secret is None: raise self.server.error( f"Invalid JWT, user {username} not logged in", 401) @@ -375,10 +418,10 @@ class Authorization: raise self.server.error("Invalid JWT signature") return user_info - def _prune_conn_handler(self): + def _prune_conn_handler(self) -> None: cur_time = time.time() for ip, user_info in list(self.trusted_users.items()): - exp_time = user_info['expires_at'] + exp_time: float = user_info['expires_at'] if cur_time >= exp_time: self.trusted_users.pop(ip, None) logging.info( @@ -387,7 +430,10 @@ class Authorization: def _oneshot_token_expire_handler(self, token): self.oneshot_tokens.pop(token, None) - def get_oneshot_token(self, ip_addr, user): + def get_oneshot_token(self, + ip_addr: IPAddr, + user: Optional[Dict[str, Any]] + ) -> str: token = base64.b32encode(os.urandom(20)).decode() ioloop = IOLoop.current() hdl = ioloop.call_later( @@ -395,8 +441,10 @@ class Authorization: self.oneshot_tokens[token] = (ip_addr, user, hdl) return token - def _check_json_web_token(self, request): - auth_token = request.headers.get("Authorization") + def _check_json_web_token(self, + request: HTTPServerRequest + ) -> Optional[Dict[str, Any]]: + auth_token: Optional[str] = request.headers.get("Authorization") if auth_token is None: auth_token = request.headers.get("X-Access-Token") if auth_token and auth_token.startswith("Bearer "): @@ -407,7 +455,7 @@ class Authorization: raise HTTPError(401, str(e)) return None - def _check_authorized_ip(self, ip): + def _check_authorized_ip(self, ip: IPAddr) -> bool: if ip in self.trusted_ips: return True for rng in self.trusted_ranges: @@ -418,7 +466,9 @@ class Authorization: return True return False - def _check_trusted_connection(self, ip): + def _check_trusted_connection(self, + ip: Optional[IPAddr] + ) -> Optional[Dict[str, Any]]: if ip is not None: curtime = time.time() exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT @@ -437,7 +487,10 @@ class Authorization: return self.trusted_users[ip] return None - def _check_oneshot_token(self, token, cur_ip): + def _check_oneshot_token(self, + token: str, + cur_ip: IPAddr + ) -> Optional[Dict[str, Any]]: if token in self.oneshot_tokens: ip_addr, user, hdl = self.oneshot_tokens.pop(token) IOLoop.current().remove_timeout(hdl) @@ -449,7 +502,9 @@ class Authorization: else: return None - def check_authorized(self, request): + def check_authorized(self, + request: HTTPServerRequest + ) -> Optional[Dict[str, Any]]: if request.path in self.permitted_paths or \ request.method == "OPTIONS": return None @@ -467,14 +522,14 @@ class Authorization: ip = None # Check oneshot access token - ost = request.arguments.get('token', None) + ost: Optional[List[bytes]] = request.arguments.get('token', None) if ost is not None: ost_user = self._check_oneshot_token(ost[-1].decode(), ip) if ost_user is not None: return ost_user # Check API Key Header - key = request.headers.get("X-Api-Key") + key: Optional[str] = request.headers.get("X-Api-Key") if key and key == self.api_key: return self.users[API_USER] @@ -485,7 +540,10 @@ class Authorization: raise HTTPError(401, "Unauthorized") - def check_cors(self, origin, request=None): + def check_cors(self, + origin: Optional[str], + req_hdlr: Optional[RequestHandler] = None + ) -> bool: if origin is None or not self.cors_domains: return False for regex in self.cors_domains: @@ -494,7 +552,7 @@ class Authorization: if match.group() == origin: logging.debug(f"CORS Pattern Matched, origin: {origin} " f" | pattern: {regex}") - self._set_cors_headers(origin, request) + self._set_cors_headers(origin, req_hdlr) return True else: logging.debug(f"Partial Cors Match: {match.group()}") @@ -512,28 +570,31 @@ class Authorization: if self._check_authorized_ip(ipaddr): logging.debug( f"Cors request matched trusted IP: {ip}") - self._set_cors_headers(origin, request) + self._set_cors_headers(origin, req_hdlr) return True logging.debug(f"No CORS match for origin: {origin}\n" f"Patterns: {self.cors_domains}") return False - def _set_cors_headers(self, origin, request): - if request is None: + def _set_cors_headers(self, + origin: str, + req_hdlr: Optional[RequestHandler] + ) -> None: + if req_hdlr is None: return - request.set_header("Access-Control-Allow-Origin", origin) - request.set_header( + req_hdlr.set_header("Access-Control-Allow-Origin", origin) + req_hdlr.set_header( "Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - request.set_header( + req_hdlr.set_header( "Access-Control-Allow-Headers", "Origin, Accept, Content-Type, X-Requested-With, " "X-CRSF-Token, Authorization, X-Access-Token, " "X-Api-Key") - def close(self): + def close(self) -> None: self.prune_handler.stop() -def load_component(config): +def load_component(config: ConfigHelper) -> Authorization: return Authorization(config)