mqtt: resolve connection blocking issues
Override the paho-mqtt client "reconnect()" method with a method capable of taking a connected socket. This allows Moonraker to connect the socket asynchronously, then finish establishing the connection. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
0b31d7d0b2
commit
09550af466
|
@ -10,6 +10,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import ssl
|
||||||
from collections import deque
|
from collections import deque
|
||||||
import paho.mqtt.client as paho_mqtt
|
import paho.mqtt.client as paho_mqtt
|
||||||
from websockets import Subscribable, WebRequest, JsonRPC, APITransport
|
from websockets import Subscribable, WebRequest, JsonRPC, APITransport
|
||||||
|
@ -42,6 +43,100 @@ MQTT_PROTOCOLS = {
|
||||||
'v5': paho_mqtt.MQTTv5
|
'v5': paho_mqtt.MQTTv5
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ExtPahoClient(paho_mqtt.Client):
|
||||||
|
# Override reconnection to take a connected socket. This allows Moonraker
|
||||||
|
# create the socket connection asynchronously
|
||||||
|
def reconnect(self, sock: Optional[socket.socket] = None):
|
||||||
|
"""Reconnect the client after a disconnect. Can only be called after
|
||||||
|
connect()/connect_async()."""
|
||||||
|
if len(self._host) == 0:
|
||||||
|
raise ValueError('Invalid host.')
|
||||||
|
if self._port <= 0:
|
||||||
|
raise ValueError('Invalid port number.')
|
||||||
|
|
||||||
|
self._in_packet = {
|
||||||
|
"command": 0,
|
||||||
|
"have_remaining": 0,
|
||||||
|
"remaining_count": [],
|
||||||
|
"remaining_mult": 1,
|
||||||
|
"remaining_length": 0,
|
||||||
|
"packet": b"",
|
||||||
|
"to_process": 0,
|
||||||
|
"pos": 0}
|
||||||
|
|
||||||
|
with self._out_packet_mutex:
|
||||||
|
self._out_packet = deque() # type: ignore
|
||||||
|
|
||||||
|
with self._current_out_packet_mutex:
|
||||||
|
self._current_out_packet = None
|
||||||
|
|
||||||
|
with self._msgtime_mutex:
|
||||||
|
self._last_msg_in = paho_mqtt.time_func()
|
||||||
|
self._last_msg_out = paho_mqtt.time_func()
|
||||||
|
|
||||||
|
self._ping_t = 0
|
||||||
|
self._state = paho_mqtt.mqtt_cs_new
|
||||||
|
|
||||||
|
self._sock_close()
|
||||||
|
|
||||||
|
# Put messages in progress in a valid state.
|
||||||
|
self._messages_reconnect_reset()
|
||||||
|
|
||||||
|
if sock is None:
|
||||||
|
sock = self._create_socket_connection()
|
||||||
|
|
||||||
|
if self._ssl:
|
||||||
|
# SSL is only supported when SSLContext is available
|
||||||
|
# (implies Python >= 2.7.9 or >= 3.2)
|
||||||
|
|
||||||
|
verify_host = not self._tls_insecure
|
||||||
|
try:
|
||||||
|
# Try with server_hostname, even it's not supported in
|
||||||
|
# certain scenarios
|
||||||
|
sock = self._ssl_context.wrap_socket(
|
||||||
|
sock,
|
||||||
|
server_hostname=self._host,
|
||||||
|
do_handshake_on_connect=False,
|
||||||
|
)
|
||||||
|
except ssl.CertificateError:
|
||||||
|
# CertificateError is derived from ValueError
|
||||||
|
raise
|
||||||
|
except ValueError:
|
||||||
|
# Python version requires SNI in order to handle
|
||||||
|
# server_hostname, but SNI is not available
|
||||||
|
sock = self._ssl_context.wrap_socket(
|
||||||
|
sock,
|
||||||
|
do_handshake_on_connect=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If SSL context has already checked hostname, then don't need
|
||||||
|
# to do it again
|
||||||
|
if (hasattr(self._ssl_context, 'check_hostname') and
|
||||||
|
self._ssl_context.check_hostname):
|
||||||
|
verify_host = False
|
||||||
|
|
||||||
|
assert isinstance(sock, ssl.SSLSocket)
|
||||||
|
sock.settimeout(self._keepalive)
|
||||||
|
sock.do_handshake()
|
||||||
|
|
||||||
|
if verify_host:
|
||||||
|
ssl.match_hostname(sock.getpeercert(), self._host)
|
||||||
|
|
||||||
|
if self._transport == "websockets":
|
||||||
|
sock.settimeout(self._keepalive)
|
||||||
|
sock = paho_mqtt.WebsocketWrapper(
|
||||||
|
sock, self._host, self._port, self._ssl,
|
||||||
|
self._websocket_path, self._websocket_extra_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
self._sock = sock
|
||||||
|
assert self._sock is not None
|
||||||
|
self._sock.setblocking(False)
|
||||||
|
self._registered_write = False
|
||||||
|
self._call_socket_open()
|
||||||
|
|
||||||
|
return self._send_connect(self._keepalive)
|
||||||
|
|
||||||
class SubscriptionHandle:
|
class SubscriptionHandle:
|
||||||
def __init__(self, topic: str, callback: FlexCallback) -> None:
|
def __init__(self, topic: str, callback: FlexCallback) -> None:
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
@ -95,12 +190,6 @@ class AIOHelper:
|
||||||
userdata: Any,
|
userdata: Any,
|
||||||
sock: socket.socket
|
sock: socket.socket
|
||||||
) -> None:
|
) -> None:
|
||||||
self.loop.call_soon_threadsafe(
|
|
||||||
self._do_socket_open, client, sock)
|
|
||||||
|
|
||||||
def _do_socket_open(self,
|
|
||||||
client: paho_mqtt.Client,
|
|
||||||
sock: socket.socket) -> None:
|
|
||||||
logging.info("MQTT Socket Opened")
|
logging.info("MQTT Socket Opened")
|
||||||
self.loop.add_reader(sock, client.loop_read)
|
self.loop.add_reader(sock, client.loop_read)
|
||||||
self.misc_task = self.loop.create_task(self.misc_loop())
|
self.misc_task = self.loop.create_task(self.misc_loop())
|
||||||
|
@ -177,7 +266,7 @@ class MQTTClient(APITransport, Subscribable):
|
||||||
raise config.error(
|
raise config.error(
|
||||||
"Option 'default_qos' in section [mqtt] must be "
|
"Option 'default_qos' in section [mqtt] must be "
|
||||||
"between 0 and 2")
|
"between 0 and 2")
|
||||||
self.client = paho_mqtt.Client(protocol=self.protocol)
|
self.client = ExtPahoClient(protocol=self.protocol)
|
||||||
self.client.on_connect = self._on_connect
|
self.client.on_connect = self._on_connect
|
||||||
self.client.on_message = self._on_message
|
self.client.on_message = self._on_message
|
||||||
self.client.on_disconnect = self._on_disconnect
|
self.client.on_disconnect = self._on_disconnect
|
||||||
|
@ -249,22 +338,10 @@ class MQTTClient(APITransport, Subscribable):
|
||||||
self.client.will_set(self.moonraker_status_topic,
|
self.client.will_set(self.moonraker_status_topic,
|
||||||
payload=json.dumps({'server': 'offline'}),
|
payload=json.dumps({'server': 'offline'}),
|
||||||
qos=self.qos, retain=True)
|
qos=self.qos, retain=True)
|
||||||
self.connect_task = self.event_loop.create_task(self._do_connect())
|
self.client.connect_async(self.address, self.port)
|
||||||
|
self.connect_task = self.event_loop.create_task(
|
||||||
async def _do_connect(self):
|
self._do_reconnect(first=True)
|
||||||
while True:
|
)
|
||||||
try:
|
|
||||||
await self.event_loop.run_in_thread(
|
|
||||||
self.client.connect, self.address, self.port)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
raise
|
|
||||||
except Exception:
|
|
||||||
await asyncio.sleep(2.)
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
self.client.socket().setsockopt(
|
|
||||||
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
|
||||||
self.connect_task = None
|
|
||||||
|
|
||||||
async def _handle_klippy_identified(self) -> None:
|
async def _handle_klippy_identified(self) -> None:
|
||||||
if self.status_objs:
|
if self.status_objs:
|
||||||
|
@ -369,17 +446,27 @@ class MQTTClient(APITransport, Subscribable):
|
||||||
if unsub_fut is not None and not unsub_fut.done():
|
if unsub_fut is not None and not unsub_fut.done():
|
||||||
unsub_fut.set_result(None)
|
unsub_fut.set_result(None)
|
||||||
|
|
||||||
async def _do_reconnect(self) -> None:
|
async def _do_reconnect(self, first: bool = False) -> None:
|
||||||
logging.info("Attempting MQTT Reconnect")
|
logging.info("Attempting MQTT Connect/Reconnect")
|
||||||
self.event_loop
|
last_err: Exception = Exception()
|
||||||
while True:
|
while True:
|
||||||
|
if not first:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(2.)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
first = False
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(2.)
|
sock = await self.event_loop.create_socket_connection(
|
||||||
|
(self.address, self.port), timeout=10
|
||||||
|
)
|
||||||
|
self.client.reconnect(sock)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
raise
|
||||||
try:
|
except Exception as e:
|
||||||
await self.event_loop.run_in_thread(self.client.reconnect)
|
if type(last_err) != type(e) or last_err.args != e.args:
|
||||||
except Exception:
|
logging.exception("MQTT Connection Error")
|
||||||
|
last_err = e
|
||||||
continue
|
continue
|
||||||
self.client.socket().setsockopt(
|
self.client.socket().setsockopt(
|
||||||
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||||
|
|
Loading…
Reference in New Issue