authorization: fix blocking call to socket.getfqdn()

If the upstream DNS server is not available the call to socket.getfqdn()
will block until a timeout occurs.  This blocks Moonraker's event loop,
resulting in carnage.

Call getfqdn() in a thread with a timeout of 5 seconds.  In addition,
only request the fqdn if the user has one or more trusted domains
configured.  Finally, cache resolved  FQDNs for 24 hours to limit
repeated DNS queries.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2024-01-22 11:49:51 -05:00
parent 3d44c51613
commit d1f97f2658
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
3 changed files with 107 additions and 82 deletions

View File

@ -458,27 +458,50 @@ class MoonrakerApp:
self.template_cache[asset_name] = asset_tmpl
return asset_tmpl
def _set_cors_headers(req_hdlr: tornado.web.RequestHandler) -> None:
request = req_hdlr.request
origin: Optional[str] = request.headers.get("Origin")
if origin is None:
return
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
if req_hdlr.request.method == "OPTIONS":
req_hdlr.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS"
)
req_hdlr.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token, Authorization, X-Access-Token, "
"X-Api-Key"
)
req_pvt_header = req_hdlr.request.headers.get(
"Access-Control-Request-Private-Network", None
)
if req_pvt_header == "true":
req_hdlr.set_header("Access-Control-Allow-Private-Network", "true")
class AuthorizedRequestHandler(tornado.web.RequestHandler):
def initialize(self) -> None:
self.server: Server = self.settings['server']
self.auth_required: bool = True
self.cors_enabled = False
def set_default_headers(self) -> None:
origin: Optional[str] = self.request.headers.get("Origin")
# it is necessary to look up the parent app here,
# as initialize() may not yet be called
server: Server = self.settings['server']
auth: AuthComp = server.lookup_component('authorization', None)
self.cors_enabled = False
if auth is not None:
self.cors_enabled = auth.check_cors(origin, self)
if getattr(self, "cors_enabled", False):
_set_cors_headers(self)
def prepare(self) -> None:
async def prepare(self) -> None:
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
self.current_user = auth.authenticate_request(
self.current_user = await auth.authenticate_request(
self.request, self.auth_required
)
origin: Optional[str] = self.request.headers.get("Origin")
self.cors_enabled = await auth.check_cors(origin)
if self.cors_enabled:
_set_cors_headers(self)
def options(self, *args, **kwargs) -> None:
# Enable CORS if configured
@ -520,23 +543,22 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
) -> None:
super(AuthorizedFileHandler, self).initialize(path, default_filename)
self.server: Server = self.settings['server']
self.cors_enabled = False
def set_default_headers(self) -> None:
origin: Optional[str] = self.request.headers.get("Origin")
# it is necessary to look up the parent app here,
# as initialize() may not yet be called
server: Server = self.settings['server']
auth: AuthComp = server.lookup_component('authorization', None)
self.cors_enabled = False
if auth is not None:
self.cors_enabled = auth.check_cors(origin, self)
if getattr(self, "cors_enabled", False):
_set_cors_headers(self)
def prepare(self) -> None:
async def prepare(self) -> None:
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
self.current_user = auth.authenticate_request(
self.current_user = await auth.authenticate_request(
self.request, self._check_need_auth()
)
origin: Optional[str] = self.request.headers.get("Origin")
self.cors_enabled = await auth.check_cors(origin)
if self.cors_enabled:
_set_cors_headers(self)
def options(self, *args, **kwargs) -> None:
# Enable CORS if configured
@ -915,8 +937,10 @@ class FileUploadHandler(AuthorizedRequestHandler):
self.parse_lock = Lock()
self.parse_failed: bool = False
def prepare(self) -> None:
super(FileUploadHandler, self).prepare()
async def prepare(self) -> None:
ret = super(FileUploadHandler, self).prepare()
if ret is not None:
await ret
content_type: str = self.request.headers.get("Content-Type", "")
logging.info(
f"Upload Request Received from {self.request.remote_ip}\n"
@ -1017,8 +1041,10 @@ class FileUploadHandler(AuthorizedRequestHandler):
# Default Handler for unregistered endpoints
class AuthorizedErrorHandler(AuthorizedRequestHandler):
def prepare(self) -> None:
super(AuthorizedRequestHandler, self).prepare()
async def prepare(self) -> None:
ret = super(AuthorizedRequestHandler, self).prepare()
if ret is not None:
await ret
self.set_status(404)
raise tornado.web.HTTPError(404)
@ -1038,7 +1064,7 @@ class RedirectHandler(AuthorizedRequestHandler):
super().initialize()
self.auth_required = False
def get(self, *args, **kwargs) -> None:
async def get(self, *args, **kwargs) -> None:
url: Optional[str] = self.get_argument('url', None)
if url is None:
try:
@ -1052,7 +1078,7 @@ class RedirectHandler(AuthorizedRequestHandler):
assert url is not None
# validate the url origin
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is None or not auth.check_cors(url.rstrip("/")):
if auth is None or not await auth.check_cors(url.rstrip("/")):
raise tornado.web.HTTPError(
400, f"Unauthorized URL redirect: {url}")
self.redirect(url)
@ -1066,7 +1092,7 @@ class WelcomeHandler(tornado.web.RequestHandler):
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is not None:
try:
auth.authenticate_request(self.request)
await auth.authenticate_request(self.request)
except tornado.web.HTTPError:
authorized = False
else:

View File

@ -38,7 +38,6 @@ if TYPE_CHECKING:
from ..common import WebRequest
from .websockets import WebsocketManager
from tornado.httputil import HTTPServerRequest
from tornado.web import RequestHandler
from .database import MoonrakerDatabase as DBComp
from .ldap import MoonrakerLDAP
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
@ -58,6 +57,7 @@ def base64url_decode(data: str) -> bytes:
ONESHOT_TIMEOUT = 5
TRUSTED_CONNECTION_TIMEOUT = 3600
FQDN_CACHE_TIMEOUT = 84000
PRUNE_CHECK_TIME = 300.
AUTH_SOURCES = ["moonraker", "ldap"]
@ -80,6 +80,7 @@ class Authorization:
self.enable_api_key = config.getboolean('enable_api_key', True)
self.max_logins = config.getint("max_login_attempts", None, above=0)
self.failed_logins: Dict[IPAddr, int] = {}
self.fqdn_cache: Dict[IPAddr, Dict[str, Any]] = {}
if self.default_source not in AUTH_SOURCES:
raise config.error(
"[authorization]: option 'default_source' - Invalid "
@ -669,8 +670,13 @@ class Authorization:
exp_time: float = user_info['expires_at']
if cur_time >= exp_time:
self.trusted_users.pop(ip, None)
logging.info(
f"Trusted Connection Expired, IP: {ip}")
logging.info(f"Trusted Connection Expired, IP: {ip}")
for ip, fqdn_info in list(self.fqdn_cache.items()):
exp_time = fqdn_info["expires_at"]
if cur_time >= exp_time:
domain: str = fqdn_info["domain"]
self.fqdn_cache.pop(ip, None)
logging.info(f"Cached FQDN Expired, IP: {ip}, domain: {domain}")
return eventtime + PRUNE_CHECK_TIME
def _oneshot_token_expire_handler(self, token):
@ -709,19 +715,34 @@ class Authorization:
raise HTTPError(401, "JWT Decode Error")
return None
def _check_authorized_ip(self, ip: IPAddr) -> bool:
async def _check_authorized_ip(self, ip: IPAddr) -> bool:
if ip in self.trusted_ips:
return True
for rng in self.trusted_ranges:
if ip in rng:
return True
fqdn = socket.getfqdn(str(ip)).lower()
if fqdn in self.trusted_domains:
return True
if self.trusted_domains:
if ip in self.fqdn_cache:
fqdn: str = self.fqdn_cache[ip]["domain"]
else:
eventloop = self.server.get_event_loop()
try:
fut = eventloop.run_in_thread(socket.getfqdn, str(ip))
fqdn = await asyncio.wait_for(fut, 5.0)
except asyncio.TimeoutError:
logging.info("Call to socket.getfqdn() timed out")
return False
else:
fqdn = fqdn.lower()
self.fqdn_cache[ip] = {
"expires_at": time.time() + FQDN_CACHE_TIMEOUT,
"domain": fqdn
}
return fqdn in self.trusted_domains
return False
def _check_trusted_connection(self,
ip: Optional[IPAddr]
async def _check_trusted_connection(
self, ip: Optional[IPAddr]
) -> Optional[Dict[str, Any]]:
if ip is not None:
curtime = time.time()
@ -729,7 +750,7 @@ class Authorization:
if ip in self.trusted_users:
self.trusted_users[ip]['expires_at'] = exp_time
return self.trusted_users[ip]
elif self._check_authorized_ip(ip):
elif await self._check_authorized_ip(ip):
logging.info(
f"Trusted Connection Detected, IP: {ip}")
self.trusted_users[ip] = {
@ -761,7 +782,7 @@ class Authorization:
return False
return self.failed_logins.get(ip_addr, 0) >= self.max_logins
def authenticate_request(
async def authenticate_request(
self, request: HTTPServerRequest, auth_required: bool = True
) -> Optional[Dict[str, Any]]:
if request.method == "OPTIONS":
@ -801,16 +822,13 @@ class Authorization:
# Check if IP is trusted. If this endpoint doesn't require authentication
# then it is acceptable to return None
trusted_user = self._check_trusted_connection(ip)
trusted_user = await self._check_trusted_connection(ip)
if trusted_user is not None or not auth_required:
return trusted_user
raise HTTPError(401, "Unauthorized")
def check_cors(self,
origin: Optional[str],
req_hdlr: Optional[RequestHandler] = None
) -> bool:
async def check_cors(self, origin: Optional[str]) -> bool:
if origin is None or not self.cors_domains:
return False
for regex in self.cors_domains:
@ -819,7 +837,6 @@ class Authorization:
if match.group() == origin:
logging.debug(f"CORS Pattern Matched, origin: {origin} "
f" | pattern: {regex}")
self._set_cors_headers(origin, req_hdlr)
return True
else:
logging.debug(f"Partial Cors Match: {match.group()}")
@ -834,37 +851,13 @@ class Authorization:
except ValueError:
pass
else:
if self._check_authorized_ip(ipaddr):
logging.debug(
f"Cors request matched trusted IP: {ip}")
self._set_cors_headers(origin, req_hdlr)
if await self._check_authorized_ip(ipaddr):
logging.debug(f"Cors request matched trusted IP: {ip}")
return True
logging.debug(f"No CORS match for origin: {origin}\n"
f"Patterns: {self.cors_domains}")
return False
def _set_cors_headers(self,
origin: str,
req_hdlr: Optional[RequestHandler]
) -> None:
if req_hdlr is None:
return
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
if req_hdlr.request.method == "OPTIONS":
req_hdlr.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
req_hdlr.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token, Authorization, X-Access-Token, "
"X-Api-Key")
if req_hdlr.request.headers.get(
"Access-Control-Request-Private-Network", None) == "true":
req_hdlr.set_header(
"Access-Control-Allow-Private-Network",
"true")
def cors_enabled(self) -> bool:
return self.cors_domains is not None

View File

@ -238,6 +238,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
self.on_create(self.settings['server'])
self._ip_addr = parse_ip_address(self.request.remote_ip or "")
self.last_pong_time: float = self.eventloop.get_loop_time()
self.cors_allowed: bool = False
@property
def ip_addr(self) -> Optional[IPAddress]:
@ -308,10 +309,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
def check_origin(self, origin: str) -> bool:
if not super(WebSocket, self).check_origin(origin):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
return auth.check_cors(origin)
return False
return self.cors_allowed
return True
def on_user_logout(self, user: str) -> bool:
@ -321,7 +319,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
return False
# Check Authorized User
def prepare(self) -> None:
async def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"]
if self.__class__.connection_count >= max_conns:
raise self.server.error(
@ -330,11 +328,16 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
try:
self._user_info = auth.authenticate_request(self.request)
self._user_info = await auth.authenticate_request(self.request)
except Exception as e:
logging.info(f"Websocket Failed Authentication: {e}")
self._user_info = None
self._need_auth = True
if "Origin" in self.request.headers:
origin = self.request.headers.get("Origin")
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
self.cors_allowed = await auth.check_cors(origin)
def close_socket(self, code: int, reason: str) -> None:
self.close(code, reason)
@ -351,6 +354,7 @@ class BridgeSocket(WebSocketHandler):
self.klippy_writer: Optional[asyncio.StreamWriter] = None
self.klippy_write_buf: List[bytes] = []
self.klippy_queue_busy: bool = False
self.cors_allowed: bool = False
@property
def ip_addr(self) -> Optional[IPAddress]:
@ -459,10 +463,7 @@ class BridgeSocket(WebSocketHandler):
def check_origin(self, origin: str) -> bool:
if not super().check_origin(origin):
auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None:
return auth.check_cors(origin)
return False
return self.cors_allowed
return True
# Check Authorized User
@ -474,7 +475,12 @@ class BridgeSocket(WebSocketHandler):
)
auth: AuthComp = self.server.lookup_component("authorization", None)
if auth is not None:
self.current_user = auth.authenticate_request(self.request)
self.current_user = await auth.authenticate_request(self.request)
if "Origin" in self.request.headers:
origin = self.request.headers.get("Origin")
else:
origin = self.request.headers.get("Sec-Websocket-Origin", None)
self.cors_allowed = await auth.check_cors(origin)
kconn: Klippy = self.server.lookup_component("klippy_connection")
try:
reader, writer = await kconn.open_klippy_connection()