websockets: set a websocket connection limit

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-11-21 20:11:58 -05:00
parent 5a22b21a40
commit 9d6719ed31
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
2 changed files with 15 additions and 1 deletions

View File

@ -58,6 +58,7 @@ if TYPE_CHECKING:
# 50 MiB Max Standard Body Size # 50 MiB Max Standard Body Size
MAX_BODY_SIZE = 50 * 1024 * 1024 MAX_BODY_SIZE = 50 * 1024 * 1024
MAX_WS_CONNS_DEFAULT = 50
EXCLUDED_ARGS = ["_", "token", "access_token", "connection_id"] EXCLUDED_ARGS = ["_", "token", "access_token", "connection_id"]
AUTHORIZED_EXTS = [".png", ".jpg"] AUTHORIZED_EXTS = [".png", ".jpg"]
DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log" DEFAULT_KLIPPY_LOG_PATH = "/tmp/klippy.log"
@ -169,6 +170,9 @@ class MoonrakerApp:
self.registered_base_handlers: List[str] = [] self.registered_base_handlers: List[str] = []
self.max_upload_size = config.getint('max_upload_size', 1024) self.max_upload_size = config.getint('max_upload_size', 1024)
self.max_upload_size *= 1024 * 1024 self.max_upload_size *= 1024 * 1024
max_ws_conns = config.getint(
'max_websocket_connections', MAX_WS_CONNS_DEFAULT
)
# SSL config # SSL config
self.cert_path: pathlib.Path = self._get_path_option( self.cert_path: pathlib.Path = self._get_path_option(
@ -193,6 +197,7 @@ class MoonrakerApp:
'websocket_ping_interval': 10, 'websocket_ping_interval': 10,
'websocket_ping_timeout': 30, 'websocket_ping_timeout': 30,
'server': self.server, 'server': self.server,
'max_websocket_connections': max_ws_conns,
'default_handler_class': AuthorizedErrorHandler, 'default_handler_class': AuthorizedErrorHandler,
'default_handler_args': {}, 'default_handler_args': {},
'log_function': self.log_request, 'log_function': self.log_request,

View File

@ -650,6 +650,8 @@ class BaseSocketClient(Subscribable):
raise NotImplementedError("Children must implement close_socket()") raise NotImplementedError("Children must implement close_socket()")
class WebSocket(WebSocketHandler, BaseSocketClient): class WebSocket(WebSocketHandler, BaseSocketClient):
connection_count: int = 0
def initialize(self) -> None: def initialize(self) -> None:
self.on_create(self.settings['server']) self.on_create(self.settings['server'])
self.ip_addr: str = self.request.remote_ip or "" self.ip_addr: str = self.request.remote_ip or ""
@ -663,6 +665,7 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
return self._user_info return self._user_info
def open(self, *args, **kwargs) -> None: def open(self, *args, **kwargs) -> None:
self.__class__.connection_count += 1
self.set_nodelay(True) self.set_nodelay(True)
self._connected_time = self.eventloop.get_loop_time() self._connected_time = self.eventloop.get_loop_time()
agent = self.request.headers.get("User-Agent", "") agent = self.request.headers.get("User-Agent", "")
@ -686,6 +689,7 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
def on_close(self) -> None: def on_close(self) -> None:
self.is_closed = True self.is_closed = True
self.__class__.connection_count -= 1
self.message_buf = [] self.message_buf = []
now = self.eventloop.get_loop_time() now = self.eventloop.get_loop_time()
pong_elapsed = now - self.last_pong_time pong_elapsed = now - self.last_pong_time
@ -725,7 +729,12 @@ class WebSocket(WebSocketHandler, BaseSocketClient):
return True return True
# Check Authorized User # Check Authorized User
def prepare(self): def prepare(self) -> None:
max_conns = self.settings["max_websocket_connections"]
if self.__class__.connection_count >= max_conns:
raise self.server.error(
"Maximum Number of Websocket Connections Reached"
)
auth: AuthComp = self.server.lookup_component('authorization', None) auth: AuthComp = self.server.lookup_component('authorization', None)
if auth is not None: if auth is not None:
try: try: