authorization: remove "permitted_paths" attribute

Track authentication requirements in the API Definition.  This
eliminates the need to look up the authentication component
to disable auth on an endpoint.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2023-11-26 16:26:22 -05:00
parent eed759e111
commit b3b60757aa
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
6 changed files with 51 additions and 62 deletions

View File

@ -311,6 +311,7 @@ class MoonrakerApp:
transports: Union[List[str], TransportType] = TransportType.all(), transports: Union[List[str], TransportType] = TransportType.all(),
wrap_result: bool = True, wrap_result: bool = True,
content_type: Optional[str] = None, content_type: Optional[str] = None,
auth_required: bool = True,
is_remote: bool = False is_remote: bool = False
) -> None: ) -> None:
if isinstance(request_types, list): if isinstance(request_types, list):
@ -318,7 +319,7 @@ class MoonrakerApp:
if isinstance(transports, list): if isinstance(transports, list):
transports = TransportType.from_string_list(transports) transports = TransportType.from_string_list(transports)
api_def = APIDefinition.create( api_def = APIDefinition.create(
endpoint, request_types, callback, transports, is_remote endpoint, request_types, callback, transports, auth_required, is_remote
) )
http_path = api_def.http_path http_path = api_def.http_path
if http_path in self.registered_base_handlers: if http_path in self.registered_base_handlers:
@ -414,7 +415,7 @@ class MoonrakerApp:
class AuthorizedRequestHandler(tornado.web.RequestHandler): class AuthorizedRequestHandler(tornado.web.RequestHandler):
def initialize(self) -> None: def initialize(self) -> None:
self.server: Server = self.settings['server'] self.server: Server = self.settings['server']
self.endpoint: str = "" self.auth_required: bool = True
def set_default_headers(self) -> None: def set_default_headers(self) -> None:
origin: Optional[str] = self.request.headers.get("Origin") origin: Optional[str] = self.request.headers.get("Origin")
@ -427,12 +428,11 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
self.cors_enabled = auth.check_cors(origin, self) self.cors_enabled = auth.check_cors(origin, self)
def prepare(self) -> None: def prepare(self) -> None:
app: MoonrakerApp = self.server.lookup_component("application")
if not self.endpoint:
self.endpoint = app.parse_endpoint(self.request.path or "")
auth: AuthComp = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
self.current_user = auth.check_authorized(self.request, self.endpoint) self.current_user = auth.authenticate_request(
self.request, self.auth_required
)
def options(self, *args, **kwargs) -> None: def options(self, *args, **kwargs) -> None:
# Enable CORS if configured # Enable CORS if configured
@ -476,7 +476,6 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
) -> None: ) -> None:
super(AuthorizedFileHandler, self).initialize(path, default_filename) super(AuthorizedFileHandler, self).initialize(path, default_filename)
self.server: Server = self.settings['server'] self.server: Server = self.settings['server']
self.endpoint: str = ""
def set_default_headers(self) -> None: def set_default_headers(self) -> None:
origin: Optional[str] = self.request.headers.get("Origin") origin: Optional[str] = self.request.headers.get("Origin")
@ -489,11 +488,11 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
self.cors_enabled = auth.check_cors(origin, self) self.cors_enabled = auth.check_cors(origin, self)
def prepare(self) -> None: def prepare(self) -> None:
app: MoonrakerApp = self.server.lookup_component("application")
self.endpoint = app.parse_endpoint(self.request.path or "")
auth: AuthComp = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None and self._check_need_auth(): if auth is not None:
self.current_user = auth.check_authorized(self.request, self.endpoint) self.current_user = auth.authenticate_request(
self.request, self._check_need_auth()
)
def options(self, *args, **kwargs) -> None: def options(self, *args, **kwargs) -> None:
# Enable CORS if configured # Enable CORS if configured
@ -531,7 +530,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
self.api_defintion = api_definition self.api_defintion = api_definition
self.wrap_result = wrap_result self.wrap_result = wrap_result
self.content_type = content_type self.content_type = content_type
self.endpoint = api_definition.endpoint self.auth_required = api_definition.auth_required
# Converts query string values with type hints # Converts query string values with type hints
def _convert_type(self, value: str, hint: str) -> Any: def _convert_type(self, value: str, hint: str) -> Any:
@ -601,10 +600,11 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
def _log_debug(self, header: str, args: Any) -> None: def _log_debug(self, header: str, args: Any) -> None:
if self.server.is_verbose_enabled(): if self.server.is_verbose_enabled():
resp = args resp = args
endpoint = self.api_defintion.endpoint
if isinstance(args, dict): if isinstance(args, dict):
if ( if (
self.endpoint.startswith("/access") or endpoint.startswith("/access") or
self.endpoint.startswith("/machine/sudo/password") endpoint.startswith("/machine/sudo/password")
): ):
resp = {key: "<sanitized>" for key in args} resp = {key: "<sanitized>" for key in args}
elif isinstance(args, str): elif isinstance(args, str):
@ -663,7 +663,9 @@ class FileRequestHandler(AuthorizedFileHandler):
f"filename*=UTF-8\'\'{utf8_basename}") f"filename*=UTF-8\'\'{utf8_basename}")
async def delete(self, path: str) -> None: async def delete(self, path: str) -> None:
path = self.endpoint.lstrip("/").split("/", 2)[-1] app: MoonrakerApp = self.server.lookup_component("application")
endpoint = app.parse_endpoint(self.request.path or "")
path = endpoint.lstrip("/").split("/", 2)[-1]
path = url_unescape(path, plus=False) path = url_unescape(path, plus=False)
file_manager: FileManager file_manager: FileManager
file_manager = self.server.lookup_component('file_manager') file_manager = self.server.lookup_component('file_manager')
@ -944,6 +946,10 @@ class AuthorizedErrorHandler(AuthorizedRequestHandler):
self.finish(jsonw.dumps({'error': err})) self.finish(jsonw.dumps({'error': err}))
class RedirectHandler(AuthorizedRequestHandler): class RedirectHandler(AuthorizedRequestHandler):
def initialize(self) -> None:
super().initialize()
self.auth_required = False
def get(self, *args, **kwargs) -> None: def get(self, *args, **kwargs) -> None:
url: Optional[str] = self.get_argument('url', None) url: Optional[str] = self.get_argument('url', None)
if url is None: if url is None:
@ -972,7 +978,7 @@ class WelcomeHandler(tornado.web.RequestHandler):
auth: AuthComp = self.server.lookup_component("authorization", None) auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is not None: if auth is not None:
try: try:
auth.check_authorized(self.request) auth.authenticate_request(self.request)
except tornado.web.HTTPError: except tornado.web.HTTPError:
authorized = False authorized = False
else: else:

