mqtt: resolve connection blocking issues

Override the paho-mqtt client "reconnect()" method with
a method capable of taking a connected socket.  This allows
Moonraker to connect the socket asynchronously, then finish
establishing the connection.

Signed-off-by:  Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
Eric Callahan 2022-06-28 15:21:48 -04:00
parent 0b31d7d0b2
commit 09550af466
No known key found for this signature in database
GPG Key ID: 5A1EB336DFB4C71B
1 changed files with 118 additions and 31 deletions

View File

@ -10,6 +10,7 @@ import asyncio
import logging import logging
import json import json
import pathlib import pathlib
import ssl
from collections import deque from collections import deque
import paho.mqtt.client as paho_mqtt import paho.mqtt.client as paho_mqtt
from websockets import Subscribable, WebRequest, JsonRPC, APITransport from websockets import Subscribable, WebRequest, JsonRPC, APITransport
@ -42,6 +43,100 @@ MQTT_PROTOCOLS = {
'v5': paho_mqtt.MQTTv5 'v5': paho_mqtt.MQTTv5
} }
class ExtPahoClient(paho_mqtt.Client):
# Override reconnection to take a connected socket. This allows Moonraker
# create the socket connection asynchronously
def reconnect(self, sock: Optional[socket.socket] = None):
"""Reconnect the client after a disconnect. Can only be called after
connect()/connect_async()."""
if len(self._host) == 0:
raise ValueError('Invalid host.')
if self._port <= 0:
raise ValueError('Invalid port number.')
self._in_packet = {
"command": 0,
"have_remaining": 0,
"remaining_count": [],
"remaining_mult": 1,
"remaining_length": 0,
"packet": b"",
"to_process": 0,
"pos": 0}
with self._out_packet_mutex:
self._out_packet = deque() # type: ignore
with self._current_out_packet_mutex:
self._current_out_packet = None
with self._msgtime_mutex:
self._last_msg_in = paho_mqtt.time_func()
self._last_msg_out = paho_mqtt.time_func()
self._ping_t = 0
self._state = paho_mqtt.mqtt_cs_new
self._sock_close()
# Put messages in progress in a valid state.
self._messages_reconnect_reset()
if sock is None:
sock = self._create_socket_connection()
if self._ssl:
# SSL is only supported when SSLContext is available
# (implies Python >= 2.7.9 or >= 3.2)
verify_host = not self._tls_insecure
try:
# Try with server_hostname, even it's not supported in
# certain scenarios
sock = self._ssl_context.wrap_socket(
sock,
server_hostname=self._host,
do_handshake_on_connect=False,
)
except ssl.CertificateError:
# CertificateError is derived from ValueError
raise
except ValueError:
# Python version requires SNI in order to handle
# server_hostname, but SNI is not available
sock = self._ssl_context.wrap_socket(
sock,
do_handshake_on_connect=False,
)
else:
# If SSL context has already checked hostname, then don't need
# to do it again
if (hasattr(self._ssl_context, 'check_hostname') and
self._ssl_context.check_hostname):
verify_host = False
assert isinstance(sock, ssl.SSLSocket)
sock.settimeout(self._keepalive)
sock.do_handshake()
if verify_host:
ssl.match_hostname(sock.getpeercert(), self._host)
if self._transport == "websockets":
sock.settimeout(self._keepalive)
sock = paho_mqtt.WebsocketWrapper(
sock, self._host, self._port, self._ssl,
self._websocket_path, self._websocket_extra_headers
)
self._sock = sock
assert self._sock is not None
self._sock.setblocking(False)
self._registered_write = False
self._call_socket_open()
return self._send_connect(self._keepalive)
class SubscriptionHandle: class SubscriptionHandle:
def __init__(self, topic: str, callback: FlexCallback) -> None: def __init__(self, topic: str, callback: FlexCallback) -> None:
self.callback = callback self.callback = callback
@ -95,12 +190,6 @@ class AIOHelper:
userdata: Any, userdata: Any,
sock: socket.socket sock: socket.socket
) -> None: ) -> None:
self.loop.call_soon_threadsafe(
self._do_socket_open, client, sock)
def _do_socket_open(self,
client: paho_mqtt.Client,
sock: socket.socket) -> None:
logging.info("MQTT Socket Opened") logging.info("MQTT Socket Opened")
self.loop.add_reader(sock, client.loop_read) self.loop.add_reader(sock, client.loop_read)
self.misc_task = self.loop.create_task(self.misc_loop()) self.misc_task = self.loop.create_task(self.misc_loop())
@ -177,7 +266,7 @@ class MQTTClient(APITransport, Subscribable):
raise config.error( raise config.error(
"Option 'default_qos' in section [mqtt] must be " "Option 'default_qos' in section [mqtt] must be "
"between 0 and 2") "between 0 and 2")
self.client = paho_mqtt.Client(protocol=self.protocol) self.client = ExtPahoClient(protocol=self.protocol)
self.client.on_connect = self._on_connect self.client.on_connect = self._on_connect
self.client.on_message = self._on_message self.client.on_message = self._on_message
self.client.on_disconnect = self._on_disconnect self.client.on_disconnect = self._on_disconnect
@ -249,22 +338,10 @@ class MQTTClient(APITransport, Subscribable):
self.client.will_set(self.moonraker_status_topic, self.client.will_set(self.moonraker_status_topic,
payload=json.dumps({'server': 'offline'}), payload=json.dumps({'server': 'offline'}),
qos=self.qos, retain=True) qos=self.qos, retain=True)
self.connect_task = self.event_loop.create_task(self._do_connect()) self.client.connect_async(self.address, self.port)
self.connect_task = self.event_loop.create_task(
async def _do_connect(self): self._do_reconnect(first=True)
while True: )
try:
await self.event_loop.run_in_thread(
self.client.connect, self.address, self.port)
except asyncio.CancelledError:
raise
except Exception:
await asyncio.sleep(2.)
else:
break
self.client.socket().setsockopt(
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
self.connect_task = None
async def _handle_klippy_identified(self) -> None: async def _handle_klippy_identified(self) -> None:
if self.status_objs: if self.status_objs:
@ -369,17 +446,27 @@ class MQTTClient(APITransport, Subscribable):
if unsub_fut is not None and not unsub_fut.done(): if unsub_fut is not None and not unsub_fut.done():
unsub_fut.set_result(None) unsub_fut.set_result(None)
async def _do_reconnect(self) -> None: async def _do_reconnect(self, first: bool = False) -> None:
logging.info("Attempting MQTT Reconnect") logging.info("Attempting MQTT Connect/Reconnect")
self.event_loop last_err: Exception = Exception()
while True: while True:
if not first:
try: try:
await asyncio.sleep(2.) await asyncio.sleep(2.)
except asyncio.CancelledError: except asyncio.CancelledError:
break raise
first = False
try: try:
await self.event_loop.run_in_thread(self.client.reconnect) sock = await self.event_loop.create_socket_connection(
except Exception: (self.address, self.port), timeout=10
)
self.client.reconnect(sock)
except asyncio.CancelledError:
raise
except Exception as e:
if type(last_err) != type(e) or last_err.args != e.args:
logging.exception("MQTT Connection Error")
last_err = e
continue continue
self.client.socket().setsockopt( self.client.socket().setsockopt(
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)