From db6e2d4c9ef1e2f8f231821f7506252a406ea2f0 Mon Sep 17 00:00:00 2001 From: Kevin O'Connor Date: Tue, 5 Mar 2019 12:42:35 -0500 Subject: [PATCH] msgproto: Convert static strings to a more generic enumeration system Signed-off-by: Kevin O'Connor --- klippy/mcu.py | 2 +- klippy/msgproto.py | 107 +++++++++++++++++++++++++++++---------- klippy/serialhdl.py | 1 - scripts/buildcommands.py | 38 ++++++++++---- src/command.h | 7 +++ 5 files changed, 116 insertions(+), 39 deletions(-) diff --git a/klippy/mcu.py b/klippy/mcu.py index 5a00c54c..a5b97557 100644 --- a/klippy/mcu.py +++ b/klippy/mcu.py @@ -483,7 +483,7 @@ class MCU: if self._is_shutdown: return self._is_shutdown = True - self._shutdown_msg = msg = params['#msg'] + self._shutdown_msg = msg = params['static_string_id'] logging.info("MCU '%s' %s: %s\n%s\n%s", self._name, params['#name'], self._shutdown_msg, self._clocksync.dump_debug(), self._serial.dump_debug()) diff --git a/klippy/msgproto.py b/klippy/msgproto.py index 6bb5c159..41f0a11b 100644 --- a/klippy/msgproto.py +++ b/klippy/msgproto.py @@ -1,6 +1,6 @@ # Protocol definitions for firmware communication # -# Copyright (C) 2016,2017 Kevin O'Connor +# Copyright (C) 2016-2019 Kevin O'Connor # # This file may be distributed under the terms of the GNU GPLv3 license. import json, zlib, logging @@ -37,9 +37,10 @@ def crc16_ccitt(buf): return crc class PT_uint32: - is_int = 1 + is_int = True + is_dynamic_string = False max_length = 5 - signed = 0 + signed = False def encode(self, out, v): if v >= 0xc000000 or v < -0x4000000: out.append((v>>28) & 0x7f | 0x80) if v >= 0x180000 or v < -0x80000: out.append((v>>21) & 0x7f | 0x80) @@ -61,17 +62,18 @@ class PT_uint32: return v, pos class PT_int32(PT_uint32): - signed = 1 + signed = True class PT_uint16(PT_uint32): max_length = 3 class PT_int16(PT_int32): - signed = 1 + signed = True max_length = 3 class PT_byte(PT_uint32): max_length = 2 class PT_string: - is_int = 0 + is_int = False + is_dynamic_string = True max_length = 64 def encode(self, out, v): out.append(len(v)) @@ -91,22 +93,55 @@ MessageTypes = { '%s': PT_string(), '%.*s': PT_progmem_buffer(), '%*s': PT_buffer(), } +class Enumeration: + is_int = False + is_dynamic_string = False + def __init__(self, pt, enum_name, enums): + self.pt = pt + self.max_length = pt.max_length + self.enum_name = enum_name + self.enums = enums + self.reverse_enums = {v: k for k, v in enums.items()} + def encode(self, out, v): + tv = self.enums.get(v) + if tv is None: + raise error("Unknown value '%s' in enumeration '%s'" % ( + v, self.enum_name)) + self.pt.encode(out, tv) + def parse(self, s, pos): + v, pos = self.pt.parse(s, pos) + tv = self.reverse_enums.get(v) + if tv is None: + tv = "?%d" % (v,) + return tv, pos + +# Lookup the message types for a format string +def lookup_params(msgformat, enumerations={}): + out = [] + argparts = [arg.split('=') for arg in msgformat.split()[1:]] + for name, fmt in argparts: + pt = MessageTypes[fmt] + for enum_name, enums in enumerations.items(): + if name == enum_name or name.endswith('_' + enum_name): + pt = Enumeration(pt, enum_name, enums) + break + out.append((name, pt)) + return out + # Update the message format to be compatible with python's % operator def convert_msg_format(msgformat): - mf = msgformat.replace('%c', '%u') - mf = mf.replace('%.*s', '%s').replace('%*s', '%s') - return mf + for c in ['%u', '%i', '%hu', '%hi', '%c', '%.*s', '%*s']: + msgformat = msgformat.replace(c, '%s') + return msgformat class MessageFormat: - def __init__(self, msgid, msgformat): + def __init__(self, msgid, msgformat, enumerations={}): self.msgid = msgid self.msgformat = msgformat self.debugformat = convert_msg_format(msgformat) - parts = msgformat.split() - self.name = parts[0] - argparts = [arg.split('=') for arg in parts[1:]] - self.param_types = [MessageTypes[fmt] for name, fmt in argparts] - self.param_names = [(name, MessageTypes[fmt]) for name, fmt in argparts] + self.name = msgformat.split()[0] + self.param_names = lookup_params(msgformat, enumerations) + self.param_types = [t for name, t in self.param_names] self.name_to_type = dict(self.param_names) def encode(self, params): out = [] @@ -131,7 +166,7 @@ class MessageFormat: out = [] for name, t in self.param_names: v = params[name] - if not t.is_int: + if t.is_dynamic_string: v = repr(v) out.append(v) return self.debugformat % tuple(out) @@ -162,7 +197,7 @@ class OutputFormat: out = [] for t in self.param_types: v, pos = t.parse(s, pos) - if not t.is_int: + if t.is_dynamic_string: v = repr(v) out.append(v) outmsg = self.debugformat % tuple(out) @@ -183,10 +218,10 @@ class MessageParser: error = error def __init__(self): self.unknown = UnknownFormat() + self.enumerations = {} self.command_ids = [] self.messages_by_id = {} self.messages_by_name = {} - self.static_strings = {} self.config = {} self.version = self.build_versions = "" self.raw_identify_data = "" @@ -239,9 +274,6 @@ class MessageParser: if pos != len(s)-MESSAGE_TRAILER_SIZE: raise error("Extra data at end of message") params['#name'] = mid.name - static_string_id = params.get('static_string_id') - if static_string_id is not None: - params['#msg'] = self.static_strings.get(static_string_id, "?") return params def encode(self, seq, cmd): msglen = MESSAGE_MIN + len(cmd) @@ -282,27 +314,47 @@ class MessageParser: argparts = dict(arg.split('=', 1) for arg in parts[1:]) for name, value in argparts.items(): t = mp.name_to_type[name] - if t.is_int: + if t.is_dynamic_string: + tval = self._parse_buffer(value) + elif t.is_int: tval = int(value, 0) else: - tval = self._parse_buffer(value) + tval = value argparts[name] = tval except: - #traceback.print_exc() + #logging.exception("Unable to extract params") raise error("Unable to extract params from: %s" % (msgname,)) try: cmd = mp.encode_by_name(**argparts) except: - #traceback.print_exc() + #logging.exception("Unable to encode") raise error("Unable to encode: %s" % (msgname,)) return cmd + def _fill_enumerations(self, enumerations): + for add_name, add_enums in enumerations.items(): + enums = self.enumerations.setdefault(add_name, {}) + for enum, value in add_enums.items(): + if type(value) == type(0): + # Simple enumeration + enums[str(enum)] = value + continue + # Enumeration range + enum = enum_root = str(enum) + while enum_root and enum_root[-1].isdigit(): + enum_root = enum_root[:-1] + start_enum = 0 + if len(enum_root) != len(enum): + start_enum = int(enum[len(enum_root):]) + start_value, count = value + for i in range(count): + enums[enum_root + str(start_enum + i)] = start_value + i def _init_messages(self, messages, output_ids=[]): for msgformat, msgid in messages.items(): msgid = int(msgid) if msgid in output_ids: self.messages_by_id[msgid] = OutputFormat(msgid, msgformat) continue - msg = MessageFormat(msgid, msgformat) + msg = MessageFormat(msgid, msgformat, self.enumerations) self.messages_by_id[msgid] = msg self.messages_by_name[msg.name] = msg def process_identify(self, data, decompress=True): @@ -311,6 +363,7 @@ class MessageParser: data = zlib.decompress(data) self.raw_identify_data = data data = json.loads(data) + self._fill_enumerations(data.get('enumerations', {})) commands = data.get('commands') responses = data.get('responses') output = data.get('output', {}) @@ -319,8 +372,6 @@ class MessageParser: all_messages.update(output) self.command_ids = sorted(commands.values()) self._init_messages(all_messages, output.values()) - static_strings = data.get('static_strings', {}) - self.static_strings = {int(k): v for k, v in static_strings.items()} self.config.update(data.get('config', {})) self.version = data.get('version', '') self.build_versions = data.get('build_versions', '') diff --git a/klippy/serialhdl.py b/klippy/serialhdl.py index 0aa477c5..1751a226 100644 --- a/klippy/serialhdl.py +++ b/klippy/serialhdl.py @@ -31,7 +31,6 @@ class SerialReader: # Message handlers handlers = { '#unknown': self.handle_unknown, '#output': self.handle_output, - 'shutdown': self.handle_output, 'is_shutdown': self.handle_output } self.handlers = { (k, None): v for k, v in handlers.items() } def _bg_thread(self): diff --git a/scripts/buildcommands.py b/scripts/buildcommands.py index 46e6dc36..43e318e7 100644 --- a/scripts/buildcommands.py +++ b/scripts/buildcommands.py @@ -60,25 +60,45 @@ Handlers.append(HandleCallList()) ###################################################################### -# Static string generation +# Enumeration and static string generation ###################################################################### STATIC_STRING_MIN = 2 # Generate a dynamic string to integer mapping -class HandleStaticStrings: +class HandleEnumerations: def __init__(self): self.static_strings = [] - self.found_strings = {} - self.ctr_dispatch = { '_DECL_STATIC_STR': self.decl_static_str } + self.enumerations = {} + self.ctr_dispatch = { + '_DECL_STATIC_STR': self.decl_static_str, + '_DECL_ENUMERATION': self.decl_enumeration, + '_DECL_ENUMERATION_RANGE': self.decl_enumeration_range + } + def add_enumeration(self, enum, name, value): + enums = self.enumerations.setdefault(enum, {}) + if name in enums and enums[name] != value: + error("Conflicting definition for enumeration '%s %s'" % ( + enum, name)) + enums[name] = value + def decl_enumeration(self, req): + enum, name, value = req.split()[1:] + self.add_enumeration(enum, name, decode_integer(value)) + def decl_enumeration_range(self, req): + enum, name, count, value = req.split()[1:] + try: + count = int(count, 0) + except ValueError as e: + error("Invalid enumeration count in '%s'" % (req,)) + self.add_enumeration(enum, name, (decode_integer(value), count)) def decl_static_str(self, req): msg = req.split(None, 1)[1] - if msg not in self.found_strings: - self.found_strings[msg] = 1 + if msg not in self.static_strings: self.static_strings.append(msg) def update_data_dictionary(self, data): - data['static_strings'] = { i + STATIC_STRING_MIN: s - for i, s in enumerate(self.static_strings) } + for i, s in enumerate(self.static_strings): + self.add_enumeration("static_string_id", s, i + STATIC_STRING_MIN) + data['enumerations'] = self.enumerations def generate_code(self, options): code = [] for i, s in enumerate(self.static_strings): @@ -94,7 +114,7 @@ ctr_lookup_static_string(const char *str) """ return fmt % ("".join(code).strip(),) -Handlers.append(HandleStaticStrings()) +Handlers.append(HandleEnumerations()) ###################################################################### diff --git a/src/command.h b/src/command.h index e0911733..6cf2ca6b 100644 --- a/src/command.h +++ b/src/command.h @@ -21,6 +21,13 @@ #define DECL_CONSTANT_STR(NAME, VALUE) \ DECL_CTR("_DECL_CONSTANT_STR " NAME " " VALUE) +// Declare an enumeration +#define DECL_ENUMERATION(ENUM, NAME, VALUE) \ + DECL_CTR_INT("_DECL_ENUMERATION " ENUM " " NAME, (VALUE)) +#define DECL_ENUMERATION_RANGE(ENUM, NAME, VALUE, COUNT) \ + DECL_CTR_INT("_DECL_ENUMERATION_RANGE " ENUM " " NAME \ + " " __stringify(COUNT), (VALUE)) + // Send an output message (and declare a static message type for it) #define output(FMT, args...) \ command_sendf(_DECL_OUTPUT(FMT) , ##args )