View File

@ -168,6 +168,7 @@ class APIDefinition:
request_types: RequestType request_types: RequestType
transports: TransportType transports: TransportType
callback: Callable[[WebRequest], Coroutine] callback: Callable[[WebRequest], Coroutine]
auth_required: bool
_cache: ClassVar[Dict[str, APIDefinition]] = {} _cache: ClassVar[Dict[str, APIDefinition]] = {}
def request( def request(
@ -196,6 +197,7 @@ class APIDefinition:
request_types: Union[List[str], RequestType], request_types: Union[List[str], RequestType],
callback: Callable[[WebRequest], Coroutine], callback: Callable[[WebRequest], Coroutine],
transports: Union[List[str], TransportType] = TransportType.all(), transports: Union[List[str], TransportType] = TransportType.all(),
auth_required: bool = True,
is_remote: bool = False is_remote: bool = False
) -> APIDefinition: ) -> APIDefinition:
if isinstance(request_types, list): if isinstance(request_types, list):
@ -239,7 +241,7 @@ class APIDefinition:
api_def = cls( api_def = cls(
endpoint, http_path, rpc_methods, request_types, endpoint, http_path, rpc_methods, request_types,
transports, callback transports, callback, auth_required
) )
cls._cache[endpoint] = api_def cls._cache[endpoint] = api_def
return api_def return api_def
@ -335,7 +337,7 @@ class BaseRemoteConnection(APITransport):
def screen_rpc_request( def screen_rpc_request(
self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any] self, api_def: APIDefinition, req_type: RequestType, args: Dict[str, Any]
) -> None: ) -> None:
self.check_authenticated(api_def.endpoint) self.check_authenticated(api_def)
async def _process_message(self, message: str) -> None: async def _process_message(self, message: str) -> None:
try: try:
@ -366,16 +368,16 @@ class BaseRemoteConnection(APITransport):
self.user_info = auth.validate_jwt(token) self.user_info = auth.validate_jwt(token)
elif api_key is not None and self.user_info is None: elif api_key is not None and self.user_info is None:
self.user_info = auth.validate_api_key(api_key) self.user_info = auth.validate_api_key(api_key)
else: elif self._need_auth:
self.check_authenticated() raise self.server.error("Unauthorized", 401)
def check_authenticated(self, path: str = "") -> None: def check_authenticated(self, api_def: APIDefinition) -> None:
if not self._need_auth: if not self._need_auth:
return return
auth: AuthComp = self.server.lookup_component("authorization", None) auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is None: if auth is None:
return return
if not auth.is_path_permitted(path): if api_def.auth_required:
raise self.server.error("Unauthorized", 401) raise self.server.error("Unauthorized", 401)
def on_user_logout(self, user: str) -> bool: def on_user_logout(self, user: str) -> bool:

