bulk_sensor: Simplify the registration of internal clients in BatchBulkHelper

Previously, the BatchBulkHelper class was designed primarily to
register webhook clients, and internal clients used a wrapper class
that emulated a webhooks client.

Change BatchBulkHelper to support regular internal callbacks, and
introduce a new BatchWebhooksClient class that can translate these
internal callback to webhooks client messages.

This makes it easier to register internal clients that can process the
bulk messages every batch interval.

Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
This commit is contained in:
Kevin O'Connor 2023-12-17 00:15:55 -05:00
parent 3370134593
commit c716edafe2
5 changed files with 69 additions and 69 deletions

View File

@ -32,26 +32,29 @@ Accel_Measurement = collections.namedtuple(
# Helper class to obtain measurements # Helper class to obtain measurements
class AccelQueryHelper: class AccelQueryHelper:
def __init__(self, printer, cconn): def __init__(self, printer):
self.printer = printer self.printer = printer
self.cconn = cconn self.is_finished = False
print_time = printer.lookup_object('toolhead').get_last_move_time() print_time = printer.lookup_object('toolhead').get_last_move_time()
self.request_start_time = self.request_end_time = print_time self.request_start_time = self.request_end_time = print_time
self.samples = self.raw_samples = [] self.msgs = []
self.samples = []
def finish_measurements(self): def finish_measurements(self):
toolhead = self.printer.lookup_object('toolhead') toolhead = self.printer.lookup_object('toolhead')
self.request_end_time = toolhead.get_last_move_time() self.request_end_time = toolhead.get_last_move_time()
toolhead.wait_moves() toolhead.wait_moves()
self.cconn.finalize() self.is_finished = True
def _get_raw_samples(self): def handle_batch(self, msg):
raw_samples = self.cconn.get_messages() if self.is_finished:
if raw_samples: return False
self.raw_samples = raw_samples if len(self.msgs) >= 10000:
return self.raw_samples # Avoid filling up memory with too many samples
return False
self.msgs.append(msg)
return True
def has_valid_samples(self): def has_valid_samples(self):
raw_samples = self._get_raw_samples() for msg in self.msgs:
for msg in raw_samples: data = msg['data']
data = msg['params']['data']
first_sample_time = data[0][0] first_sample_time = data[0][0]
last_sample_time = data[-1][0] last_sample_time = data[-1][0]
if (first_sample_time > self.request_end_time if (first_sample_time > self.request_end_time
@ -60,21 +63,20 @@ class AccelQueryHelper:
# The time intervals [first_sample_time, last_sample_time] # The time intervals [first_sample_time, last_sample_time]
# and [request_start_time, request_end_time] have non-zero # and [request_start_time, request_end_time] have non-zero
# intersection. It is still theoretically possible that none # intersection. It is still theoretically possible that none
# of the samples from raw_samples fall into the time interval # of the samples from msgs fall into the time interval
# [request_start_time, request_end_time] if it is too narrow # [request_start_time, request_end_time] if it is too narrow
# or on very heavy data losses. In practice, that interval # or on very heavy data losses. In practice, that interval
# is at least 1 second, so this possibility is negligible. # is at least 1 second, so this possibility is negligible.
return True return True
return False return False
def get_samples(self): def get_samples(self):
raw_samples = self._get_raw_samples() if not self.msgs:
if not raw_samples:
return self.samples return self.samples
total = sum([len(m['params']['data']) for m in raw_samples]) total = sum([len(m['data']) for m in self.msgs])
count = 0 count = 0
self.samples = samples = [None] * total self.samples = samples = [None] * total
for msg in raw_samples: for msg in self.msgs:
for samp_time, x, y, z in msg['params']['data']: for samp_time, x, y, z in msg['data']:
if samp_time < self.request_start_time: if samp_time < self.request_start_time:
continue continue
if samp_time > self.request_end_time: if samp_time > self.request_end_time:
@ -250,8 +252,9 @@ class ADXL345:
"(e.g. faulty wiring) or a faulty adxl345 chip." % ( "(e.g. faulty wiring) or a faulty adxl345 chip." % (
reg, val, stored_val)) reg, val, stored_val))
def start_internal_client(self): def start_internal_client(self):
cconn = self.batch_bulk.add_internal_client() aqh = AccelQueryHelper(self.printer)
return AccelQueryHelper(self.printer, cconn) self.batch_bulk.add_client(aqh.handle_batch)
return aqh
# Measurement decoding # Measurement decoding
def _extract_samples(self, raw_samples): def _extract_samples(self, raw_samples):
# Load variables to optimize inner loop below # Load variables to optimize inner loop below

View File

@ -157,8 +157,14 @@ class AngleCalibration:
def do_calibration_moves(self): def do_calibration_moves(self):
move = self.printer.lookup_object('force_move').manual_move move = self.printer.lookup_object('force_move').manual_move
# Start data collection # Start data collection
angle_sensor = self.printer.lookup_object(self.name) msgs = []
cconn = angle_sensor.start_internal_client() is_finished = False
def handle_batch(msg):
if is_finished:
return False
msgs.append(msg)
return True
self.printer.lookup_object(self.name).add_client(handle_batch)
# Move stepper several turns (to allow internal sensor calibration) # Move stepper several turns (to allow internal sensor calibration)
microsteps, full_steps = self.get_microsteps() microsteps, full_steps = self.get_microsteps()
mcu_stepper = self.mcu_stepper mcu_stepper = self.mcu_stepper
@ -190,13 +196,12 @@ class AngleCalibration:
move(mcu_stepper, .5*rotation_dist + align_dist, move_speed) move(mcu_stepper, .5*rotation_dist + align_dist, move_speed)
toolhead.wait_moves() toolhead.wait_moves()
# Finish data collection # Finish data collection
cconn.finalize() is_finished = True
msgs = cconn.get_messages()
# Correlate query responses # Correlate query responses
cal = {} cal = {}
step = 0 step = 0
for msg in msgs: for msg in msgs:
for query_time, pos in msg['params']['data']: for query_time, pos in msg['data']:
# Add to step tracking # Add to step tracking
while step < len(times) and query_time > times[step][1]: while step < len(times) and query_time > times[step][1]:
step += 1 step += 1
@ -462,8 +467,8 @@ class Angle:
"spi_angle_end oid=%c sequence=%hu", oid=self.oid, cq=cmdqueue) "spi_angle_end oid=%c sequence=%hu", oid=self.oid, cq=cmdqueue)
def get_status(self, eventtime=None): def get_status(self, eventtime=None):
return {'temperature': self.sensor_helper.last_temperature} return {'temperature': self.sensor_helper.last_temperature}
def start_internal_client(self): def add_client(self, client_cb):
return self.batch_bulk.add_internal_client() self.batch_bulk.add_client(client_cb)
# Measurement decoding # Measurement decoding
def _extract_samples(self, raw_samples): def _extract_samples(self, raw_samples):
# Load variables to optimize inner loop below # Load variables to optimize inner loop below

