diff --git a/moonraker/components/application.py b/moonraker/components/application.py index a983294..62fc42d 100644 --- a/moonraker/components/application.py +++ b/moonraker/components/application.py @@ -458,27 +458,50 @@ class MoonrakerApp: self.template_cache[asset_name] = asset_tmpl return asset_tmpl +def _set_cors_headers(req_hdlr: tornado.web.RequestHandler) -> None: + request = req_hdlr.request + origin: Optional[str] = request.headers.get("Origin") + if origin is None: + return + req_hdlr.set_header("Access-Control-Allow-Origin", origin) + if req_hdlr.request.method == "OPTIONS": + req_hdlr.set_header( + "Access-Control-Allow-Methods", + "GET, POST, PUT, DELETE, OPTIONS" + ) + req_hdlr.set_header( + "Access-Control-Allow-Headers", + "Origin, Accept, Content-Type, X-Requested-With, " + "X-CRSF-Token, Authorization, X-Access-Token, " + "X-Api-Key" + ) + req_pvt_header = req_hdlr.request.headers.get( + "Access-Control-Request-Private-Network", None + ) + if req_pvt_header == "true": + req_hdlr.set_header("Access-Control-Allow-Private-Network", "true") + + class AuthorizedRequestHandler(tornado.web.RequestHandler): def initialize(self) -> None: self.server: Server = self.settings['server'] self.auth_required: bool = True + self.cors_enabled = False def set_default_headers(self) -> None: - origin: Optional[str] = self.request.headers.get("Origin") - # it is necessary to look up the parent app here, - # as initialize() may not yet be called - server: Server = self.settings['server'] - auth: AuthComp = server.lookup_component('authorization', None) - self.cors_enabled = False - if auth is not None: - self.cors_enabled = auth.check_cors(origin, self) + if getattr(self, "cors_enabled", False): + _set_cors_headers(self) - def prepare(self) -> None: + async def prepare(self) -> None: auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: - self.current_user = auth.authenticate_request( + self.current_user = await auth.authenticate_request( self.request, self.auth_required ) + origin: Optional[str] = self.request.headers.get("Origin") + self.cors_enabled = await auth.check_cors(origin) + if self.cors_enabled: + _set_cors_headers(self) def options(self, *args, **kwargs) -> None: # Enable CORS if configured @@ -520,23 +543,22 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler): ) -> None: super(AuthorizedFileHandler, self).initialize(path, default_filename) self.server: Server = self.settings['server'] + self.cors_enabled = False def set_default_headers(self) -> None: - origin: Optional[str] = self.request.headers.get("Origin") - # it is necessary to look up the parent app here, - # as initialize() may not yet be called - server: Server = self.settings['server'] - auth: AuthComp = server.lookup_component('authorization', None) - self.cors_enabled = False - if auth is not None: - self.cors_enabled = auth.check_cors(origin, self) + if getattr(self, "cors_enabled", False): + _set_cors_headers(self) - def prepare(self) -> None: + async def prepare(self) -> None: auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: - self.current_user = auth.authenticate_request( + self.current_user = await auth.authenticate_request( self.request, self._check_need_auth() ) + origin: Optional[str] = self.request.headers.get("Origin") + self.cors_enabled = await auth.check_cors(origin) + if self.cors_enabled: + _set_cors_headers(self) def options(self, *args, **kwargs) -> None: # Enable CORS if configured @@ -915,8 +937,10 @@ class FileUploadHandler(AuthorizedRequestHandler): self.parse_lock = Lock() self.parse_failed: bool = False - def prepare(self) -> None: - super(FileUploadHandler, self).prepare() + async def prepare(self) -> None: + ret = super(FileUploadHandler, self).prepare() + if ret is not None: + await ret content_type: str = self.request.headers.get("Content-Type", "") logging.info( f"Upload Request Received from {self.request.remote_ip}\n" @@ -1017,8 +1041,10 @@ class FileUploadHandler(AuthorizedRequestHandler): # Default Handler for unregistered endpoints class AuthorizedErrorHandler(AuthorizedRequestHandler): - def prepare(self) -> None: - super(AuthorizedRequestHandler, self).prepare() + async def prepare(self) -> None: + ret = super(AuthorizedRequestHandler, self).prepare() + if ret is not None: + await ret self.set_status(404) raise tornado.web.HTTPError(404) @@ -1038,7 +1064,7 @@ class RedirectHandler(AuthorizedRequestHandler): super().initialize() self.auth_required = False - def get(self, *args, **kwargs) -> None: + async def get(self, *args, **kwargs) -> None: url: Optional[str] = self.get_argument('url', None) if url is None: try: @@ -1052,7 +1078,7 @@ class RedirectHandler(AuthorizedRequestHandler): assert url is not None # validate the url origin auth: AuthComp = self.server.lookup_component('authorization', None) - if auth is None or not auth.check_cors(url.rstrip("/")): + if auth is None or not await auth.check_cors(url.rstrip("/")): raise tornado.web.HTTPError( 400, f"Unauthorized URL redirect: {url}") self.redirect(url) @@ -1066,7 +1092,7 @@ class WelcomeHandler(tornado.web.RequestHandler): auth: AuthComp = self.server.lookup_component("authorization", None) if auth is not None: try: - auth.authenticate_request(self.request) + await auth.authenticate_request(self.request) except tornado.web.HTTPError: authorized = False else: diff --git a/moonraker/components/authorization.py b/moonraker/components/authorization.py index 69c2ee3..2af1c3c 100644 --- a/moonraker/components/authorization.py +++ b/moonraker/components/authorization.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from ..common import WebRequest from .websockets import WebsocketManager from tornado.httputil import HTTPServerRequest - from tornado.web import RequestHandler from .database import MoonrakerDatabase as DBComp from .ldap import MoonrakerLDAP IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -58,6 +57,7 @@ def base64url_decode(data: str) -> bytes: ONESHOT_TIMEOUT = 5 TRUSTED_CONNECTION_TIMEOUT = 3600 +FQDN_CACHE_TIMEOUT = 84000 PRUNE_CHECK_TIME = 300. AUTH_SOURCES = ["moonraker", "ldap"] @@ -80,6 +80,7 @@ class Authorization: self.enable_api_key = config.getboolean('enable_api_key', True) self.max_logins = config.getint("max_login_attempts", None, above=0) self.failed_logins: Dict[IPAddr, int] = {} + self.fqdn_cache: Dict[IPAddr, Dict[str, Any]] = {} if self.default_source not in AUTH_SOURCES: raise config.error( "[authorization]: option 'default_source' - Invalid " @@ -669,8 +670,13 @@ class Authorization: exp_time: float = user_info['expires_at'] if cur_time >= exp_time: self.trusted_users.pop(ip, None) - logging.info( - f"Trusted Connection Expired, IP: {ip}") + logging.info(f"Trusted Connection Expired, IP: {ip}") + for ip, fqdn_info in list(self.fqdn_cache.items()): + exp_time = fqdn_info["expires_at"] + if cur_time >= exp_time: + domain: str = fqdn_info["domain"] + self.fqdn_cache.pop(ip, None) + logging.info(f"Cached FQDN Expired, IP: {ip}, domain: {domain}") return eventtime + PRUNE_CHECK_TIME def _oneshot_token_expire_handler(self, token): @@ -709,27 +715,42 @@ class Authorization: raise HTTPError(401, "JWT Decode Error") return None - def _check_authorized_ip(self, ip: IPAddr) -> bool: + async def _check_authorized_ip(self, ip: IPAddr) -> bool: if ip in self.trusted_ips: return True for rng in self.trusted_ranges: if ip in rng: return True - fqdn = socket.getfqdn(str(ip)).lower() - if fqdn in self.trusted_domains: - return True + if self.trusted_domains: + if ip in self.fqdn_cache: + fqdn: str = self.fqdn_cache[ip]["domain"] + else: + eventloop = self.server.get_event_loop() + try: + fut = eventloop.run_in_thread(socket.getfqdn, str(ip)) + fqdn = await asyncio.wait_for(fut, 5.0) + except asyncio.TimeoutError: + logging.info("Call to socket.getfqdn() timed out") + return False + else: + fqdn = fqdn.lower() + self.fqdn_cache[ip] = { + "expires_at": time.time() + FQDN_CACHE_TIMEOUT, + "domain": fqdn + } + return fqdn in self.trusted_domains return False - def _check_trusted_connection(self, - ip: Optional[IPAddr] - ) -> Optional[Dict[str, Any]]: + async def _check_trusted_connection( + self, ip: Optional[IPAddr] + ) -> Optional[Dict[str, Any]]: if ip is not None: curtime = time.time() exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT if ip in self.trusted_users: self.trusted_users[ip]['expires_at'] = exp_time return self.trusted_users[ip] - elif self._check_authorized_ip(ip): + elif await self._check_authorized_ip(ip): logging.info( f"Trusted Connection Detected, IP: {ip}") self.trusted_users[ip] = { @@ -761,7 +782,7 @@ class Authorization: return False return self.failed_logins.get(ip_addr, 0) >= self.max_logins - def authenticate_request( + async def authenticate_request( self, request: HTTPServerRequest, auth_required: bool = True ) -> Optional[Dict[str, Any]]: if request.method == "OPTIONS": @@ -801,16 +822,13 @@ class Authorization: # Check if IP is trusted. If this endpoint doesn't require authentication # then it is acceptable to return None - trusted_user = self._check_trusted_connection(ip) + trusted_user = await self._check_trusted_connection(ip) if trusted_user is not None or not auth_required: return trusted_user raise HTTPError(401, "Unauthorized") - def check_cors(self, - origin: Optional[str], - req_hdlr: Optional[RequestHandler] = None - ) -> bool: + async def check_cors(self, origin: Optional[str]) -> bool: if origin is None or not self.cors_domains: return False for regex in self.cors_domains: @@ -819,7 +837,6 @@ class Authorization: if match.group() == origin: logging.debug(f"CORS Pattern Matched, origin: {origin} " f" | pattern: {regex}") - self._set_cors_headers(origin, req_hdlr) return True else: logging.debug(f"Partial Cors Match: {match.group()}") @@ -834,37 +851,13 @@ class Authorization: except ValueError: pass else: - if self._check_authorized_ip(ipaddr): - logging.debug( - f"Cors request matched trusted IP: {ip}") - self._set_cors_headers(origin, req_hdlr) + if await self._check_authorized_ip(ipaddr): + logging.debug(f"Cors request matched trusted IP: {ip}") return True logging.debug(f"No CORS match for origin: {origin}\n" f"Patterns: {self.cors_domains}") return False - def _set_cors_headers(self, - origin: str, - req_hdlr: Optional[RequestHandler] - ) -> None: - if req_hdlr is None: - return - req_hdlr.set_header("Access-Control-Allow-Origin", origin) - if req_hdlr.request.method == "OPTIONS": - req_hdlr.set_header( - "Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS") - req_hdlr.set_header( - "Access-Control-Allow-Headers", - "Origin, Accept, Content-Type, X-Requested-With, " - "X-CRSF-Token, Authorization, X-Access-Token, " - "X-Api-Key") - if req_hdlr.request.headers.get( - "Access-Control-Request-Private-Network", None) == "true": - req_hdlr.set_header( - "Access-Control-Allow-Private-Network", - "true") - def cors_enabled(self) -> bool: return self.cors_domains is not None diff --git a/moonraker/components/websockets.py b/moonraker/components/websockets.py index bd14460..7d0c85b 100644 --- a/moonraker/components/websockets.py +++ b/moonraker/components/websockets.py @@ -238,6 +238,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection): self.on_create(self.settings['server']) self._ip_addr = parse_ip_address(self.request.remote_ip or "") self.last_pong_time: float = self.eventloop.get_loop_time() + self.cors_allowed: bool = False @property def ip_addr(self) -> Optional[IPAddress]: @@ -308,10 +309,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection): def check_origin(self, origin: str) -> bool: if not super(WebSocket, self).check_origin(origin): - auth: AuthComp = self.server.lookup_component('authorization', None) - if auth is not None: - return auth.check_cors(origin) - return False + return self.cors_allowed return True def on_user_logout(self, user: str) -> bool: @@ -321,7 +319,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection): return False # Check Authorized User - def prepare(self) -> None: + async def prepare(self) -> None: max_conns = self.settings["max_websocket_connections"] if self.__class__.connection_count >= max_conns: raise self.server.error( @@ -330,11 +328,16 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection): auth: AuthComp = self.server.lookup_component('authorization', None) if auth is not None: try: - self._user_info = auth.authenticate_request(self.request) + self._user_info = await auth.authenticate_request(self.request) except Exception as e: logging.info(f"Websocket Failed Authentication: {e}") self._user_info = None self._need_auth = True + if "Origin" in self.request.headers: + origin = self.request.headers.get("Origin") + else: + origin = self.request.headers.get("Sec-Websocket-Origin", None) + self.cors_allowed = await auth.check_cors(origin) def close_socket(self, code: int, reason: str) -> None: self.close(code, reason) @@ -351,6 +354,7 @@ class BridgeSocket(WebSocketHandler): self.klippy_writer: Optional[asyncio.StreamWriter] = None self.klippy_write_buf: List[bytes] = [] self.klippy_queue_busy: bool = False + self.cors_allowed: bool = False @property def ip_addr(self) -> Optional[IPAddress]: @@ -459,10 +463,7 @@ class BridgeSocket(WebSocketHandler): def check_origin(self, origin: str) -> bool: if not super().check_origin(origin): - auth: AuthComp = self.server.lookup_component('authorization', None) - if auth is not None: - return auth.check_cors(origin) - return False + return self.cors_allowed return True # Check Authorized User @@ -474,7 +475,12 @@ class BridgeSocket(WebSocketHandler): ) auth: AuthComp = self.server.lookup_component("authorization", None) if auth is not None: - self.current_user = auth.authenticate_request(self.request) + self.current_user = await auth.authenticate_request(self.request) + if "Origin" in self.request.headers: + origin = self.request.headers.get("Origin") + else: + origin = self.request.headers.get("Sec-Websocket-Origin", None) + self.cors_allowed = await auth.check_cors(origin) kconn: Klippy = self.server.lookup_component("klippy_connection") try: reader, writer = await kconn.open_klippy_connection()