authorization: fix issue cors issue when an error is detected

Tornado clears the headers when an error is detected, "set_default_headers" must be overrridden so that errors are properly returned.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2020-11-14 15:53:42 -05:00
parent 7414781b3a
commit 2d2f8bfbcd
3 changed files with 35 additions and 28 deletions

View File

@ -117,14 +117,15 @@ class MoonrakerApp:
self.mutable_router = MutableRouter(self) self.mutable_router = MutableRouter(self)
app_handlers = [ app_handlers = [
(AnyMatches(), self.mutable_router), (AnyMatches(), self.mutable_router),
(r"/websocket", WebSocket, {'main_app': self}), (r"/websocket", WebSocket),
(r"/api/version", EmulateOctoprintHandler, {'main_app': self})] (r"/api/version", EmulateOctoprintHandler)]
self.app = tornado.web.Application( self.app = tornado.web.Application(
app_handlers, app_handlers,
serve_traceback=debug, serve_traceback=debug,
websocket_ping_interval=10, websocket_ping_interval=10,
websocket_ping_timeout=30) websocket_ping_timeout=30,
parent=self)
self.get_handler_delegate = self.app.get_handler_delegate self.get_handler_delegate = self.app.get_handler_delegate
# Register handlers # Register handlers
@ -165,7 +166,6 @@ class MoonrakerApp:
f"Websocket: {', '.join(api_def.ws_methods)}") f"Websocket: {', '.join(api_def.ws_methods)}")
self.wsm.register_remote_handler(api_def) self.wsm.register_remote_handler(api_def)
params = {} params = {}
params['main_app'] = self
params['arg_parser'] = api_def.parser params['arg_parser'] = api_def.parser
params['remote_callback'] = api_def.endpoint params['remote_callback'] = api_def.endpoint
self.mutable_router.add_handler( self.mutable_router.add_handler(
@ -182,7 +182,6 @@ class MoonrakerApp:
if "http" in protocol: if "http" in protocol:
msg += f" - HTTP: ({' '.join(request_methods)}) {uri}" msg += f" - HTTP: ({' '.join(request_methods)}) {uri}"
params = {} params = {}
params['main_app'] = self
params['methods'] = request_methods params['methods'] = request_methods
params['arg_parser'] = api_def.parser params['arg_parser'] = api_def.parser
params['callback'] = callback params['callback'] = callback
@ -206,12 +205,11 @@ class MoonrakerApp:
logging.info(f"Invalid file path: {file_path}") logging.info(f"Invalid file path: {file_path}")
return return
logging.debug(f"Registering static file: ({pattern}) {file_path}") logging.debug(f"Registering static file: ({pattern}) {file_path}")
params = {'main_app': self, 'path': file_path} params = {'path': file_path}
self.mutable_router.add_handler(pattern, FileRequestHandler, params) self.mutable_router.add_handler(pattern, FileRequestHandler, params)
def register_upload_handler(self, pattern): def register_upload_handler(self, pattern):
params = {'main_app': self} self.mutable_router.add_handler(pattern, FileUploadHandler, {})
self.mutable_router.add_handler(pattern, FileUploadHandler, params)
def remove_handler(self, endpoint): def remove_handler(self, endpoint):
api_def = self.api_cache.get(endpoint) api_def = self.api_cache.get(endpoint)
@ -260,8 +258,8 @@ class MoonrakerApp:
# ***** Dynamic Handlers***** # ***** Dynamic Handlers*****
class RemoteRequestHandler(AuthorizedRequestHandler): class RemoteRequestHandler(AuthorizedRequestHandler):
def initialize(self, main_app, remote_callback, arg_parser): def initialize(self, remote_callback, arg_parser):
super(RemoteRequestHandler, self).initialize(main_app) super(RemoteRequestHandler, self).initialize()
self.remote_callback = remote_callback self.remote_callback = remote_callback
self.query_parser = arg_parser self.query_parser = arg_parser
@ -283,8 +281,8 @@ class RemoteRequestHandler(AuthorizedRequestHandler):
self.finish({'result': result}) self.finish({'result': result})
class LocalRequestHandler(AuthorizedRequestHandler): class LocalRequestHandler(AuthorizedRequestHandler):
def initialize(self, main_app, callback, methods, arg_parser): def initialize(self, callback, methods, arg_parser):
super(LocalRequestHandler, self).initialize(main_app) super(LocalRequestHandler, self).initialize()
self.callback = callback self.callback = callback
self.methods = methods self.methods = methods
self.query_parser = arg_parser self.query_parser = arg_parser

View File

@ -207,17 +207,22 @@ class Authorization:
self.prune_handler.stop() self.prune_handler.stop()
class AuthorizedRequestHandler(tornado.web.RequestHandler): class AuthorizedRequestHandler(tornado.web.RequestHandler):
def initialize(self, main_app): def initialize(self):
self.server = main_app.get_server() app = self.settings['parent']
self.auth = main_app.get_auth() self.server = app.get_server()
self.wsm = main_app.get_websocket_manager() self.auth = app.get_auth()
self.cors_enabled = False self.wsm = app.get_websocket_manager()
def set_default_headers(self):
origin = self.request.headers.get("Origin")
# it is necessary to look up the parent app here,
# as initialize() may not yet be called
auth = self.settings['parent'].get_auth()
self.cors_enabled = auth.check_cors(origin, self)
def prepare(self): def prepare(self):
if not self.auth.check_authorized(self.request): if not self.auth.check_authorized(self.request):
raise tornado.web.HTTPError(401, "Unauthorized") raise tornado.web.HTTPError(401, "Unauthorized")
origin = self.request.headers.get("Origin")
self.cors_enabled = self.auth.check_cors(origin, self)
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
# Enable CORS if configured # Enable CORS if configured
@ -244,17 +249,20 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
# Due to the way Python treats multiple inheritance its best # Due to the way Python treats multiple inheritance its best
# to create a separate authorized handler for serving files # to create a separate authorized handler for serving files
class AuthorizedFileHandler(tornado.web.StaticFileHandler): class AuthorizedFileHandler(tornado.web.StaticFileHandler):
def initialize(self, main_app, path, default_filename=None): def initialize(self, path, default_filename=None):
super(AuthorizedFileHandler, self).initialize(path, default_filename) super(AuthorizedFileHandler, self).initialize(path, default_filename)
self.server = main_app.get_server() app = self.settings['parent']
self.auth = main_app.get_auth() self.server = app.get_server()
self.cors_enabled = False self.auth = app.get_auth()
def set_default_headers(self):
origin = self.request.headers.get("Origin")
auth = self.settings['parent'].get_auth()
self.cors_enabled = auth.check_cors(origin, self)
def prepare(self): def prepare(self):
if not self.auth.check_authorized(self.request): if not self.auth.check_authorized(self.request):
raise tornado.web.HTTPError(401, "Unauthorized") raise tornado.web.HTTPError(401, "Unauthorized")
origin = self.request.headers.get("Origin")
self.cors_enabled = self.auth.check_cors(origin, self)
def options(self, *args, **kwargs): def options(self, *args, **kwargs):
# Enable CORS if configured # Enable CORS if configured

View File

@ -273,9 +273,10 @@ class WebsocketManager:
self.websockets = {} self.websockets = {}
class WebSocket(WebSocketHandler): class WebSocket(WebSocketHandler):
def initialize(self, main_app): def initialize(self):
self.auth = main_app.get_auth() app = self.settings['parent']
self.wsm = main_app.get_websocket_manager() self.auth = app.get_auth()
self.wsm = app.get_websocket_manager()
self.rpc = self.wsm.rpc self.rpc = self.wsm.rpc
self.uid = id(self) self.uid = id(self)