app: change `enable_cors` option to `cors_domains`

Rather than allow all origins as was the default with "enable_cors", users may not specify the domains allowed.  If "*" is specified, all domains are allowed.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2020-11-11 20:44:27 -05:00
parent 2cfc5b9501
commit ea62bc9ed1
3 changed files with 37 additions and 28 deletions

View File

@ -112,7 +112,6 @@ class MoonrakerApp:
mimetypes.add_type('text/plain', '.gcode') mimetypes.add_type('text/plain', '.gcode')
mimetypes.add_type('text/plain', '.cfg') mimetypes.add_type('text/plain', '.cfg')
debug = config.getboolean('enable_debug_logging', True) debug = config.getboolean('enable_debug_logging', True)
enable_cors = config.getboolean('enable_cors', False)
# Set up HTTP only requests # Set up HTTP only requests
self.mutable_router = MutableRouter(self) self.mutable_router = MutableRouter(self)
@ -125,8 +124,7 @@ class MoonrakerApp:
app_handlers, app_handlers,
serve_traceback=debug, serve_traceback=debug,
websocket_ping_interval=10, websocket_ping_interval=10,
websocket_ping_timeout=30, websocket_ping_timeout=30)
enable_cors=enable_cors)
self.get_handler_delegate = self.app.get_handler_delegate self.get_handler_delegate = self.app.get_handler_delegate
# Register handlers # Register handlers

View File

@ -26,6 +26,11 @@ class Authorization:
self.trusted_connections = {} self.trusted_connections = {}
self.access_tokens = {} self.access_tokens = {}
# Get allowed cors domains
cors_cfg = config.get('cors_domains', "").strip()
self.cors_domains = [d.strip() for d in cors_cfg.split('\n')
if d.strip()]
# Get Trusted Clients # Get Trusted Clients
self.trusted_ips = [] self.trusted_ips = []
self.trusted_ranges = [] self.trusted_ranges = []
@ -176,6 +181,28 @@ class Authorization:
return True return True
return False return False
def check_cors(self, origin, request=None):
if origin in self.cors_domains:
logging.debug(f"CORS Domain Allowed: {origin}")
self._set_cors_headers(origin, request)
elif "*" in self.cors_domains:
self._set_cors_headers("*", request)
else:
return False
return True
def _set_cors_headers(self, origin, request):
if request is None:
return
request.set_header("Access-Control-Allow-Origin", origin)
request.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
request.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token")
def close(self): def close(self):
self.prune_handler.stop() self.prune_handler.stop()
@ -184,25 +211,17 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
self.server = main_app.get_server() self.server = main_app.get_server()
self.auth = main_app.get_auth() self.auth = main_app.get_auth()
self.wsm = main_app.get_websocket_manager() self.wsm = main_app.get_websocket_manager()
self.cors_enabled = False
def prepare(self): def prepare(self):
if not self.auth.check_authorized(self.request): if not self.auth.check_authorized(self.request):
raise tornado.web.HTTPError(401, "Unauthorized") raise tornado.web.HTTPError(401, "Unauthorized")
origin = self.request.headers.get("Origin")
def set_default_headers(self): self.cors_enabled = self.auth.check_cors(origin, self)
if self.settings['enable_cors']:
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
self.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token")
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
# Enable CORS if configured # Enable CORS if configured
if self.settings['enable_cors']: if self.cors_enabled:
self.set_status(204) self.set_status(204)
self.finish() self.finish()
else: else:
@ -229,25 +248,17 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
super(AuthorizedFileHandler, self).initialize(path, default_filename) super(AuthorizedFileHandler, self).initialize(path, default_filename)
self.server = main_app.get_server() self.server = main_app.get_server()
self.auth = main_app.get_auth() self.auth = main_app.get_auth()
self.cors_enabled = False
def prepare(self): def prepare(self):
if not self.auth.check_authorized(self.request): if not self.auth.check_authorized(self.request):
raise tornado.web.HTTPError(401, "Unauthorized") raise tornado.web.HTTPError(401, "Unauthorized")
origin = self.request.headers.get("Origin")
def set_default_headers(self): self.cors_enabled = self.auth.check_cors(origin, self)
if self.settings['enable_cors']:
self.set_header("Access-Control-Allow-Origin", "*")
self.set_header(
"Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
self.set_header(
"Access-Control-Allow-Headers",
"Origin, Accept, Content-Type, X-Requested-With, "
"X-CRSF-Token")
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
# Enable CORS if configured # Enable CORS if configured
if self.settings['enable_cors']: if self.cors_enabled:
self.set_status(204) self.set_status(204)
self.finish() self.finish()
else: else:

View File

@ -314,7 +314,7 @@ class WebSocket(WebSocketHandler):
io_loop.spawn_callback(self.wsm.remove_websocket, self) io_loop.spawn_callback(self.wsm.remove_websocket, self)
def check_origin(self, origin): def check_origin(self, origin):
if self.settings['enable_cors']: if self.auth.check_cors(origin):
# allow CORS # allow CORS
return True return True
else: else: