authorization: fix blocking call to socket.getfqdn()
If the upstream DNS server is not available the call to socket.getfqdn() will block until a timeout occurs. This blocks Moonraker's event loop, resulting in carnage. Call getfqdn() in a thread with a timeout of 5 seconds. In addition, only request the fqdn if the user has one or more trusted domains configured. Finally, cache resolved FQDNs for 24 hours to limit repeated DNS queries. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
3d44c51613
commit
d1f97f2658
|
@ -458,27 +458,50 @@ class MoonrakerApp:
|
|||
self.template_cache[asset_name] = asset_tmpl
|
||||
return asset_tmpl
|
||||
|
||||
def _set_cors_headers(req_hdlr: tornado.web.RequestHandler) -> None:
|
||||
request = req_hdlr.request
|
||||
origin: Optional[str] = request.headers.get("Origin")
|
||||
if origin is None:
|
||||
return
|
||||
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
|
||||
if req_hdlr.request.method == "OPTIONS":
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, OPTIONS"
|
||||
)
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, Accept, Content-Type, X-Requested-With, "
|
||||
"X-CRSF-Token, Authorization, X-Access-Token, "
|
||||
"X-Api-Key"
|
||||
)
|
||||
req_pvt_header = req_hdlr.request.headers.get(
|
||||
"Access-Control-Request-Private-Network", None
|
||||
)
|
||||
if req_pvt_header == "true":
|
||||
req_hdlr.set_header("Access-Control-Allow-Private-Network", "true")
|
||||
|
||||
|
||||
class AuthorizedRequestHandler(tornado.web.RequestHandler):
|
||||
def initialize(self) -> None:
|
||||
self.server: Server = self.settings['server']
|
||||
self.auth_required: bool = True
|
||||
self.cors_enabled = False
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
origin: Optional[str] = self.request.headers.get("Origin")
|
||||
# it is necessary to look up the parent app here,
|
||||
# as initialize() may not yet be called
|
||||
server: Server = self.settings['server']
|
||||
auth: AuthComp = server.lookup_component('authorization', None)
|
||||
self.cors_enabled = False
|
||||
if auth is not None:
|
||||
self.cors_enabled = auth.check_cors(origin, self)
|
||||
if getattr(self, "cors_enabled", False):
|
||||
_set_cors_headers(self)
|
||||
|
||||
def prepare(self) -> None:
|
||||
async def prepare(self) -> None:
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
self.current_user = auth.authenticate_request(
|
||||
self.current_user = await auth.authenticate_request(
|
||||
self.request, self.auth_required
|
||||
)
|
||||
origin: Optional[str] = self.request.headers.get("Origin")
|
||||
self.cors_enabled = await auth.check_cors(origin)
|
||||
if self.cors_enabled:
|
||||
_set_cors_headers(self)
|
||||
|
||||
def options(self, *args, **kwargs) -> None:
|
||||
# Enable CORS if configured
|
||||
|
@ -520,23 +543,22 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
|
|||
) -> None:
|
||||
super(AuthorizedFileHandler, self).initialize(path, default_filename)
|
||||
self.server: Server = self.settings['server']
|
||||
self.cors_enabled = False
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
origin: Optional[str] = self.request.headers.get("Origin")
|
||||
# it is necessary to look up the parent app here,
|
||||
# as initialize() may not yet be called
|
||||
server: Server = self.settings['server']
|
||||
auth: AuthComp = server.lookup_component('authorization', None)
|
||||
self.cors_enabled = False
|
||||
if auth is not None:
|
||||
self.cors_enabled = auth.check_cors(origin, self)
|
||||
if getattr(self, "cors_enabled", False):
|
||||
_set_cors_headers(self)
|
||||
|
||||
def prepare(self) -> None:
|
||||
async def prepare(self) -> None:
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
self.current_user = auth.authenticate_request(
|
||||
self.current_user = await auth.authenticate_request(
|
||||
self.request, self._check_need_auth()
|
||||
)
|
||||
origin: Optional[str] = self.request.headers.get("Origin")
|
||||
self.cors_enabled = await auth.check_cors(origin)
|
||||
if self.cors_enabled:
|
||||
_set_cors_headers(self)
|
||||
|
||||
def options(self, *args, **kwargs) -> None:
|
||||
# Enable CORS if configured
|
||||
|
@ -915,8 +937,10 @@ class FileUploadHandler(AuthorizedRequestHandler):
|
|||
self.parse_lock = Lock()
|
||||
self.parse_failed: bool = False
|
||||
|
||||
def prepare(self) -> None:
|
||||
super(FileUploadHandler, self).prepare()
|
||||
async def prepare(self) -> None:
|
||||
ret = super(FileUploadHandler, self).prepare()
|
||||
if ret is not None:
|
||||
await ret
|
||||
content_type: str = self.request.headers.get("Content-Type", "")
|
||||
logging.info(
|
||||
f"Upload Request Received from {self.request.remote_ip}\n"
|
||||
|
@ -1017,8 +1041,10 @@ class FileUploadHandler(AuthorizedRequestHandler):
|
|||
|
||||
# Default Handler for unregistered endpoints
|
||||
class AuthorizedErrorHandler(AuthorizedRequestHandler):
|
||||
def prepare(self) -> None:
|
||||
super(AuthorizedRequestHandler, self).prepare()
|
||||
async def prepare(self) -> None:
|
||||
ret = super(AuthorizedRequestHandler, self).prepare()
|
||||
if ret is not None:
|
||||
await ret
|
||||
self.set_status(404)
|
||||
raise tornado.web.HTTPError(404)
|
||||
|
||||
|
@ -1038,7 +1064,7 @@ class RedirectHandler(AuthorizedRequestHandler):
|
|||
super().initialize()
|
||||
self.auth_required = False
|
||||
|
||||
def get(self, *args, **kwargs) -> None:
|
||||
async def get(self, *args, **kwargs) -> None:
|
||||
url: Optional[str] = self.get_argument('url', None)
|
||||
if url is None:
|
||||
try:
|
||||
|
@ -1052,7 +1078,7 @@ class RedirectHandler(AuthorizedRequestHandler):
|
|||
assert url is not None
|
||||
# validate the url origin
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is None or not auth.check_cors(url.rstrip("/")):
|
||||
if auth is None or not await auth.check_cors(url.rstrip("/")):
|
||||
raise tornado.web.HTTPError(
|
||||
400, f"Unauthorized URL redirect: {url}")
|
||||
self.redirect(url)
|
||||
|
@ -1066,7 +1092,7 @@ class WelcomeHandler(tornado.web.RequestHandler):
|
|||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is not None:
|
||||
try:
|
||||
auth.authenticate_request(self.request)
|
||||
await auth.authenticate_request(self.request)
|
||||
except tornado.web.HTTPError:
|
||||
authorized = False
|
||||
else:
|
||||
|
|
|
@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|||
from ..common import WebRequest
|
||||
from .websockets import WebsocketManager
|
||||
from tornado.httputil import HTTPServerRequest
|
||||
from tornado.web import RequestHandler
|
||||
from .database import MoonrakerDatabase as DBComp
|
||||
from .ldap import MoonrakerLDAP
|
||||
IPAddr = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
@ -58,6 +57,7 @@ def base64url_decode(data: str) -> bytes:
|
|||
|
||||
ONESHOT_TIMEOUT = 5
|
||||
TRUSTED_CONNECTION_TIMEOUT = 3600
|
||||
FQDN_CACHE_TIMEOUT = 84000
|
||||
PRUNE_CHECK_TIME = 300.
|
||||
|
||||
AUTH_SOURCES = ["moonraker", "ldap"]
|
||||
|
@ -80,6 +80,7 @@ class Authorization:
|
|||
self.enable_api_key = config.getboolean('enable_api_key', True)
|
||||
self.max_logins = config.getint("max_login_attempts", None, above=0)
|
||||
self.failed_logins: Dict[IPAddr, int] = {}
|
||||
self.fqdn_cache: Dict[IPAddr, Dict[str, Any]] = {}
|
||||
if self.default_source not in AUTH_SOURCES:
|
||||
raise config.error(
|
||||
"[authorization]: option 'default_source' - Invalid "
|
||||
|
@ -669,8 +670,13 @@ class Authorization:
|
|||
exp_time: float = user_info['expires_at']
|
||||
if cur_time >= exp_time:
|
||||
self.trusted_users.pop(ip, None)
|
||||
logging.info(
|
||||
f"Trusted Connection Expired, IP: {ip}")
|
||||
logging.info(f"Trusted Connection Expired, IP: {ip}")
|
||||
for ip, fqdn_info in list(self.fqdn_cache.items()):
|
||||
exp_time = fqdn_info["expires_at"]
|
||||
if cur_time >= exp_time:
|
||||
domain: str = fqdn_info["domain"]
|
||||
self.fqdn_cache.pop(ip, None)
|
||||
logging.info(f"Cached FQDN Expired, IP: {ip}, domain: {domain}")
|
||||
return eventtime + PRUNE_CHECK_TIME
|
||||
|
||||
def _oneshot_token_expire_handler(self, token):
|
||||
|
@ -709,27 +715,42 @@ class Authorization:
|
|||
raise HTTPError(401, "JWT Decode Error")
|
||||
return None
|
||||
|
||||
def _check_authorized_ip(self, ip: IPAddr) -> bool:
|
||||
async def _check_authorized_ip(self, ip: IPAddr) -> bool:
|
||||
if ip in self.trusted_ips:
|
||||
return True
|
||||
for rng in self.trusted_ranges:
|
||||
if ip in rng:
|
||||
return True
|
||||
fqdn = socket.getfqdn(str(ip)).lower()
|
||||
if fqdn in self.trusted_domains:
|
||||
return True
|
||||
if self.trusted_domains:
|
||||
if ip in self.fqdn_cache:
|
||||
fqdn: str = self.fqdn_cache[ip]["domain"]
|
||||
else:
|
||||
eventloop = self.server.get_event_loop()
|
||||
try:
|
||||
fut = eventloop.run_in_thread(socket.getfqdn, str(ip))
|
||||
fqdn = await asyncio.wait_for(fut, 5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logging.info("Call to socket.getfqdn() timed out")
|
||||
return False
|
||||
else:
|
||||
fqdn = fqdn.lower()
|
||||
self.fqdn_cache[ip] = {
|
||||
"expires_at": time.time() + FQDN_CACHE_TIMEOUT,
|
||||
"domain": fqdn
|
||||
}
|
||||
return fqdn in self.trusted_domains
|
||||
return False
|
||||
|
||||
def _check_trusted_connection(self,
|
||||
ip: Optional[IPAddr]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
async def _check_trusted_connection(
|
||||
self, ip: Optional[IPAddr]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if ip is not None:
|
||||
curtime = time.time()
|
||||
exp_time = curtime + TRUSTED_CONNECTION_TIMEOUT
|
||||
if ip in self.trusted_users:
|
||||
self.trusted_users[ip]['expires_at'] = exp_time
|
||||
return self.trusted_users[ip]
|
||||
elif self._check_authorized_ip(ip):
|
||||
elif await self._check_authorized_ip(ip):
|
||||
logging.info(
|
||||
f"Trusted Connection Detected, IP: {ip}")
|
||||
self.trusted_users[ip] = {
|
||||
|
@ -761,7 +782,7 @@ class Authorization:
|
|||
return False
|
||||
return self.failed_logins.get(ip_addr, 0) >= self.max_logins
|
||||
|
||||
def authenticate_request(
|
||||
async def authenticate_request(
|
||||
self, request: HTTPServerRequest, auth_required: bool = True
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
if request.method == "OPTIONS":
|
||||
|
@ -801,16 +822,13 @@ class Authorization:
|
|||
|
||||
# Check if IP is trusted. If this endpoint doesn't require authentication
|
||||
# then it is acceptable to return None
|
||||
trusted_user = self._check_trusted_connection(ip)
|
||||
trusted_user = await self._check_trusted_connection(ip)
|
||||
if trusted_user is not None or not auth_required:
|
||||
return trusted_user
|
||||
|
||||
raise HTTPError(401, "Unauthorized")
|
||||
|
||||
def check_cors(self,
|
||||
origin: Optional[str],
|
||||
req_hdlr: Optional[RequestHandler] = None
|
||||
) -> bool:
|
||||
async def check_cors(self, origin: Optional[str]) -> bool:
|
||||
if origin is None or not self.cors_domains:
|
||||
return False
|
||||
for regex in self.cors_domains:
|
||||
|
@ -819,7 +837,6 @@ class Authorization:
|
|||
if match.group() == origin:
|
||||
logging.debug(f"CORS Pattern Matched, origin: {origin} "
|
||||
f" | pattern: {regex}")
|
||||
self._set_cors_headers(origin, req_hdlr)
|
||||
return True
|
||||
else:
|
||||
logging.debug(f"Partial Cors Match: {match.group()}")
|
||||
|
@ -834,37 +851,13 @@ class Authorization:
|
|||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if self._check_authorized_ip(ipaddr):
|
||||
logging.debug(
|
||||
f"Cors request matched trusted IP: {ip}")
|
||||
self._set_cors_headers(origin, req_hdlr)
|
||||
if await self._check_authorized_ip(ipaddr):
|
||||
logging.debug(f"Cors request matched trusted IP: {ip}")
|
||||
return True
|
||||
logging.debug(f"No CORS match for origin: {origin}\n"
|
||||
f"Patterns: {self.cors_domains}")
|
||||
return False
|
||||
|
||||
def _set_cors_headers(self,
|
||||
origin: str,
|
||||
req_hdlr: Optional[RequestHandler]
|
||||
) -> None:
|
||||
if req_hdlr is None:
|
||||
return
|
||||
req_hdlr.set_header("Access-Control-Allow-Origin", origin)
|
||||
if req_hdlr.request.method == "OPTIONS":
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET, POST, PUT, DELETE, OPTIONS")
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, Accept, Content-Type, X-Requested-With, "
|
||||
"X-CRSF-Token, Authorization, X-Access-Token, "
|
||||
"X-Api-Key")
|
||||
if req_hdlr.request.headers.get(
|
||||
"Access-Control-Request-Private-Network", None) == "true":
|
||||
req_hdlr.set_header(
|
||||
"Access-Control-Allow-Private-Network",
|
||||
"true")
|
||||
|
||||
def cors_enabled(self) -> bool:
|
||||
return self.cors_domains is not None
|
||||
|
||||
|
|
|
@ -238,6 +238,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
|||
self.on_create(self.settings['server'])
|
||||
self._ip_addr = parse_ip_address(self.request.remote_ip or "")
|
||||
self.last_pong_time: float = self.eventloop.get_loop_time()
|
||||
self.cors_allowed: bool = False
|
||||
|
||||
@property
|
||||
def ip_addr(self) -> Optional[IPAddress]:
|
||||
|
@ -308,10 +309,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
|||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
if not super(WebSocket, self).check_origin(origin):
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
return auth.check_cors(origin)
|
||||
return False
|
||||
return self.cors_allowed
|
||||
return True
|
||||
|
||||
def on_user_logout(self, user: str) -> bool:
|
||||
|
@ -321,7 +319,7 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
|||
return False
|
||||
|
||||
# Check Authorized User
|
||||
def prepare(self) -> None:
|
||||
async def prepare(self) -> None:
|
||||
max_conns = self.settings["max_websocket_connections"]
|
||||
if self.__class__.connection_count >= max_conns:
|
||||
raise self.server.error(
|
||||
|
@ -330,11 +328,16 @@ class WebSocket(WebSocketHandler, BaseRemoteConnection):
|
|||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
try:
|
||||
self._user_info = auth.authenticate_request(self.request)
|
||||
self._user_info = await auth.authenticate_request(self.request)
|
||||
except Exception as e:
|
||||
logging.info(f"Websocket Failed Authentication: {e}")
|
||||
self._user_info = None
|
||||
self._need_auth = True
|
||||
if "Origin" in self.request.headers:
|
||||
origin = self.request.headers.get("Origin")
|
||||
else:
|
||||
origin = self.request.headers.get("Sec-Websocket-Origin", None)
|
||||
self.cors_allowed = await auth.check_cors(origin)
|
||||
|
||||
def close_socket(self, code: int, reason: str) -> None:
|
||||
self.close(code, reason)
|
||||
|
@ -351,6 +354,7 @@ class BridgeSocket(WebSocketHandler):
|
|||
self.klippy_writer: Optional[asyncio.StreamWriter] = None
|
||||
self.klippy_write_buf: List[bytes] = []
|
||||
self.klippy_queue_busy: bool = False
|
||||
self.cors_allowed: bool = False
|
||||
|
||||
@property
|
||||
def ip_addr(self) -> Optional[IPAddress]:
|
||||
|
@ -459,10 +463,7 @@ class BridgeSocket(WebSocketHandler):
|
|||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
if not super().check_origin(origin):
|
||||
auth: AuthComp = self.server.lookup_component('authorization', None)
|
||||
if auth is not None:
|
||||
return auth.check_cors(origin)
|
||||
return False
|
||||
return self.cors_allowed
|
||||
return True
|
||||
|
||||
# Check Authorized User
|
||||
|
@ -474,7 +475,12 @@ class BridgeSocket(WebSocketHandler):
|
|||
)
|
||||
auth: AuthComp = self.server.lookup_component("authorization", None)
|
||||
if auth is not None:
|
||||
self.current_user = auth.authenticate_request(self.request)
|
||||
self.current_user = await auth.authenticate_request(self.request)
|
||||
if "Origin" in self.request.headers:
|
||||
origin = self.request.headers.get("Origin")
|
||||
else:
|
||||
origin = self.request.headers.get("Sec-Websocket-Origin", None)
|
||||
self.cors_allowed = await auth.check_cors(origin)
|
||||
kconn: Klippy = self.server.lookup_component("klippy_connection")
|
||||
try:
|
||||
reader, writer = await kconn.open_klippy_connection()
|
||||
|
|
Loading…
Reference in New Issue