mqtt: resolve connection blocking issues
Override the paho-mqtt client "reconnect()" method with a method capable of taking a connected socket. This allows Moonraker to connect the socket asynchronously, then finish establishing the connection. Signed-off-by: Eric Callahan <arksine.code@gmail.com>
This commit is contained in:
parent
0b31d7d0b2
commit
09550af466
|
@ -10,6 +10,7 @@ import asyncio
|
|||
import logging
|
||||
import json
|
||||
import pathlib
|
||||
import ssl
|
||||
from collections import deque
|
||||
import paho.mqtt.client as paho_mqtt
|
||||
from websockets import Subscribable, WebRequest, JsonRPC, APITransport
|
||||
|
@ -42,6 +43,100 @@ MQTT_PROTOCOLS = {
|
|||
'v5': paho_mqtt.MQTTv5
|
||||
}
|
||||
|
||||
class ExtPahoClient(paho_mqtt.Client):
|
||||
# Override reconnection to take a connected socket. This allows Moonraker
|
||||
# create the socket connection asynchronously
|
||||
def reconnect(self, sock: Optional[socket.socket] = None):
|
||||
"""Reconnect the client after a disconnect. Can only be called after
|
||||
connect()/connect_async()."""
|
||||
if len(self._host) == 0:
|
||||
raise ValueError('Invalid host.')
|
||||
if self._port <= 0:
|
||||
raise ValueError('Invalid port number.')
|
||||
|
||||
self._in_packet = {
|
||||
"command": 0,
|
||||
"have_remaining": 0,
|
||||
"remaining_count": [],
|
||||
"remaining_mult": 1,
|
||||
"remaining_length": 0,
|
||||
"packet": b"",
|
||||
"to_process": 0,
|
||||
"pos": 0}
|
||||
|
||||
with self._out_packet_mutex:
|
||||
self._out_packet = deque() # type: ignore
|
||||
|
||||
with self._current_out_packet_mutex:
|
||||
self._current_out_packet = None
|
||||
|
||||
with self._msgtime_mutex:
|
||||
self._last_msg_in = paho_mqtt.time_func()
|
||||
self._last_msg_out = paho_mqtt.time_func()
|
||||
|
||||
self._ping_t = 0
|
||||
self._state = paho_mqtt.mqtt_cs_new
|
||||
|
||||
self._sock_close()
|
||||
|
||||
# Put messages in progress in a valid state.
|
||||
self._messages_reconnect_reset()
|
||||
|
||||
if sock is None:
|
||||
sock = self._create_socket_connection()
|
||||
|
||||
if self._ssl:
|
||||
# SSL is only supported when SSLContext is available
|
||||
# (implies Python >= 2.7.9 or >= 3.2)
|
||||
|
||||
verify_host = not self._tls_insecure
|
||||
try:
|
||||
# Try with server_hostname, even it's not supported in
|
||||
# certain scenarios
|
||||
sock = self._ssl_context.wrap_socket(
|
||||
sock,
|
||||
server_hostname=self._host,
|
||||
do_handshake_on_connect=False,
|
||||
)
|
||||
except ssl.CertificateError:
|
||||
# CertificateError is derived from ValueError
|
||||
raise
|
||||
except ValueError:
|
||||
# Python version requires SNI in order to handle
|
||||
# server_hostname, but SNI is not available
|
||||
sock = self._ssl_context.wrap_socket(
|
||||
sock,
|
||||
do_handshake_on_connect=False,
|
||||
)
|
||||
else:
|
||||
# If SSL context has already checked hostname, then don't need
|
||||
# to do it again
|
||||
if (hasattr(self._ssl_context, 'check_hostname') and
|
||||
self._ssl_context.check_hostname):
|
||||
verify_host = False
|
||||
|
||||
assert isinstance(sock, ssl.SSLSocket)
|
||||
sock.settimeout(self._keepalive)
|
||||
sock.do_handshake()
|
||||
|
||||
if verify_host:
|
||||
ssl.match_hostname(sock.getpeercert(), self._host)
|
||||
|
||||
if self._transport == "websockets":
|
||||
sock.settimeout(self._keepalive)
|
||||
sock = paho_mqtt.WebsocketWrapper(
|
||||
sock, self._host, self._port, self._ssl,
|
||||
self._websocket_path, self._websocket_extra_headers
|
||||
)
|
||||
|
||||
self._sock = sock
|
||||
assert self._sock is not None
|
||||
self._sock.setblocking(False)
|
||||
self._registered_write = False
|
||||
self._call_socket_open()
|
||||
|
||||
return self._send_connect(self._keepalive)
|
||||
|
||||
class SubscriptionHandle:
|
||||
def __init__(self, topic: str, callback: FlexCallback) -> None:
|
||||
self.callback = callback
|
||||
|
@ -95,12 +190,6 @@ class AIOHelper:
|
|||
userdata: Any,
|
||||
sock: socket.socket
|
||||
) -> None:
|
||||
self.loop.call_soon_threadsafe(
|
||||
self._do_socket_open, client, sock)
|
||||
|
||||
def _do_socket_open(self,
|
||||
client: paho_mqtt.Client,
|
||||
sock: socket.socket) -> None:
|
||||
logging.info("MQTT Socket Opened")
|
||||
self.loop.add_reader(sock, client.loop_read)
|
||||
self.misc_task = self.loop.create_task(self.misc_loop())
|
||||
|
@ -177,7 +266,7 @@ class MQTTClient(APITransport, Subscribable):
|
|||
raise config.error(
|
||||
"Option 'default_qos' in section [mqtt] must be "
|
||||
"between 0 and 2")
|
||||
self.client = paho_mqtt.Client(protocol=self.protocol)
|
||||
self.client = ExtPahoClient(protocol=self.protocol)
|
||||
self.client.on_connect = self._on_connect
|
||||
self.client.on_message = self._on_message
|
||||
self.client.on_disconnect = self._on_disconnect
|
||||
|
@ -249,22 +338,10 @@ class MQTTClient(APITransport, Subscribable):
|
|||
self.client.will_set(self.moonraker_status_topic,
|
||||
payload=json.dumps({'server': 'offline'}),
|
||||
qos=self.qos, retain=True)
|
||||
self.connect_task = self.event_loop.create_task(self._do_connect())
|
||||
|
||||
async def _do_connect(self):
|
||||
while True:
|
||||
try:
|
||||
await self.event_loop.run_in_thread(
|
||||
self.client.connect, self.address, self.port)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
await asyncio.sleep(2.)
|
||||
else:
|
||||
break
|
||||
self.client.socket().setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||
self.connect_task = None
|
||||
self.client.connect_async(self.address, self.port)
|
||||
self.connect_task = self.event_loop.create_task(
|
||||
self._do_reconnect(first=True)
|
||||
)
|
||||
|
||||
async def _handle_klippy_identified(self) -> None:
|
||||
if self.status_objs:
|
||||
|
@ -369,17 +446,27 @@ class MQTTClient(APITransport, Subscribable):
|
|||
if unsub_fut is not None and not unsub_fut.done():
|
||||
unsub_fut.set_result(None)
|
||||
|
||||
async def _do_reconnect(self) -> None:
|
||||
logging.info("Attempting MQTT Reconnect")
|
||||
self.event_loop
|
||||
async def _do_reconnect(self, first: bool = False) -> None:
|
||||
logging.info("Attempting MQTT Connect/Reconnect")
|
||||
last_err: Exception = Exception()
|
||||
while True:
|
||||
if not first:
|
||||
try:
|
||||
await asyncio.sleep(2.)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
first = False
|
||||
try:
|
||||
await asyncio.sleep(2.)
|
||||
sock = await self.event_loop.create_socket_connection(
|
||||
(self.address, self.port), timeout=10
|
||||
)
|
||||
self.client.reconnect(sock)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
try:
|
||||
await self.event_loop.run_in_thread(self.client.reconnect)
|
||||
except Exception:
|
||||
raise
|
||||
except Exception as e:
|
||||
if type(last_err) != type(e) or last_err.args != e.args:
|
||||
logging.exception("MQTT Connection Error")
|
||||
last_err = e
|
||||
continue
|
||||
self.client.socket().setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_SNDBUF, 2048)
|
||||
|
|
Loading…
Reference in New Issue