From 09550af466497a7e4b7c801acf318fb6446ff44a Mon Sep 17 00:00:00 2001 From: Eric Callahan Date: Tue, 28 Jun 2022 15:21:48 -0400 Subject: [PATCH] mqtt: resolve connection blocking issues Override the paho-mqtt client "reconnect()" method with a method capable of taking a connected socket. This allows Moonraker to connect the socket asynchronously, then finish establishing the connection. Signed-off-by: Eric Callahan --- moonraker/components/mqtt.py | 149 +++++++++++++++++++++++++++-------- 1 file changed, 118 insertions(+), 31 deletions(-) diff --git a/moonraker/components/mqtt.py b/moonraker/components/mqtt.py index 5434d23..31f396b 100644 --- a/moonraker/components/mqtt.py +++ b/moonraker/components/mqtt.py @@ -10,6 +10,7 @@ import asyncio import logging import json import pathlib +import ssl from collections import deque import paho.mqtt.client as paho_mqtt from websockets import Subscribable, WebRequest, JsonRPC, APITransport @@ -42,6 +43,100 @@ MQTT_PROTOCOLS = { 'v5': paho_mqtt.MQTTv5 } +class ExtPahoClient(paho_mqtt.Client): + # Override reconnection to take a connected socket. This allows Moonraker + # create the socket connection asynchronously + def reconnect(self, sock: Optional[socket.socket] = None): + """Reconnect the client after a disconnect. Can only be called after + connect()/connect_async().""" + if len(self._host) == 0: + raise ValueError('Invalid host.') + if self._port <= 0: + raise ValueError('Invalid port number.') + + self._in_packet = { + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": b"", + "to_process": 0, + "pos": 0} + + with self._out_packet_mutex: + self._out_packet = deque() # type: ignore + + with self._current_out_packet_mutex: + self._current_out_packet = None + + with self._msgtime_mutex: + self._last_msg_in = paho_mqtt.time_func() + self._last_msg_out = paho_mqtt.time_func() + + self._ping_t = 0 + self._state = paho_mqtt.mqtt_cs_new + + self._sock_close() + + # Put messages in progress in a valid state. + self._messages_reconnect_reset() + + if sock is None: + sock = self._create_socket_connection() + + if self._ssl: + # SSL is only supported when SSLContext is available + # (implies Python >= 2.7.9 or >= 3.2) + + verify_host = not self._tls_insecure + try: + # Try with server_hostname, even it's not supported in + # certain scenarios + sock = self._ssl_context.wrap_socket( + sock, + server_hostname=self._host, + do_handshake_on_connect=False, + ) + except ssl.CertificateError: + # CertificateError is derived from ValueError + raise + except ValueError: + # Python version requires SNI in order to handle + # server_hostname, but SNI is not available + sock = self._ssl_context.wrap_socket( + sock, + do_handshake_on_connect=False, + ) + else: + # If SSL context has already checked hostname, then don't need + # to do it again + if (hasattr(self._ssl_context, 'check_hostname') and + self._ssl_context.check_hostname): + verify_host = False + + assert isinstance(sock, ssl.SSLSocket) + sock.settimeout(self._keepalive) + sock.do_handshake() + + if verify_host: + ssl.match_hostname(sock.getpeercert(), self._host) + + if self._transport == "websockets": + sock.settimeout(self._keepalive) + sock = paho_mqtt.WebsocketWrapper( + sock, self._host, self._port, self._ssl, + self._websocket_path, self._websocket_extra_headers + ) + + self._sock = sock + assert self._sock is not None + self._sock.setblocking(False) + self._registered_write = False + self._call_socket_open() + + return self._send_connect(self._keepalive) + class SubscriptionHandle: def __init__(self, topic: str, callback: FlexCallback) -> None: self.callback = callback @@ -95,12 +190,6 @@ class AIOHelper: userdata: Any, sock: socket.socket ) -> None: - self.loop.call_soon_threadsafe( - self._do_socket_open, client, sock) - - def _do_socket_open(self, - client: paho_mqtt.Client, - sock: socket.socket) -> None: logging.info("MQTT Socket Opened") self.loop.add_reader(sock, client.loop_read) self.misc_task = self.loop.create_task(self.misc_loop()) @@ -177,7 +266,7 @@ class MQTTClient(APITransport, Subscribable): raise config.error( "Option 'default_qos' in section [mqtt] must be " "between 0 and 2") - self.client = paho_mqtt.Client(protocol=self.protocol) + self.client = ExtPahoClient(protocol=self.protocol) self.client.on_connect = self._on_connect self.client.on_message = self._on_message self.client.on_disconnect = self._on_disconnect @@ -249,22 +338,10 @@ class MQTTClient(APITransport, Subscribable): self.client.will_set(self.moonraker_status_topic, payload=json.dumps({'server': 'offline'}), qos=self.qos, retain=True) - self.connect_task = self.event_loop.create_task(self._do_connect()) - - async def _do_connect(self): - while True: - try: - await self.event_loop.run_in_thread( - self.client.connect, self.address, self.port) - except asyncio.CancelledError: - raise - except Exception: - await asyncio.sleep(2.) - else: - break - self.client.socket().setsockopt( - socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) - self.connect_task = None + self.client.connect_async(self.address, self.port) + self.connect_task = self.event_loop.create_task( + self._do_reconnect(first=True) + ) async def _handle_klippy_identified(self) -> None: if self.status_objs: @@ -369,17 +446,27 @@ class MQTTClient(APITransport, Subscribable): if unsub_fut is not None and not unsub_fut.done(): unsub_fut.set_result(None) - async def _do_reconnect(self) -> None: - logging.info("Attempting MQTT Reconnect") - self.event_loop + async def _do_reconnect(self, first: bool = False) -> None: + logging.info("Attempting MQTT Connect/Reconnect") + last_err: Exception = Exception() while True: + if not first: + try: + await asyncio.sleep(2.) + except asyncio.CancelledError: + raise + first = False try: - await asyncio.sleep(2.) + sock = await self.event_loop.create_socket_connection( + (self.address, self.port), timeout=10 + ) + self.client.reconnect(sock) except asyncio.CancelledError: - break - try: - await self.event_loop.run_in_thread(self.client.reconnect) - except Exception: + raise + except Exception as e: + if type(last_err) != type(e) or last_err.args != e.args: + logging.exception("MQTT Connection Error") + last_err = e continue self.client.socket().setsockopt( socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)