View File

@ -22,7 +22,7 @@ class BatchBulkHelper:
self.is_started = False self.is_started = False
self.batch_interval = batch_interval self.batch_interval = batch_interval
self.batch_timer = None self.batch_timer = None
self.clients = {} self.client_cbs = []
self.webhooks_start_resp = {} self.webhooks_start_resp = {}
# Periodic batch processing # Periodic batch processing
def _start(self): def _start(self):
@ -34,14 +34,14 @@ class BatchBulkHelper:
except self.printer.command_error as e: except self.printer.command_error as e:
logging.exception("BatchBulkHelper start callback error") logging.exception("BatchBulkHelper start callback error")
self.is_started = False self.is_started = False
self.clients.clear() del self.client_cbs[:]
raise raise
reactor = self.printer.get_reactor() reactor = self.printer.get_reactor()
systime = reactor.monotonic() systime = reactor.monotonic()
waketime = systime + self.batch_interval waketime = systime + self.batch_interval
self.batch_timer = reactor.register_timer(self._proc_batch, waketime) self.batch_timer = reactor.register_timer(self._proc_batch, waketime)
def _stop(self): def _stop(self):
self.clients.clear() del self.client_cbs[:]
self.printer.get_reactor().unregister_timer(self.batch_timer) self.printer.get_reactor().unregister_timer(self.batch_timer)
self.batch_timer = None self.batch_timer = None
if not self.is_started: if not self.is_started:
@ -50,9 +50,9 @@ class BatchBulkHelper:
self.stop_cb() self.stop_cb()
except self.printer.command_error as e: except self.printer.command_error as e:
logging.exception("BatchBulkHelper stop callback error") logging.exception("BatchBulkHelper stop callback error")
self.clients.clear() del self.client_cbs[:]
self.is_started = False self.is_started = False
if self.clients: if self.client_cbs:
# New client started while in process of stopping # New client started while in process of stopping
self._start() self._start()
def _proc_batch(self, eventtime): def _proc_batch(self, eventtime):
@ -64,51 +64,41 @@ class BatchBulkHelper:
return self.printer.get_reactor().NEVER return self.printer.get_reactor().NEVER
if not msg: if not msg:
return eventtime + self.batch_interval return eventtime + self.batch_interval
for cconn, template in list(self.clients.items()): for client_cb in list(self.client_cbs):
if cconn.is_closed(): res = client_cb(msg)
del self.clients[cconn] if not res:
if not self.clients: # This client no longer needs updates - unregister it
self.client_cbs.remove(client_cb)
if not self.client_cbs:
self._stop() self._stop()
return self.printer.get_reactor().NEVER return self.printer.get_reactor().NEVER
continue
tmp = dict(template)
tmp['params'] = msg
cconn.send(tmp)
return eventtime + self.batch_interval return eventtime + self.batch_interval
# Internal clients # Client registration
def add_internal_client(self): def add_client(self, client_cb):
cconn = InternalDumpClient() self.client_cbs.append(client_cb)
self.clients[cconn] = {}
self._start() self._start()
return cconn
# Webhooks registration # Webhooks registration
def _add_api_client(self, web_request): def _add_api_client(self, web_request):
cconn = web_request.get_client_connection() whbatch = BatchWebhooksClient(web_request)
template = web_request.get_dict('response_template', {}) self.add_client(whbatch.handle_batch)
self.clients[cconn] = template
self._start()
web_request.send(self.webhooks_start_resp) web_request.send(self.webhooks_start_resp)
def add_mux_endpoint(self, path, key, value, webhooks_start_resp): def add_mux_endpoint(self, path, key, value, webhooks_start_resp):
self.webhooks_start_resp = webhooks_start_resp self.webhooks_start_resp = webhooks_start_resp
wh = self.printer.lookup_object('webhooks') wh = self.printer.lookup_object('webhooks')
wh.register_mux_endpoint(path, key, value, self._add_api_client) wh.register_mux_endpoint(path, key, value, self._add_api_client)
# An "internal webhooks" wrapper for using BatchBulkHelper internally # A webhooks wrapper for use by BatchBulkHelper
class InternalDumpClient: class BatchWebhooksClient:
def __init__(self): def __init__(self, web_request):
self.msgs = [] self.cconn = web_request.get_client_connection()
self.is_done = False self.template = web_request.get_dict('response_template', {})
def get_messages(self): def handle_batch(self, msg):
return self.msgs if self.cconn.is_closed():
def finalize(self): return False
self.is_done = True tmp = dict(self.template)
def is_closed(self): tmp['params'] = msg
return self.is_done self.cconn.send(tmp)
def send(self, msg): return True
self.msgs.append(msg)
if len(self.msgs) >= 10000:
# Avoid filling up memory with too many samples
self.finalize()
# Helper class to store incoming messages in a queue # Helper class to store incoming messages in a queue
class BulkDataQueue: class BulkDataQueue:

