authorization: add annotations
Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
410db750c6
commit
41ddbb16a8
|
@ -3,6 +3,8 @@
|
||||||
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
# Copyright (C) 2020 Eric Callahan <arksine.code@gmail.com>
|
||||||
#
|
#
|
||||||
# This file may be distributed under the terms of the GNU GPLv3 license
|
# This file may be distributed under the terms of the GNU GPLv3 license
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import base64
|
import base64
|
||||||
import uuid
|
import uuid
|
||||||
import hashlib
|
import hashlib
|
||||||
|
@ -18,7 +20,28 @@ import socket
|
||||||
import logging
|
import logging
|
||||||
from tornado.ioloop import IOLoop, PeriodicCallback
|
from tornado.ioloop import IOLoop, PeriodicCallback
|
||||||
from tornado.web import HTTPError
|
from tornado.web import HTTPError
|
||||||
from utils import ServerError
|
|
||||||
|
# Annotation imports
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Tuple,
|
||||||
|
Set,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
)
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from confighelper import ConfigHelper
|
||||||
|
from websockets import WebRequest
|
||||||
|
from tornado.httputil import HTTPServerRequest
|
||||||
|
from tornado.web import RequestHandler
|
||||||
|
from . import database
|
||||||
|
DBComp = database.MoonrakerDatabase
|
||||||
|
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||||
|
IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network]
|
||||||
|
OneshotToken = Tuple[IPAddr, Optional[Dict[str, Any]], object]
|
||||||
|
|
||||||
ONESHOT_TIMEOUT = 5
|
ONESHOT_TIMEOUT = 5
|
||||||
TRUSTED_CONNECTION_TIMEOUT = 3600
|
TRUSTED_CONNECTION_TIMEOUT = 3600
|
||||||
|
@ -35,23 +58,23 @@ JWT_HEADER = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# Helpers for base64url encoding and decoding
|
# Helpers for base64url encoding and decoding
|
||||||
def base64url_encode(data):
|
def base64url_encode(data: bytes) -> bytes:
|
||||||
return base64.urlsafe_b64encode(data).rstrip(b"=")
|
return base64.urlsafe_b64encode(data).rstrip(b"=")
|
||||||
|
|
||||||
def base64url_decode(data):
|
def base64url_decode(data) -> bytes:
|
||||||
pad_cnt = len(data) % 4
|
pad_cnt = len(data) % 4
|
||||||
if pad_cnt:
|
if pad_cnt:
|
||||||
data += b"=" * (4 - pad_cnt)
|
data += b"=" * (4 - pad_cnt)
|
||||||
return base64.urlsafe_b64decode(data)
|
return base64.urlsafe_b64decode(data)
|
||||||
|
|
||||||
class Authorization:
|
class Authorization:
|
||||||
def __init__(self, config):
|
def __init__(self, config: ConfigHelper) -> None:
|
||||||
self.server = config.get_server()
|
self.server = config.get_server()
|
||||||
self.login_timeout = config.getint('login_timeout', 90)
|
self.login_timeout = config.getint('login_timeout', 90)
|
||||||
database = self.server.lookup_component('database')
|
database: DBComp = self.server.lookup_component('database')
|
||||||
database.register_local_namespace('authorized_users', forbidden=True)
|
database.register_local_namespace('authorized_users', forbidden=True)
|
||||||
self.users = database.wrap_namespace('authorized_users')
|
self.users = database.wrap_namespace('authorized_users')
|
||||||
api_user = self.users.get(API_USER, None)
|
api_user: Optional[Dict[str, Any]] = self.users.get(API_USER, None)
|
||||||
if api_user is None:
|
if api_user is None:
|
||||||
self.api_key = uuid.uuid4().hex
|
self.api_key = uuid.uuid4().hex
|
||||||
self.users[API_USER] = {
|
self.users[API_USER] = {
|
||||||
|
@ -61,14 +84,14 @@ class Authorization:
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
self.api_key = api_user['api_key']
|
self.api_key = api_user['api_key']
|
||||||
self.trusted_users = {}
|
self.trusted_users: Dict[IPAddr, Any] = {}
|
||||||
self.oneshot_tokens = {}
|
self.oneshot_tokens: Dict[str, OneshotToken] = {}
|
||||||
self.permitted_paths = set()
|
self.permitted_paths: Set[str] = set()
|
||||||
|
|
||||||
# Get allowed cors domains
|
# Get allowed cors domains
|
||||||
self.cors_domains = []
|
self.cors_domains: List[str] = []
|
||||||
cors_cfg = config.get('cors_domains', "").strip()
|
cors_cfg = config.get('cors_domains', "").strip()
|
||||||
cds = [d.strip() for d in cors_cfg.split('\n')if d.strip()]
|
cds = [d.strip() for d in cors_cfg.split('\n') if d.strip()]
|
||||||
for domain in cds:
|
for domain in cds:
|
||||||
bad_match = re.search(r"^.+\.[^:]*\*", domain)
|
bad_match = re.search(r"^.+\.[^:]*\*", domain)
|
||||||
if bad_match is not None:
|
if bad_match is not None:
|
||||||
|
@ -79,12 +102,11 @@ class Authorization:
|
||||||
domain.replace(".", "\\.").replace("*", ".*"))
|
domain.replace(".", "\\.").replace("*", ".*"))
|
||||||
|
|
||||||
# Get Trusted Clients
|
# Get Trusted Clients
|
||||||
self.trusted_ips = []
|
self.trusted_ips: List[IPAddr] = []
|
||||||
self.trusted_ranges = []
|
self.trusted_ranges: List[IPNetwork] = []
|
||||||
self.trusted_domains = []
|
self.trusted_domains: List[str] = []
|
||||||
trusted_clients = config.get('trusted_clients', "")
|
tcs = config.get('trusted_clients', "")
|
||||||
trusted_clients = [c.strip() for c in trusted_clients.split('\n')
|
trusted_clients = [c.strip() for c in tcs.split('\n') if c.strip()]
|
||||||
if c.strip()]
|
|
||||||
for val in trusted_clients:
|
for val in trusted_clients:
|
||||||
# Check IP address
|
# Check IP address
|
||||||
try:
|
try:
|
||||||
|
@ -150,26 +172,27 @@ class Authorization:
|
||||||
self.server.register_notification("authorization:user_created")
|
self.server.register_notification("authorization:user_created")
|
||||||
self.server.register_notification("authorization:user_deleted")
|
self.server.register_notification("authorization:user_deleted")
|
||||||
|
|
||||||
async def _handle_apikey_request(self, web_request):
|
async def _handle_apikey_request(self, web_request: WebRequest) -> str:
|
||||||
action = web_request.get_action()
|
action = web_request.get_action()
|
||||||
if action.upper() == 'POST':
|
if action.upper() == 'POST':
|
||||||
self.api_key = uuid.uuid4().hex
|
self.api_key = uuid.uuid4().hex
|
||||||
self.users[f'{API_USER}.api_key'] = self.api_key
|
self.users[f'{API_USER}.api_key'] = self.api_key
|
||||||
return self.api_key
|
return self.api_key
|
||||||
|
|
||||||
async def _handle_token_request(self, web_request):
|
async def _handle_token_request(self, web_request: WebRequest) -> str:
|
||||||
ip = web_request.get_ip_address()
|
ip = web_request.get_ip_address()
|
||||||
|
assert ip is not None
|
||||||
user_info = web_request.get_current_user()
|
user_info = web_request.get_current_user()
|
||||||
return self.get_oneshot_token(ip, user_info)
|
return self.get_oneshot_token(ip, user_info)
|
||||||
|
|
||||||
async def _handle_login(self, web_request):
|
async def _handle_login(self, web_request: WebRequest) -> Dict[str, Any]:
|
||||||
return self._login_jwt_user(web_request)
|
return self._login_jwt_user(web_request)
|
||||||
|
|
||||||
async def _handle_logout(self, web_request):
|
async def _handle_logout(self, web_request: WebRequest) -> Dict[str, str]:
|
||||||
user_info = web_request.get_current_user()
|
user_info = web_request.get_current_user()
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
raise self.server.error("No user logged in")
|
raise self.server.error("No user logged in")
|
||||||
username = user_info['username']
|
username: str = user_info['username']
|
||||||
if username in RESERVED_USERS:
|
if username in RESERVED_USERS:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid log out request for user {username}")
|
f"Invalid log out request for user {username}")
|
||||||
|
@ -179,10 +202,12 @@ class Authorization:
|
||||||
"action": "user_logged_out"
|
"action": "user_logged_out"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_refresh_jwt(self, web_request):
|
async def _handle_refresh_jwt(self,
|
||||||
refresh_token = web_request.get_str('refresh_token')
|
web_request: WebRequest
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
refresh_token: str = web_request.get_str('refresh_token')
|
||||||
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
user_info = self._decode_jwt(refresh_token, token_type="refresh")
|
||||||
username = user_info['username']
|
username: str = user_info['username']
|
||||||
secret = bytes.fromhex(user_info['jwt_secret'])
|
secret = bytes.fromhex(user_info['jwt_secret'])
|
||||||
token = self._generate_jwt(username, secret)
|
token = self._generate_jwt(username, secret)
|
||||||
return {
|
return {
|
||||||
|
@ -191,7 +216,9 @@ class Authorization:
|
||||||
'action': 'user_jwt_refresh'
|
'action': 'user_jwt_refresh'
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_user_request(self, web_request):
|
async def _handle_user_request(self,
|
||||||
|
web_request: WebRequest
|
||||||
|
) -> Dict[str, Any]:
|
||||||
action = web_request.get_action()
|
action = web_request.get_action()
|
||||||
if action == "GET":
|
if action == "GET":
|
||||||
user = web_request.get_current_user()
|
user = web_request.get_current_user()
|
||||||
|
@ -211,8 +238,11 @@ class Authorization:
|
||||||
elif action == "DELETE":
|
elif action == "DELETE":
|
||||||
# Delete User
|
# Delete User
|
||||||
return self._delete_jwt_user(web_request)
|
return self._delete_jwt_user(web_request)
|
||||||
|
raise self.server.error("Invalid Request Method")
|
||||||
|
|
||||||
async def _handle_list_request(self, web_request):
|
async def _handle_list_request(self,
|
||||||
|
web_request: WebRequest
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
user_list = []
|
user_list = []
|
||||||
for user in self.users.values():
|
for user in self.users.values():
|
||||||
if user['username'] == API_USER:
|
if user['username'] == API_USER:
|
||||||
|
@ -225,9 +255,11 @@ class Authorization:
|
||||||
'users': user_list
|
'users': user_list
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _handle_password_reset(self, web_request):
|
async def _handle_password_reset(self,
|
||||||
password = web_request.get_str('password')
|
web_request: WebRequest
|
||||||
new_pass = web_request.get_str('new_password')
|
) -> Dict[str, str]:
|
||||||
|
password: str = web_request.get_str('password')
|
||||||
|
new_pass: str = web_request.get_str('new_password')
|
||||||
user_info = web_request.get_current_user()
|
user_info = web_request.get_current_user()
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
raise self.server.error("No Current User")
|
raise self.server.error("No Current User")
|
||||||
|
@ -248,9 +280,13 @@ class Authorization:
|
||||||
'action': "user_password_reset"
|
'action': "user_password_reset"
|
||||||
}
|
}
|
||||||
|
|
||||||
def _login_jwt_user(self, web_request, create=False):
|
def _login_jwt_user(self,
|
||||||
username = web_request.get_str('username')
|
web_request: WebRequest,
|
||||||
password = web_request.get_str('password')
|
create: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
username: str = web_request.get_str('username')
|
||||||
|
password: str = web_request.get_str('password')
|
||||||
|
user_info: Dict[str, Any]
|
||||||
if username in RESERVED_USERS:
|
if username in RESERVED_USERS:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid Request for user {username}")
|
f"Invalid Request for user {username}")
|
||||||
|
@ -278,13 +314,13 @@ class Authorization:
|
||||||
action = "user_logged_in"
|
action = "user_logged_in"
|
||||||
if hashed_pass != user_info['password']:
|
if hashed_pass != user_info['password']:
|
||||||
raise self.server.error("Invalid Password")
|
raise self.server.error("Invalid Password")
|
||||||
jwt_secret = user_info.get('jwt_secret', None)
|
jwt_secret_hex: Optional[str] = user_info.get('jwt_secret', None)
|
||||||
if jwt_secret is None:
|
if jwt_secret_hex is None:
|
||||||
jwt_secret = secrets.token_bytes(32)
|
jwt_secret = secrets.token_bytes(32)
|
||||||
user_info['jwt_secret'] = jwt_secret.hex()
|
user_info['jwt_secret'] = jwt_secret.hex()
|
||||||
self.users[username] = user_info
|
self.users[username] = user_info
|
||||||
else:
|
else:
|
||||||
jwt_secret = bytes.fromhex(jwt_secret)
|
jwt_secret = bytes.fromhex(jwt_secret_hex)
|
||||||
token = self._generate_jwt(username, jwt_secret)
|
token = self._generate_jwt(username, jwt_secret)
|
||||||
refresh_token = self._generate_jwt(
|
refresh_token = self._generate_jwt(
|
||||||
username, jwt_secret, token_type="refresh",
|
username, jwt_secret, token_type="refresh",
|
||||||
|
@ -301,8 +337,8 @@ class Authorization:
|
||||||
'action': action
|
'action': action
|
||||||
}
|
}
|
||||||
|
|
||||||
def _delete_jwt_user(self, web_request):
|
def _delete_jwt_user(self, web_request: WebRequest) -> Dict[str, str]:
|
||||||
username = web_request.get_str('username')
|
username: str = web_request.get_str('username')
|
||||||
current_user = web_request.get_current_user()
|
current_user = web_request.get_current_user()
|
||||||
if current_user is not None:
|
if current_user is not None:
|
||||||
curname = current_user.get('username', None)
|
curname = current_user.get('username', None)
|
||||||
|
@ -312,7 +348,7 @@ class Authorization:
|
||||||
if username in RESERVED_USERS:
|
if username in RESERVED_USERS:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid Request for reserved user {username}")
|
f"Invalid Request for reserved user {username}")
|
||||||
user_info = self.users.get(username)
|
user_info: Optional[Dict[str, Any]] = self.users.get(username)
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
raise self.server.error(f"No registered user: {username}")
|
raise self.server.error(f"No registered user: {username}")
|
||||||
del self.users[username]
|
del self.users[username]
|
||||||
|
@ -325,8 +361,12 @@ class Authorization:
|
||||||
"action": "user_deleted"
|
"action": "user_deleted"
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate_jwt(self, username, secret, token_type="auth",
|
def _generate_jwt(self,
|
||||||
exp_time=JWT_EXP_TIME):
|
username: str,
|
||||||
|
secret: bytes,
|
||||||
|
token_type: str = "auth",
|
||||||
|
exp_time: datetime.timedelta = JWT_EXP_TIME
|
||||||
|
) -> str:
|
||||||
curtime = time.time()
|
curtime = time.time()
|
||||||
payload = {
|
payload = {
|
||||||
'iss': "Moonraker",
|
'iss': "Moonraker",
|
||||||
|
@ -342,27 +382,30 @@ class Authorization:
|
||||||
message += b"." + signature
|
message += b"." + signature
|
||||||
return message.decode()
|
return message.decode()
|
||||||
|
|
||||||
def _decode_jwt(self, jwt, token_type="auth"):
|
def _decode_jwt(self,
|
||||||
|
jwt: str,
|
||||||
|
token_type: str = "auth"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
parts = jwt.encode().split(b".")
|
parts = jwt.encode().split(b".")
|
||||||
if len(parts) != 3:
|
if len(parts) != 3:
|
||||||
raise self.server.error(f"Invalid JWT length of {len(parts)}")
|
raise self.server.error(f"Invalid JWT length of {len(parts)}")
|
||||||
header = json.loads(base64url_decode(parts[0]))
|
header: Dict[str, Any] = json.loads(base64url_decode(parts[0]))
|
||||||
payload = json.loads(base64url_decode(parts[1]))
|
payload: Dict[str, Any] = json.loads(base64url_decode(parts[1]))
|
||||||
if header != JWT_HEADER:
|
if header != JWT_HEADER:
|
||||||
raise self.server.error("Invalid JWT header")
|
raise self.server.error("Invalid JWT header")
|
||||||
recd_type = payload.get('token_type', "")
|
recd_type: str = payload.get('token_type', "")
|
||||||
if token_type != recd_type:
|
if token_type != recd_type:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"JWT Token type mismatch: Expected {token_type}, "
|
f"JWT Token type mismatch: Expected {token_type}, "
|
||||||
f"Recd: {recd_type}", 401)
|
f"Recd: {recd_type}", 401)
|
||||||
if time.time() > payload['exp']:
|
if time.time() > payload['exp']:
|
||||||
raise self.server.error("JWT expired", 401)
|
raise self.server.error("JWT expired", 401)
|
||||||
username = payload.get('username')
|
username: str = payload['username']
|
||||||
user_info = self.users.get(username, None)
|
user_info: Dict[str, Any] = self.users.get(username, None)
|
||||||
if user_info is None:
|
if user_info is None:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid JWT, no registered user {username}", 401)
|
f"Invalid JWT, no registered user {username}", 401)
|
||||||
jwt_secret = user_info.get('jwt_secret', None)
|
jwt_secret: Optional[str] = user_info.get('jwt_secret', None)
|
||||||
if jwt_secret is None:
|
if jwt_secret is None:
|
||||||
raise self.server.error(
|
raise self.server.error(
|
||||||
f"Invalid JWT, user {username} not logged in", 401)
|
f"Invalid JWT, user {username} not logged in", 401)
|
||||||
|
@ -375,10 +418,10 @@ class Authorization:
|
||||||
raise self.server.error("Invalid JWT signature")
|
raise self.server.error("Invalid JWT signature")
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
def _prune_conn_handler(self):
|
def _prune_conn_handler(self) -> None:
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
for ip, user_info in list(self.trusted_users.items()):
|
for ip, user_info in list(self.trusted_users.items()):
|
||||||
exp_time = user_info['expires_at']
|
exp_time: float = user_info['expires_at']
|
||||||
if cur_time >= exp_time:
|
if cur_time >= exp_time:
|
||||||
self.trusted_users.pop(ip, None)
|
self.trusted_users.pop(ip, None)
|
||||||
logging.info(
|
logging.info(
|
||||||
|
@ -387,7 +430,10 @@ class Authorization:
|
||||||
def _oneshot_token_expire_handler(self, token):
|
def _oneshot_token_expire_handler(self, token):
|
||||||
self.oneshot_tokens.pop(token, None)
|
self.oneshot_tokens.pop(token, None)
|
||||||
|
|
||||||
def get_oneshot_token(self, ip_addr, user):
|
def get_oneshot_token(self,
|
||||||
|
ip_addr: IPAddr,
|
||||||
|
user: Optional[Dict[str, Any]]
|
||||||
|
) -> str:
|
||||||
token = base64.b32encode(os.urandom(20)).decode()
|
token = base64.b32encode(os.urandom(20)).decode()
|
||||||
ioloop = IOLoop.current()
|
ioloop = IOLoop.current()
|
||||||
hdl = ioloop.call_later(
|
hdl = ioloop.call_later(
|
||||||
|
@ -395,8 +441,10 @@ class Authorization:
|
||||||
self.oneshot_tokens[token] = (ip_addr, user, hdl)
|
self.oneshot_tokens[token] = (ip_addr, user, hdl)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def _check_json_web_token(self, request):
|
def _check_json_web_token(self,
|
||||||
auth_token = request.headers.get("Authorization")
|
request: HTTPServerRequest
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
auth_token: Optional[str] = request.headers.get("Authorization")
|
||||||
if auth_token is None:
|
if auth_token is None:
|
||||||
auth_token = request.headers.get("X-Access-Token")
|
auth_token = request.headers.get("X-Access-Token")
|
||||||
if auth_token and auth_token.startswith("Bearer "):
|
if auth_token and auth_token.startswith("Bearer "):
|
||||||
|
@ -407,7 +455,7 @@ class Authorization:
|
||||||
raise HTTPError(401, str(e))
|
raise HTTPError(401, str(e))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _check_authorized_ip(self, ip):
|
def _check_authorized_ip(self, ip: IPAddr) -> bool:
|
||||||
if ip in self.trusted_ips:
|
if ip in self.trusted_ips:
|
||||||
return True
|
return True
|
||||||
for rng in self.trusted_ranges:
|
for rng in self.trusted_ranges:
|
||||||
|
@ -418,7 +466,9 @@ class Authorization:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _check_trusted_connection(self, ip):
|
def _check_trusted_connection(self,
|
||||||
|
ip: Optional[IPAddr]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
if ip is not None:
|
if ip is not None:
|
||||||
curtime = time.time()
|
curtime = time.time()
|
||||||
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
|
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
|
||||||
|
@ -437,7 +487,10 @@ class Authorization:
|
||||||
return self.trusted_users[ip]
|
return self.trusted_users[ip]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _check_oneshot_token(self, token, cur_ip):
|
def _check_oneshot_token(self,
|
||||||
|
token: str,
|
||||||
|
cur_ip: IPAddr
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
if token in self.oneshot_tokens:
|
if token in self.oneshot_tokens:
|
||||||
ip_addr, user, hdl = self.oneshot_tokens.pop(token)
|
ip_addr, user, hdl = self.oneshot_tokens.pop(token)
|
||||||
IOLoop.current().remove_timeout(hdl)
|
IOLoop.current().remove_timeout(hdl)
|
||||||
|
@ -449,7 +502,9 @@ class Authorization:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def check_authorized(self, request):
|
def check_authorized(self,
|
||||||
|
request: HTTPServerRequest
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
if request.path in self.permitted_paths or \
|
if request.path in self.permitted_paths or \
|
||||||
request.method == "OPTIONS":
|
request.method == "OPTIONS":
|
||||||
return None
|
return None
|
||||||
|
@ -467,14 +522,14 @@ class Authorization:
|
||||||
ip = None
|
ip = None
|
||||||
|
|
||||||
# Check oneshot access token
|
# Check oneshot access token
|
||||||
ost = request.arguments.get('token', None)
|
ost: Optional[List[bytes]] = request.arguments.get('token', None)
|
||||||
if ost is not None:
|
if ost is not None:
|
||||||
ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
|
ost_user = self._check_oneshot_token(ost[-1].decode(), ip)
|
||||||
if ost_user is not None:
|
if ost_user is not None:
|
||||||
return ost_user
|
return ost_user
|
||||||
|
|
||||||
# Check API Key Header
|
# Check API Key Header
|
||||||
key = request.headers.get("X-Api-Key")
|
key: Optional[str] = request.headers.get("X-Api-Key")
|
||||||
if key and key == self.api_key:
|
if key and key == self.api_key:
|
||||||
return self.users[API_USER]
|
return self.users[API_USER]
|
||||||
|
|
||||||
|
@ -485,7 +540,10 @@ class Authorization:
|
||||||
|
|
||||||
raise HTTPError(401, "Unauthorized")
|
raise HTTPError(401, "Unauthorized")
|
||||||
|
|
||||||
def check_cors(self, origin, request=None):
|
def check_cors(self,
|
||||||
|
origin: Optional[str],
|
||||||
|
req_hdlr: Optional[RequestHandler] = None
|
||||||
|
) -> bool:
|
||||||
if origin is None or not self.cors_domains:
|
if origin is None or not self.cors_domains:
|
||||||
return False
|
return False
|
||||||
for regex in self.cors_domains:
|
for regex in self.cors_domains:
|
||||||
|
@ -494,7 +552,7 @@ class Authorization:
|
||||||
if match.group() == origin:
|
if match.group() == origin:
|
||||||
logging.debug(f"CORS Pattern Matched, origin: {origin} "
|
logging.debug(f"CORS Pattern Matched, origin: {origin} "
|
||||||
f" | pattern: {regex}")
|
f" | pattern: {regex}")
|
||||||
self._set_cors_headers(origin, request)
|
self._set_cors_headers(origin, req_hdlr)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logging.debug(f"Partial Cors Match: {match.group()}")
|
logging.debug(f"Partial Cors Match: {match.group()}")
|
||||||
|
@ -512,28 +570,31 @@ class Authorization:
|
||||||
if self._check_authorized_ip(ipaddr):
|
if self._check_authorized_ip(ipaddr):
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Cors request matched trusted IP: {ip}")
|
f"Cors request matched trusted IP: {ip}")
|
||||||
self._set_cors_headers(origin, request)
|
self._set_cors_headers(origin, req_hdlr)
|
||||||
return True
|
return True
|
||||||
logging.debug(f"No CORS match for origin: {origin}\n"
|
logging.debug(f"No CORS match for origin: {origin}\n"
|
||||||
f"Patterns: {self.cors_domains}")
|
f"Patterns: {self.cors_domains}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _set_cors_headers(self, origin, request):
|
def _set_cors_headers(self,
|
||||||
if request is None:
|
origin: str,
|
||||||
|
req_hdlr: Optional[RequestHandler]
|
||||||
|
) -> None:
|
||||||
|
if req_hdlr is None:
|
||||||
return
|
return
|
||||||
request.set_header("Access-Control-Allow-Origin", origin)
|
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
|
||||||
request.set_header(
|
req_hdlr.set_header(
|
||||||
"Access-Control-Allow-Methods",
|
"Access-Control-Allow-Methods",
|
||||||
"GET, POST, PUT, DELETE, OPTIONS")
|
"GET, POST, PUT, DELETE, OPTIONS")
|
||||||
request.set_header(
|
req_hdlr.set_header(
|
||||||
"Access-Control-Allow-Headers",
|
"Access-Control-Allow-Headers",
|
||||||
"Origin, Accept, Content-Type, X-Requested-With, "
|
"Origin, Accept, Content-Type, X-Requested-With, "
|
||||||
"X-CRSF-Token, Authorization, X-Access-Token, "
|
"X-CRSF-Token, Authorization, X-Access-Token, "
|
||||||
"X-Api-Key")
|
"X-Api-Key")
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
self.prune_handler.stop()
|
self.prune_handler.stop()
|
||||||
|
|
||||||
|
|
||||||
def load_component(config):
|
def load_component(config: ConfigHelper) -> Authorization:
|
||||||
return Authorization(config)
|
return Authorization(config)
|
||||||
|
|
Loading…
Reference in New Issue