app: refactor the dynamic request handlers

Unify the Local and Remote request handlers into a single handler.  This reduces duplicated code.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-03-01 11:39:48 -05:00
parent fb24917f1a
commit d878340a7a
1 changed files with 39 additions and 68 deletions

View File

@ -61,14 +61,14 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
class APIDefinition: class APIDefinition:
def __init__(self, endpoint, http_uri, ws_methods, def __init__(self, endpoint, http_uri, ws_methods,
request_methods, parser): request_methods, need_object_parser):
self.endpoint = endpoint self.endpoint = endpoint
self.uri = http_uri self.uri = http_uri
self.ws_methods = ws_methods self.ws_methods = ws_methods
if not isinstance(request_methods, list): if not isinstance(request_methods, list):
request_methods = [request_methods] request_methods = [request_methods]
self.request_methods = request_methods self.request_methods = request_methods
self.parser = parser self.need_object_parser = need_object_parser
class MoonrakerApp: class MoonrakerApp:
def __init__(self, config): def __init__(self, config):
@ -146,10 +146,11 @@ class MoonrakerApp:
f"Websocket: {', '.join(api_def.ws_methods)}") f"Websocket: {', '.join(api_def.ws_methods)}")
self.wsm.register_remote_handler(api_def) self.wsm.register_remote_handler(api_def)
params = {} params = {}
params['query_parser'] = api_def.parser params['methods'] = api_def.request_methods
params['remote_callback'] = api_def.endpoint params['callback'] = api_def.endpoint
params['need_object_parser'] = api_def.need_object_parser
self.mutable_router.add_handler( self.mutable_router.add_handler(
api_def.uri, RemoteRequestHandler, params) api_def.uri, DynamicRequestHandler, params)
self.registered_base_handlers.append(api_def.uri) self.registered_base_handlers.append(api_def.uri)
def register_local_handler(self, uri, request_methods, def register_local_handler(self, uri, request_methods,
@ -164,10 +165,10 @@ class MoonrakerApp:
msg += f" - HTTP: ({' '.join(request_methods)}) {uri}" msg += f" - HTTP: ({' '.join(request_methods)}) {uri}"
params = {} params = {}
params['methods'] = request_methods params['methods'] = request_methods
params['query_parser'] = api_def.parser
params['callback'] = callback params['callback'] = callback
params['wrap_result'] = wrap_result params['wrap_result'] = wrap_result
self.mutable_router.add_handler(uri, LocalRequestHandler, params) params['is_remote'] = False
self.mutable_router.add_handler(uri, DynamicRequestHandler, params)
self.registered_base_handlers.append(uri) self.registered_base_handlers.append(uri)
if "websocket" in protocol: if "websocket" in protocol:
msg += f" - Websocket: {', '.join(api_def.ws_methods)}" msg += f" - Websocket: {', '.join(api_def.ws_methods)}"
@ -228,24 +229,23 @@ class MoonrakerApp:
raise self.server.error( raise self.server.error(
"Invalid API definition. Number of websocket methods must " "Invalid API definition. Number of websocket methods must "
"match the number of request methods") "match the number of request methods")
if endpoint.startswith("objects/"): need_object_parser = endpoint.startswith("objects/")
parser = "_status_parser"
else:
parser = "_default_parser"
api_def = APIDefinition(endpoint, uri, ws_methods, api_def = APIDefinition(endpoint, uri, ws_methods,
request_methods, parser) request_methods, need_object_parser)
self.api_cache[endpoint] = api_def self.api_cache[endpoint] = api_def
return api_def return api_def
# ***** Dynamic Handlers***** class DynamicRequestHandler(AuthorizedRequestHandler):
class DynamicRequestBase(AuthorizedRequestHandler): def initialize(self, callback, methods, need_object_parser=False,
def initialize(self, query_parser): is_remote=True, wrap_result=True):
super(DynamicRequestBase, self).initialize() super(DynamicRequestHandler, self).initialize()
try: self.callback = callback
self.query_parser = getattr(self, query_parser) self.methods = methods
except Exception: self.wrap_result = wrap_result
self.query_parser = lambda: {} self._do_request = self._do_remote_request if is_remote \
else self._do_local_request
self._parse_query = self._object_parser if need_object_parser \
else self._default_parser
# Converts query string values with type hints # Converts query string values with type hints
def _convert_type(value, hint): def _convert_type(value, hint):
@ -278,7 +278,7 @@ class DynamicRequestBase(AuthorizedRequestHandler):
args[key_parts[0]] = self._convert_type(val, key_parts[1]) args[key_parts[0]] = self._convert_type(val, key_parts[1])
return args return args
def _status_parser(self): def _object_parser(self):
args = {} args = {}
for key in self.request.arguments.keys(): for key in self.request.arguments.keys():
if key in EXCLUDED_ARGS: if key in EXCLUDED_ARGS:
@ -292,7 +292,7 @@ class DynamicRequestBase(AuthorizedRequestHandler):
return {'objects': args} return {'objects': args}
def parse_args(self): def parse_args(self):
args = self.query_parser() args = self._parse_query()
if self.request.headers.get('Content-Type', "") == "application/json": if self.request.headers.get('Content-Type', "") == "application/json":
try: try:
args.update(json.loads(self.request.body)) args.update(json.loads(self.request.body))
@ -303,66 +303,37 @@ class DynamicRequestBase(AuthorizedRequestHandler):
args[key] = value args[key] = value
return args return args
class RemoteRequestHandler(DynamicRequestBase):
def initialize(self, remote_callback, query_parser):
super(RemoteRequestHandler, self).initialize(query_parser)
self.remote_callback = remote_callback
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs):
await self._process_http_request() await self._process_http_request()
async def post(self, *args, **kwargs): async def post(self, *args, **kwargs):
await self._process_http_request() await self._process_http_request()
async def _process_http_request(self):
conn = self.get_associated_websocket()
args = self.parse_args()
try:
result = await self.server.make_request(
WebRequest(self.remote_callback, args, conn=conn))
except ServerError as e:
raise tornado.web.HTTPError(
e.status_code, str(e)) from e
self.finish({'result': result})
class LocalRequestHandler(DynamicRequestBase):
def initialize(self, callback, methods, query_parser, wrap_result):
super(LocalRequestHandler, self).initialize(query_parser)
self.callback = callback
self.methods = methods
self.wrap_result = wrap_result
async def get(self, *args, **kwargs):
if 'GET' in self.methods:
await self._process_http_request('GET')
else:
raise tornado.web.HTTPError(405)
async def post(self, *args, **kwargs):
if 'POST' in self.methods:
await self._process_http_request('POST')
else:
raise tornado.web.HTTPError(405)
async def delete(self, *args, **kwargs): async def delete(self, *args, **kwargs):
if 'DELETE' in self.methods: await self._process_http_request()
await self._process_http_request('DELETE')
else:
raise tornado.web.HTTPError(405)
async def _process_http_request(self, method): async def _do_local_request(self, args, conn):
return await self.callback(
WebRequest(self.request.path, args, self.request.method,
conn=conn))
async def _do_remote_request(self, args, conn):
return await self.server.make_request(
WebRequest(self.callback, args, conn=conn))
async def _process_http_request(self):
if self.request.method not in self.methods:
raise tornado.web.HTTPError(405)
conn = self.get_associated_websocket() conn = self.get_associated_websocket()
args = self.parse_args() args = self.parse_args()
try: try:
result = await self.callback( result = await self._do_request(args, conn)
WebRequest(self.request.path, args, method, conn=conn))
except ServerError as e: except ServerError as e:
raise tornado.web.HTTPError( raise tornado.web.HTTPError(
e.status_code, str(e)) from e e.status_code, str(e)) from e
if self.wrap_result: if self.wrap_result:
self.finish({'result': result}) result = {'result': result}
else: self.finish(result)
self.finish(result)
class FileRequestHandler(AuthorizedFileHandler): class FileRequestHandler(AuthorizedFileHandler):
def set_extra_headers(self, path): def set_extra_headers(self, path):