View File

@ -97,8 +97,9 @@ class LIS2DW:
"(e.g. faulty wiring) or a faulty lis2dw chip." % ( "(e.g. faulty wiring) or a faulty lis2dw chip." % (
reg, val, stored_val)) reg, val, stored_val))
def start_internal_client(self): def start_internal_client(self):
cconn = self.bulk_batch.add_internal_client() aqh = adxl345.AccelQueryHelper(self.printer)
return adxl345.AccelQueryHelper(self.printer, cconn) self.batch_bulk.add_client(aqh.handle_batch)
return aqh
# Measurement decoding # Measurement decoding
def _extract_samples(self, raw_samples): def _extract_samples(self, raw_samples):
# Load variables to optimize inner loop below # Load variables to optimize inner loop below

View File

@ -109,8 +109,9 @@ class MPU9250:
def set_reg(self, reg, val, minclock=0): def set_reg(self, reg, val, minclock=0):
self.i2c.i2c_write([reg, val & 0xFF], minclock=minclock) self.i2c.i2c_write([reg, val & 0xFF], minclock=minclock)
def start_internal_client(self): def start_internal_client(self):
cconn = self.batch_bulk.add_internal_client() aqh = adxl345.AccelQueryHelper(self.printer)
return adxl345.AccelQueryHelper(self.printer, cconn) self.batch_bulk.add_client(aqh.handle_batch)
return aqh
# Measurement decoding # Measurement decoding
def _extract_samples(self, raw_samples): def _extract_samples(self, raw_samples):
# Load variables to optimize inner loop below # Load variables to optimize inner loop below