app: add annotations

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Arksine 2021-05-11 18:13:33 -04:00
parent 420ba065da
commit b91df6642d
1 changed files with 150 additions and 81 deletions

View File

@ -4,6 +4,7 @@
# #
# This file may be distributed under the terms of the GNU GPLv3 license # This file may be distributed under the terms of the GNU GPLv3 license
from __future__ import annotations
import os import os
import mimetypes import mimetypes
import logging import logging
@ -13,15 +14,37 @@ import traceback
import tornado import tornado
import tornado.iostream import tornado.iostream
import tornado.httputil import tornado.httputil
import tornado.web
from inspect import isclass from inspect import isclass
from tornado.escape import url_unescape from tornado.escape import url_unescape
from tornado.routing import Rule, PathMatches, AnyMatches from tornado.routing import Rule, PathMatches, AnyMatches
from tornado.http1connection import HTTP1Connection
from tornado.log import access_log from tornado.log import access_log
from utils import ServerError from utils import ServerError
from websockets import WebRequest, WebsocketManager, WebSocket from websockets import WebRequest, WebsocketManager, WebSocket
from streaming_form_data import StreamingFormDataParser from streaming_form_data import StreamingFormDataParser
from streaming_form_data.targets import FileTarget, ValueTarget, SHA256Target from streaming_form_data.targets import FileTarget, ValueTarget, SHA256Target
# Annotation imports
from typing import (
TYPE_CHECKING,
Any,
Optional,
Callable,
Coroutine,
Union,
Dict,
List,
)
if TYPE_CHECKING:
from tornado.httpserver import HTTPServer
from moonraker import Server
from confighelper import ConfigHelper
from components.file_manager import FileManager
import components.authorization
MessageDelgate = Optional[tornado.httputil.HTTPMessageDelegate]
AuthComp = Optional[components.authorization.Authorization]
# These endpoints are reserved for klippy/server communication only and are # These endpoints are reserved for klippy/server communication only and are
# not exposed via http or the websocket # not exposed via http or the websocket
RESERVED_ENDPOINTS = [ RESERVED_ENDPOINTS = [
@ -35,12 +58,16 @@ EXCLUDED_ARGS = ["_", "token", "connection_id"]
DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log" DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log"
class MutableRouter(tornado.web.ReversibleRuleRouter): class MutableRouter(tornado.web.ReversibleRuleRouter):
def __init__(self, application): def __init__(self, application: MoonrakerApp) -> None:
self.application = application self.application = application
self.pattern_to_rule = {} self.pattern_to_rule: Dict[str, Rule] = {}
super(MutableRouter, self).__init__(None) super(MutableRouter, self).__init__(None)
def get_target_delegate(self, target, request, **target_params): def get_target_delegate(self,
target: Any,
request: tornado.httputil.HTTPServerRequest,
**target_params
) -> MessageDelgate:
if isclass(target) and issubclass(target, tornado.web.RequestHandler): if isclass(target) and issubclass(target, tornado.web.RequestHandler):
return self.application.get_handler_delegate( return self.application.get_handler_delegate(
request, target, **target_params) request, target, **target_params)
@ -48,17 +75,21 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
return super(MutableRouter, self).get_target_delegate( return super(MutableRouter, self).get_target_delegate(
target, request, **target_params) target, request, **target_params)
def has_rule(self, pattern): def has_rule(self, pattern: str) -> bool:
return pattern in self.pattern_to_rule return pattern in self.pattern_to_rule
def add_handler(self, pattern, target, target_params): def add_handler(self,
pattern: str,
target: Any,
target_params: Optional[Dict[str, Any]]
) -> None:
if pattern in self.pattern_to_rule: if pattern in self.pattern_to_rule:
self.remove_handler(pattern) self.remove_handler(pattern)
new_rule = Rule(PathMatches(pattern), target, target_params) new_rule = Rule(PathMatches(pattern), target, target_params)
self.pattern_to_rule[pattern] = new_rule self.pattern_to_rule[pattern] = new_rule
self.rules.append(new_rule) self.rules.append(new_rule)
def remove_handler(self, pattern): def remove_handler(self, pattern: str) -> None:
rule = self.pattern_to_rule.pop(pattern, None) rule = self.pattern_to_rule.pop(pattern, None)
if rule is not None: if rule is not None:
try: try:
@ -67,8 +98,12 @@ class MutableRouter(tornado.web.ReversibleRuleRouter):
logging.exception(f"Unable to remove rule: {pattern}") logging.exception(f"Unable to remove rule: {pattern}")
class APIDefinition: class APIDefinition:
def __init__(self, endpoint, http_uri, ws_methods, def __init__(self,
request_methods, need_object_parser): endpoint: str,
http_uri: str,
ws_methods: List[str],
request_methods: Union[str, List[str]],
need_object_parser: bool):
self.endpoint = endpoint self.endpoint = endpoint
self.uri = http_uri self.uri = http_uri
self.ws_methods = ws_methods self.ws_methods = ws_methods
@ -78,11 +113,11 @@ class APIDefinition:
self.need_object_parser = need_object_parser self.need_object_parser = need_object_parser
class MoonrakerApp: class MoonrakerApp:
def __init__(self, config): def __init__(self, config: ConfigHelper) -> None:
self.server = config.get_server() self.server = config.get_server()
self.tornado_server = None self.tornado_server: Optional[HTTPServer] = None
self.api_cache = {} self.api_cache: Dict[str, APIDefinition] = {}
self.registered_base_handlers = [] self.registered_base_handlers: List[str] = []
self.max_upload_size = config.getint('max_upload_size', 1024) self.max_upload_size = config.getint('max_upload_size', 1024)
self.max_upload_size *= 1024 * 1024 self.max_upload_size *= 1024 * 1024
@ -96,7 +131,7 @@ class MoonrakerApp:
self.debug = config.getboolean('enable_debug_logging', False) self.debug = config.getboolean('enable_debug_logging', False)
log_level = logging.DEBUG if self.debug else logging.INFO log_level = logging.DEBUG if self.debug else logging.INFO
logging.getLogger().setLevel(log_level) logging.getLogger().setLevel(log_level)
app_args = { app_args: Dict[str, Any] = {
'serve_traceback': self.debug, 'serve_traceback': self.debug,
'websocket_ping_interval': 10, 'websocket_ping_interval': 10,
'websocket_ping_timeout': 30, 'websocket_ping_timeout': 30,
@ -108,7 +143,7 @@ class MoonrakerApp:
# Set up HTTP only requests # Set up HTTP only requests
self.mutable_router = MutableRouter(self) self.mutable_router = MutableRouter(self)
app_handlers = [ app_handlers: List[Any] = [
(AnyMatches(), self.mutable_router), (AnyMatches(), self.mutable_router),
(r"/websocket", WebSocket)] (r"/websocket", WebSocket)]
self.app = tornado.web.Application(app_handlers, **app_args) self.app = tornado.web.Application(app_handlers, **app_args)
@ -122,12 +157,12 @@ class MoonrakerApp:
self.register_static_file_handler( self.register_static_file_handler(
"klippy.log", DEFAULT_KLIPPY_LOG_PATH, force=True) "klippy.log", DEFAULT_KLIPPY_LOG_PATH, force=True)
def listen(self, host, port): def listen(self, host: str, port: int) -> None:
self.tornado_server = self.app.listen( self.tornado_server = self.app.listen(
port, address=host, max_body_size=MAX_BODY_SIZE, port, address=host, max_body_size=MAX_BODY_SIZE,
xheaders=True) xheaders=True)
def log_request(self, handler): def log_request(self, handler: tornado.web.RequestHandler) -> None:
status_code = handler.get_status() status_code = handler.get_status()
if not self.debug and status_code in [200, 204, 206, 304]: if not self.debug and status_code in [200, 204, 206, 304]:
# don't log successful requests in release mode # don't log successful requests in release mode
@ -147,19 +182,19 @@ class MoonrakerApp:
f"{status_code} {handler._request_summary()} " f"{status_code} {handler._request_summary()} "
f"[{username}] {request_time:.2f}ms") f"[{username}] {request_time:.2f}ms")
def get_server(self): def get_server(self) -> Server:
return self.server return self.server
def get_websocket_manager(self): def get_websocket_manager(self) -> WebsocketManager:
return self.wsm return self.wsm
async def close(self): async def close(self) -> None:
if self.tornado_server is not None: if self.tornado_server is not None:
self.tornado_server.stop() self.tornado_server.stop()
await self.tornado_server.close_all_connections() await self.tornado_server.close_all_connections()
await self.wsm.close() await self.wsm.close()
def register_remote_handler(self, endpoint): def register_remote_handler(self, endpoint: str) -> None:
if endpoint in RESERVED_ENDPOINTS: if endpoint in RESERVED_ENDPOINTS:
return return
api_def = self._create_api_definition(endpoint) api_def = self._create_api_definition(endpoint)
@ -171,7 +206,7 @@ class MoonrakerApp:
f"HTTP: ({' '.join(api_def.request_methods)}) {api_def.uri}; " f"HTTP: ({' '.join(api_def.request_methods)}) {api_def.uri}; "
f"Websocket: {', '.join(api_def.ws_methods)}") f"Websocket: {', '.join(api_def.ws_methods)}")
self.wsm.register_remote_handler(api_def) self.wsm.register_remote_handler(api_def)
params = {} params: Dict[str, Any] = {}
params['methods'] = api_def.request_methods params['methods'] = api_def.request_methods
params['callback'] = api_def.endpoint params['callback'] = api_def.endpoint
params['need_object_parser'] = api_def.need_object_parser params['need_object_parser'] = api_def.need_object_parser
@ -179,9 +214,13 @@ class MoonrakerApp:
api_def.uri, DynamicRequestHandler, params) api_def.uri, DynamicRequestHandler, params)
self.registered_base_handlers.append(api_def.uri) self.registered_base_handlers.append(api_def.uri)
def register_local_handler(self, uri, request_methods, def register_local_handler(self,
callback, protocol=["http", "websocket"], uri: str,
wrap_result=True): request_methods: List[str],
callback: Callable[[WebRequest], Coroutine],
protocol: List[str] = ["http", "websocket"],
wrap_result: bool = True
) -> None:
if uri in self.registered_base_handlers: if uri in self.registered_base_handlers:
return return
api_def = self._create_api_definition( api_def = self._create_api_definition(
@ -189,7 +228,7 @@ class MoonrakerApp:
msg = "Registering local endpoint" msg = "Registering local endpoint"
if "http" in protocol: if "http" in protocol:
msg += f" - HTTP: ({' '.join(request_methods)}) {uri}" msg += f" - HTTP: ({' '.join(request_methods)}) {uri}"
params = {} params: dict[str, Any] = {}
params['methods'] = request_methods params['methods'] = request_methods
params['callback'] = callback params['callback'] = callback
params['wrap_result'] = wrap_result params['wrap_result'] = wrap_result
@ -201,7 +240,11 @@ class MoonrakerApp:
self.wsm.register_local_handler(api_def, callback) self.wsm.register_local_handler(api_def, callback)
logging.info(msg) logging.info(msg)
def register_static_file_handler(self, pattern, file_path, force=False): def register_static_file_handler(self,
pattern: str,
file_path: str,
force: bool = False
) -> None:
if pattern[0] != "/": if pattern[0] != "/":
pattern = "/server/files/" + pattern pattern = "/server/files/" + pattern
if os.path.isfile(file_path) or force: if os.path.isfile(file_path) or force:
@ -217,19 +260,23 @@ class MoonrakerApp:
params = {'path': file_path} params = {'path': file_path}
self.mutable_router.add_handler(pattern, FileRequestHandler, params) self.mutable_router.add_handler(pattern, FileRequestHandler, params)
def register_upload_handler(self, pattern): def register_upload_handler(self, pattern: str) -> None:
self.mutable_router.add_handler( self.mutable_router.add_handler(
pattern, FileUploadHandler, pattern, FileUploadHandler,
{'max_upload_size': self.max_upload_size}) {'max_upload_size': self.max_upload_size})
def remove_handler(self, endpoint): def remove_handler(self, endpoint: str) -> None:
api_def = self.api_cache.get(endpoint) api_def = self.api_cache.get(endpoint)
if api_def is not None: if api_def is not None:
self.wsm.remove_handler(api_def.uri) self.mutable_router.remove_handler(api_def.uri)
self.mutable_router.remove_handler(api_def.ws_method) for ws_method in api_def.ws_methods:
self.wsm.remove_handler(ws_method)
def _create_api_definition(self, endpoint, request_methods=[], def _create_api_definition(self,
is_remote=True): endpoint: str,
request_methods: List[str] = [],
is_remote=True
) -> APIDefinition:
if endpoint in self.api_cache: if endpoint in self.api_cache:
return self.api_cache[endpoint] return self.api_cache[endpoint]
if endpoint[0] == '/': if endpoint[0] == '/':
@ -264,25 +311,25 @@ class MoonrakerApp:
return api_def return api_def
class AuthorizedRequestHandler(tornado.web.RequestHandler): class AuthorizedRequestHandler(tornado.web.RequestHandler):
def initialize(self): def initialize(self) -> None:
self.server = self.settings['parent'].get_server() self.server: Server = self.settings['parent'].get_server()
def set_default_headers(self): def set_default_headers(self) -> None:
origin = self.request.headers.get("Origin") origin: Optional[str] = self.request.headers.get("Origin")
# it is necessary to look up the parent app here, # it is necessary to look up the parent app here,
# as initialize() may not yet be called # as initialize() may not yet be called
server = self.settings['parent'].get_server() server: Server = self.settings['parent'].get_server()
auth = server.lookup_component('authorization', None) auth: AuthComp = server.lookup_component('authorization', None)
self.cors_enabled = False self.cors_enabled = False
if auth is not None: if auth is not None:
self.cors_enabled = auth.check_cors(origin, self) self.cors_enabled = auth.check_cors(origin, self)
def prepare(self): def prepare(self) -> None:
auth = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
self.current_user = auth.check_authorized(self.request) self.current_user = auth.check_authorized(self.request)
def options(self, *args, **kwargs): def options(self, *args, **kwargs) -> None:
# Enable CORS if configured # Enable CORS if configured
if self.cors_enabled: if self.cors_enabled:
self.set_status(204) self.set_status(204)
@ -290,22 +337,23 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
else: else:
super(AuthorizedRequestHandler, self).options() super(AuthorizedRequestHandler, self).options()
def get_associated_websocket(self): def get_associated_websocket(self) -> Optional[WebSocket]:
# Return associated websocket connection if an id # Return associated websocket connection if an id
# was provided by the request # was provided by the request
conn = None conn = None
conn_id = self.get_argument('connection_id', None) conn_id: Any = self.get_argument('connection_id', None)
if conn_id is not None: if conn_id is not None:
try: try:
conn_id = int(conn_id) conn_id = int(conn_id)
except Exception: except Exception:
pass pass
else: else:
wsm = self.settings['parent'].get_websocket_manager() parent: MoonrakerApp = self.settings['parent']
wsm: WebsocketManager = parent.get_websocket_manager()
conn = wsm.get_websocket(conn_id) conn = wsm.get_websocket(conn_id)
return conn return conn
def write_error(self, status_code, **kwargs): def write_error(self, status_code: int, **kwargs) -> None:
err = {'code': status_code, 'message': self._reason} err = {'code': status_code, 'message': self._reason}
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(
@ -315,26 +363,29 @@ class AuthorizedRequestHandler(tornado.web.RequestHandler):
# Due to the way Python treats multiple inheritance its best # Due to the way Python treats multiple inheritance its best
# to create a separate authorized handler for serving files # to create a separate authorized handler for serving files
class AuthorizedFileHandler(tornado.web.StaticFileHandler): class AuthorizedFileHandler(tornado.web.StaticFileHandler):
def initialize(self, path, default_filename=None): def initialize(self,
path: str,
default_filename: Optional[str] = None
) -> None:
super(AuthorizedFileHandler, self).initialize(path, default_filename) super(AuthorizedFileHandler, self).initialize(path, default_filename)
self.server = self.settings['parent'].get_server() self.server: Server = self.settings['parent'].get_server()
def set_default_headers(self): def set_default_headers(self) -> None:
origin = self.request.headers.get("Origin") origin: Optional[str] = self.request.headers.get("Origin")
# it is necessary to look up the parent app here, # it is necessary to look up the parent app here,
# as initialize() may not yet be called # as initialize() may not yet be called
server = self.settings['parent'].get_server() server: Server = self.settings['parent'].get_server()
auth = server.lookup_component('authorization', None) auth: AuthComp = server.lookup_component('authorization', None)
self.cors_enabled = False self.cors_enabled = False
if auth is not None: if auth is not None:
self.cors_enabled = auth.check_cors(origin, self) self.cors_enabled = auth.check_cors(origin, self)
def prepare(self): def prepare(self) -> None:
auth = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None and self.request.method != "GET": if auth is not None and self.request.method != "GET":
self.current_user = auth.check_authorized(self.request) self.current_user = auth.check_authorized(self.request)
def options(self, *args, **kwargs): def options(self, *args, **kwargs) -> None:
# Enable CORS if configured # Enable CORS if configured
if self.cors_enabled: if self.cors_enabled:
self.set_status(204) self.set_status(204)
@ -342,7 +393,7 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
else: else:
super(AuthorizedFileHandler, self).options() super(AuthorizedFileHandler, self).options()
def write_error(self, status_code, **kwargs): def write_error(self, status_code: int, **kwargs) -> None:
err = {'code': status_code, 'message': self._reason} err = {'code': status_code, 'message': self._reason}
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(
@ -350,8 +401,14 @@ class AuthorizedFileHandler(tornado.web.StaticFileHandler):
self.finish({'error': err}) self.finish({'error': err})
class DynamicRequestHandler(AuthorizedRequestHandler): class DynamicRequestHandler(AuthorizedRequestHandler):
def initialize(self, callback, methods, need_object_parser=False, def initialize(
is_remote=True, wrap_result=True): self,
callback: Union[str, Callable[[WebRequest], Coroutine]] = "",
methods: List[str] = [],
need_object_parser: bool = False,
is_remote: bool = True,
wrap_result: bool = True
) -> None:
super(DynamicRequestHandler, self).initialize() super(DynamicRequestHandler, self).initialize()
self.callback = callback self.callback = callback
self.methods = methods self.methods = methods
@ -362,8 +419,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
else self._default_parser else self._default_parser
# Converts query string values with type hints # Converts query string values with type hints
def _convert_type(self, value, hint): def _convert_type(self, value: str, hint: str) -> Any:
type_funcs = { type_funcs: Dict[str, Callable] = {
"int": int, "float": float, "int": int, "float": float,
"bool": lambda x: x.lower() == "true", "bool": lambda x: x.lower() == "true",
"json": json.loads} "json": json.loads}
@ -379,7 +436,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
return value return value
return converted return converted
def _default_parser(self): def _default_parser(self) -> Dict[str, Any]:
args = {} args = {}
for key in self.request.arguments.keys(): for key in self.request.arguments.keys():
if key in EXCLUDED_ARGS: if key in EXCLUDED_ARGS:
@ -392,8 +449,8 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
args[key_parts[0]] = self._convert_type(val, key_parts[1]) args[key_parts[0]] = self._convert_type(val, key_parts[1])
return args return args
def _object_parser(self): def _object_parser(self) -> Dict[str, Dict[str, Any]]:
args = {} args: Dict[str, Any] = {}
for key in self.request.arguments.keys(): for key in self.request.arguments.keys():
if key in EXCLUDED_ARGS: if key in EXCLUDED_ARGS:
continue continue
@ -405,7 +462,7 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
logging.debug(f"Parsed Arguments: {args}") logging.debug(f"Parsed Arguments: {args}")
return {'objects': args} return {'objects': args}
def parse_args(self): def parse_args(self) -> Dict[str, Any]:
try: try:
args = self._parse_query() args = self._parse_query()
except Exception: except Exception:
@ -423,28 +480,36 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
args[key] = value args[key] = value
return args return args
async def get(self, *args, **kwargs): async def get(self, *args, **kwargs) -> None:
await self._process_http_request() await self._process_http_request()
async def post(self, *args, **kwargs): async def post(self, *args, **kwargs) -> None:
await self._process_http_request() await self._process_http_request()
async def delete(self, *args, **kwargs): async def delete(self, *args, **kwargs) -> None:
await self._process_http_request() await self._process_http_request()
async def _do_local_request(self, args, conn): async def _do_local_request(self,
args: Dict[str, Any],
conn: Optional[WebSocket]
) -> Any:
assert callable(self.callback)
return await self.callback( return await self.callback(
WebRequest(self.request.path, args, self.request.method, WebRequest(self.request.path, args, self.request.method,
conn=conn, ip_addr=self.request.remote_ip, conn=conn, ip_addr=self.request.remote_ip,
user=self.current_user)) user=self.current_user))
async def _do_remote_request(self, args, conn): async def _do_remote_request(self,
args: Dict[str, Any],
conn: Optional[WebSocket]
) -> Any:
assert isinstance(self.callback, str)
return await self.server.make_request( return await self.server.make_request(
WebRequest(self.callback, args, conn=conn, WebRequest(self.callback, args, conn=conn,
ip_addr=self.request.remote_ip, ip_addr=self.request.remote_ip,
user=self.current_user)) user=self.current_user))
async def _process_http_request(self): async def _process_http_request(self) -> None:
if self.request.method not in self.methods: if self.request.method not in self.methods:
raise tornado.web.HTTPError(405) raise tornado.web.HTTPError(405)
conn = self.get_associated_websocket() conn = self.get_associated_websocket()
@ -459,15 +524,16 @@ class DynamicRequestHandler(AuthorizedRequestHandler):
self.finish(result) self.finish(result)
class FileRequestHandler(AuthorizedFileHandler): class FileRequestHandler(AuthorizedFileHandler):
def set_extra_headers(self, path): def set_extra_headers(self, path: str) -> None:
# The call below shold never return an empty string, # The call below shold never return an empty string,
# as the path should have already been validated to be # as the path should have already been validated to be
# a file # a file
assert isinstance(self.absolute_path, str)
basename = os.path.basename(self.absolute_path) basename = os.path.basename(self.absolute_path)
self.set_header( self.set_header(
"Content-Disposition", f"attachment; filename={basename}") "Content-Disposition", f"attachment; filename={basename}")
async def delete(self, path): async def delete(self, path: str) -> None:
path = self.request.path.lstrip("/").split("/", 2)[-1] path = self.request.path.lstrip("/").split("/", 2)[-1]
path = url_unescape(path, plus=False) path = url_unescape(path, plus=False)
file_manager = self.server.lookup_component('file_manager') file_manager = self.server.lookup_component('file_manager')
@ -568,9 +634,10 @@ class FileRequestHandler(AuthorizedFileHandler):
assert self.request.method == "HEAD" assert self.request.method == "HEAD"
@classmethod @classmethod
def _get_cached_version(cls, abs_path: str): def _get_cached_version(cls, abs_path: str) -> Optional[str]:
with cls._lock: with cls._lock:
hashes = cls._static_hashes hashes: Dict[str, Dict[str, Any]] = \
cls._static_hashes # type: ignore
try: try:
mtime = datetime.datetime.fromtimestamp( mtime = datetime.datetime.fromtimestamp(
os.path.getmtime(abs_path), tz=datetime.timezone.utc) os.path.getmtime(abs_path), tz=datetime.timezone.utc)
@ -596,14 +663,16 @@ class FileRequestHandler(AuthorizedFileHandler):
@tornado.web.stream_request_body @tornado.web.stream_request_body
class FileUploadHandler(AuthorizedRequestHandler): class FileUploadHandler(AuthorizedRequestHandler):
def initialize(self, max_upload_size): def initialize(self, max_upload_size: int = MAX_BODY_SIZE) -> None:
super(FileUploadHandler, self).initialize() super(FileUploadHandler, self).initialize()
self.file_manager = self.server.lookup_component('file_manager') self.file_manager: FileManager = self.server.lookup_component(
'file_manager')
self.max_upload_size = max_upload_size self.max_upload_size = max_upload_size
def prepare(self): def prepare(self) -> None:
super(FileUploadHandler, self).prepare() super(FileUploadHandler, self).prepare()
if self.request.method == "POST": if self.request.method == "POST":
assert isinstance(self.request.connection, HTTP1Connection)
self.request.connection.set_max_body_size(self.max_upload_size) self.request.connection.set_max_body_size(self.max_upload_size)
tmpname = self.file_manager.gen_temp_upload_path() tmpname = self.file_manager.gen_temp_upload_path()
self._targets = { self._targets = {
@ -620,11 +689,11 @@ class FileUploadHandler(AuthorizedRequestHandler):
for name, target in self._targets.items(): for name, target in self._targets.items():
self._parser.register(name, target) self._parser.register(name, target)
def data_received(self, chunk): def data_received(self, chunk: bytes) -> None:
if self.request.method == "POST": if self.request.method == "POST":
self._parser.data_received(chunk) self._parser.data_received(chunk)
async def post(self): async def post(self) -> None:
form_args = {} form_args = {}
chk_target = self._targets.pop('checksum') chk_target = self._targets.pop('checksum')
calc_chksum = self._sha256_target.value.lower() calc_chksum = self._sha256_target.value.lower()
@ -659,15 +728,15 @@ class FileUploadHandler(AuthorizedRequestHandler):
# Default Handler for unregistered endpoints # Default Handler for unregistered endpoints
class AuthorizedErrorHandler(AuthorizedRequestHandler): class AuthorizedErrorHandler(AuthorizedRequestHandler):
def prepare(self): def prepare(self) -> None:
super(AuthorizedRequestHandler, self).prepare() super(AuthorizedRequestHandler, self).prepare()
self.set_status(404) self.set_status(404)
raise tornado.web.HTTPError(404) raise tornado.web.HTTPError(404)
def check_xsrf_cookie(self): def check_xsrf_cookie(self) -> None:
pass pass
def write_error(self, status_code, **kwargs): def write_error(self, status_code: int, **kwargs) -> None:
err = {'code': status_code, 'message': self._reason} err = {'code': status_code, 'message': self._reason}
if 'exc_info' in kwargs: if 'exc_info' in kwargs:
err['traceback'] = "\n".join( err['traceback'] = "\n".join(