View File

@ -27,7 +27,6 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Tuple, Tuple,
Set,
Optional, Optional,
Union, Union,
Dict, Dict,
@ -152,7 +151,6 @@ class Authorization:
self.user_db.sync(self.users) self.user_db.sync(self.users)
self.trusted_users: Dict[IPAddr, Any] = {} self.trusted_users: Dict[IPAddr, Any] = {}
self.oneshot_tokens: Dict[str, OneshotToken] = {} self.oneshot_tokens: Dict[str, OneshotToken] = {}
self.permitted_paths: Set[str] = set()
# Get allowed cors domains # Get allowed cors domains
self.cors_domains: List[str] = [] self.cors_domains: List[str] = []
@ -222,13 +220,10 @@ class Authorization:
self._prune_conn_handler) self._prune_conn_handler)
# Register Authorization Endpoints # Register Authorization Endpoints
self.permitted_paths.add("/server/redirect")
self.permitted_paths.add("/access/login")
self.permitted_paths.add("/access/refresh_jwt")
self.permitted_paths.add("/access/info")
self.server.register_endpoint( self.server.register_endpoint(
"/access/login", RequestType.POST, self._handle_login, "/access/login", RequestType.POST, self._handle_login,
transports=TransportType.HTTP | TransportType.WEBSOCKET transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
) )
self.server.register_endpoint( self.server.register_endpoint(
"/access/logout", RequestType.POST, self._handle_logout, "/access/logout", RequestType.POST, self._handle_logout,
@ -236,7 +231,8 @@ class Authorization:
) )
self.server.register_endpoint( self.server.register_endpoint(
"/access/refresh_jwt", RequestType.POST, self._handle_refresh_jwt, "/access/refresh_jwt", RequestType.POST, self._handle_refresh_jwt,
transports=TransportType.HTTP | TransportType.WEBSOCKET transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
) )
self.server.register_endpoint( self.server.register_endpoint(
"/access/user", RequestType.all(), self._handle_user_request, "/access/user", RequestType.all(), self._handle_user_request,
@ -261,7 +257,8 @@ class Authorization:
) )
self.server.register_endpoint( self.server.register_endpoint(
"/access/info", RequestType.GET, self._handle_info_request, "/access/info", RequestType.GET, self._handle_info_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET transports=TransportType.HTTP | TransportType.WEBSOCKET,
auth_required=False
) )
wsm: WebsocketManager = self.server.lookup_component("websockets") wsm: WebsocketManager = self.server.lookup_component("websockets")
wsm.register_notification("authorization:user_created") wsm.register_notification("authorization:user_created")
@ -272,12 +269,6 @@ class Authorization:
"authorization:user_logged_out", event_type="logout" "authorization:user_logged_out", event_type="logout"
) )
def register_permited_path(self, path: str) -> None:
self.permitted_paths.add(path)
def is_path_permitted(self, path: str) -> bool:
return path in self.permitted_paths
def _sync_user(self, username: str) -> None: def _sync_user(self, username: str) -> None:
self.user_db[username] = self.users[username] self.user_db[username] = self.users[username]
@ -770,13 +761,10 @@ class Authorization:
return False return False
return self.failed_logins.get(ip_addr, 0) >= self.max_logins return self.failed_logins.get(ip_addr, 0) >= self.max_logins
def check_authorized( def authenticate_request(
self, request: HTTPServerRequest, endpoint: str = "", self, request: HTTPServerRequest, auth_required: bool = True
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
if ( if request.method == "OPTIONS":
endpoint in self.permitted_paths
or request.method == "OPTIONS"
):
return None return None
# Check JSON Web Token # Check JSON Web Token
@ -804,14 +792,17 @@ class Authorization:
if key and key == self.api_key: if key and key == self.api_key:
return self.users[API_USER] return self.users[API_USER]
# If the force_logins option is enabled and at least one # If the force_logins option is enabled and at least one user is created
# user is created this is an unauthorized request # then trusted user authentication is disabled
if self.force_logins and len(self.users) > 1: if self.force_logins and len(self.users) > 1:
if not auth_required:
return None
raise HTTPError(401, "Unauthorized, Force Logins Enabled") raise HTTPError(401, "Unauthorized, Force Logins Enabled")
# Check if IP is trusted # Check if IP is trusted. If this endpoint doesn't require authentication
# then it is acceptable to return None
trusted_user = self._check_trusted_connection(ip) trusted_user = self._check_trusted_connection(ip)
if trusted_user is not None: if trusted_user is not None or not auth_required:
return trusted_user return trusted_user
raise HTTPError(401, "Unauthorized") raise HTTPError(401, "Unauthorized")

View File

@ -47,7 +47,6 @@ if TYPE_CHECKING:
from .shell_command import ShellCommandFactory as SCMDComp from .shell_command import ShellCommandFactory as SCMDComp
from .database import MoonrakerDatabase from .database import MoonrakerDatabase
from .file_manager.file_manager import FileManager from .file_manager.file_manager import FileManager
from .authorization import Authorization
from .announcements import Announcements from .announcements import Announcements
from .proc_stats import ProcStats from .proc_stats import ProcStats
from .dbus_manager import DbusManager from .dbus_manager import DbusManager
@ -1933,11 +1932,6 @@ class InstallValidator:
if self._sudo_requested: if self._sudo_requested:
return return
self._sudo_requested = True self._sudo_requested = True
auth: Optional[Authorization]
auth = self.server.lookup_component("authorization", None)
if auth is not None:
# Bypass authentication requirements
auth.register_permited_path("/machine/sudo/password")
machine: Machine = self.server.lookup_component("machine") machine: Machine = self.server.lookup_component("machine")
machine.register_sudo_request( machine.register_sudo_request(
self._on_password_received, self._on_password_received,

View File

@ -30,7 +30,6 @@ if TYPE_CHECKING:
from ..confighelper import ConfigHelper from ..confighelper import ConfigHelper
from ..common import WebRequest from ..common import WebRequest
from ..app import MoonrakerApp from ..app import MoonrakerApp
from .authorization import Authorization
from .machine import Machine from .machine import Machine
ZC_SERVICE_TYPE = "_moonraker._tcp.local." ZC_SERVICE_TYPE = "_moonraker._tcp.local."
@ -209,17 +208,14 @@ class SSDPServer(asyncio.protocols.DatagramProtocol):
self.boot_id = int(eventloop.get_loop_time()) self.boot_id = int(eventloop.get_loop_time())
self.config_id = 1 self.config_id = 1
self.ad_timer = eventloop.register_timer(self._advertise_presence) self.ad_timer = eventloop.register_timer(self._advertise_presence)
auth: Optional[Authorization]
auth = self.server.load_component(config, "authorization", None)
if auth is not None:
auth.register_permited_path("/server/zeroconf/ssdp")
self.server.register_endpoint( self.server.register_endpoint(
"/server/zeroconf/ssdp", "/server/zeroconf/ssdp",
RequestType.GET, RequestType.GET,
self._handle_xml_request, self._handle_xml_request,
transports=TransportType.HTTP, transports=TransportType.HTTP,
wrap_result=False, wrap_result=False,
content_type="application/xml" content_type="application/xml",
auth_required=False
) )
def _create_ssdp_socket( def _create_ssdp_socket(

View File

@ -57,7 +57,7 @@ class WebsocketManager:
) )
self.server.register_endpoint( self.server.register_endpoint(
"/server/connection/identify", RequestType.POST, self._handle_identify, "/server/connection/identify", RequestType.POST, self._handle_identify,
TransportType.WEBSOCKET TransportType.WEBSOCKET, auth_required=False
) )
self.server.register_component("websockets", self) self.server.register_component("websockets", self)
@ -321,7 +321,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
auth: AuthComp = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
try: try:
self._user_info = auth.check_authorized(self.request) self._user_info = auth.authenticate_request(self.request)
except Exception as e: except Exception as e:
logging.info(f"Websocket Failed Authentication: {e}") logging.info(f"Websocket Failed Authentication: {e}")
self._user_info = None self._user_info = None
@ -461,7 +461,7 @@ class BridgeSocket(WebSocketHandler):
) )
auth: AuthComp = self.server.lookup_component("authorization", None) auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is not None: if auth is not None:
self.current_user = auth.check_authorized(self.request) self.current_user = auth.authenticate_request(self.request)
kconn: Klippy = self.server.lookup_component("klippy_connection") kconn: Klippy = self.server.lookup_component("klippy_connection")
try: try:
reader, writer = await kconn.open_klippy_connection() reader, writer = await kconn.open_klippy_connection()