msgproto: Convert static strings to a more generic enumeration system
Signed-off-by: Kevin O'Connor <kevin@koconnor.net>
This commit is contained in:
parent
7d73a35805
commit
db6e2d4c9e
|
@ -483,7 +483,7 @@ class MCU:
|
||||||
if self._is_shutdown:
|
if self._is_shutdown:
|
||||||
return
|
return
|
||||||
self._is_shutdown = True
|
self._is_shutdown = True
|
||||||
self._shutdown_msg = msg = params['#msg']
|
self._shutdown_msg = msg = params['static_string_id']
|
||||||
logging.info("MCU '%s' %s: %s\n%s\n%s", self._name, params['#name'],
|
logging.info("MCU '%s' %s: %s\n%s\n%s", self._name, params['#name'],
|
||||||
self._shutdown_msg, self._clocksync.dump_debug(),
|
self._shutdown_msg, self._clocksync.dump_debug(),
|
||||||
self._serial.dump_debug())
|
self._serial.dump_debug())
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Protocol definitions for firmware communication
|
# Protocol definitions for firmware communication
|
||||||
#
|
#
|
||||||
# Copyright (C) 2016,2017 Kevin O'Connor <kevin@koconnor.net>
|
# Copyright (C) 2016-2019 Kevin O'Connor <kevin@koconnor.net>
|
||||||
#
|
#
|
||||||
# This file may be distributed under the terms of the GNU GPLv3 license.
|
# This file may be distributed under the terms of the GNU GPLv3 license.
|
||||||
import json, zlib, logging
|
import json, zlib, logging
|
||||||
|
@ -37,9 +37,10 @@ def crc16_ccitt(buf):
|
||||||
return crc
|
return crc
|
||||||
|
|
||||||
class PT_uint32:
|
class PT_uint32:
|
||||||
is_int = 1
|
is_int = True
|
||||||
|
is_dynamic_string = False
|
||||||
max_length = 5
|
max_length = 5
|
||||||
signed = 0
|
signed = False
|
||||||
def encode(self, out, v):
|
def encode(self, out, v):
|
||||||
if v >= 0xc000000 or v < -0x4000000: out.append((v>>28) & 0x7f | 0x80)
|
if v >= 0xc000000 or v < -0x4000000: out.append((v>>28) & 0x7f | 0x80)
|
||||||
if v >= 0x180000 or v < -0x80000: out.append((v>>21) & 0x7f | 0x80)
|
if v >= 0x180000 or v < -0x80000: out.append((v>>21) & 0x7f | 0x80)
|
||||||
|
@ -61,17 +62,18 @@ class PT_uint32:
|
||||||
return v, pos
|
return v, pos
|
||||||
|
|
||||||
class PT_int32(PT_uint32):
|
class PT_int32(PT_uint32):
|
||||||
signed = 1
|
signed = True
|
||||||
class PT_uint16(PT_uint32):
|
class PT_uint16(PT_uint32):
|
||||||
max_length = 3
|
max_length = 3
|
||||||
class PT_int16(PT_int32):
|
class PT_int16(PT_int32):
|
||||||
signed = 1
|
signed = True
|
||||||
max_length = 3
|
max_length = 3
|
||||||
class PT_byte(PT_uint32):
|
class PT_byte(PT_uint32):
|
||||||
max_length = 2
|
max_length = 2
|
||||||
|
|
||||||
class PT_string:
|
class PT_string:
|
||||||
is_int = 0
|
is_int = False
|
||||||
|
is_dynamic_string = True
|
||||||
max_length = 64
|
max_length = 64
|
||||||
def encode(self, out, v):
|
def encode(self, out, v):
|
||||||
out.append(len(v))
|
out.append(len(v))
|
||||||
|
@ -91,22 +93,55 @@ MessageTypes = {
|
||||||
'%s': PT_string(), '%.*s': PT_progmem_buffer(), '%*s': PT_buffer(),
|
'%s': PT_string(), '%.*s': PT_progmem_buffer(), '%*s': PT_buffer(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class Enumeration:
|
||||||
|
is_int = False
|
||||||
|
is_dynamic_string = False
|
||||||
|
def __init__(self, pt, enum_name, enums):
|
||||||
|
self.pt = pt
|
||||||
|
self.max_length = pt.max_length
|
||||||
|
self.enum_name = enum_name
|
||||||
|
self.enums = enums
|
||||||
|
self.reverse_enums = {v: k for k, v in enums.items()}
|
||||||
|
def encode(self, out, v):
|
||||||
|
tv = self.enums.get(v)
|
||||||
|
if tv is None:
|
||||||
|
raise error("Unknown value '%s' in enumeration '%s'" % (
|
||||||
|
v, self.enum_name))
|
||||||
|
self.pt.encode(out, tv)
|
||||||
|
def parse(self, s, pos):
|
||||||
|
v, pos = self.pt.parse(s, pos)
|
||||||
|
tv = self.reverse_enums.get(v)
|
||||||
|
if tv is None:
|
||||||
|
tv = "?%d" % (v,)
|
||||||
|
return tv, pos
|
||||||
|
|
||||||
|
# Lookup the message types for a format string
|
||||||
|
def lookup_params(msgformat, enumerations={}):
|
||||||
|
out = []
|
||||||
|
argparts = [arg.split('=') for arg in msgformat.split()[1:]]
|
||||||
|
for name, fmt in argparts:
|
||||||
|
pt = MessageTypes[fmt]
|
||||||
|
for enum_name, enums in enumerations.items():
|
||||||
|
if name == enum_name or name.endswith('_' + enum_name):
|
||||||
|
pt = Enumeration(pt, enum_name, enums)
|
||||||
|
break
|
||||||
|
out.append((name, pt))
|
||||||
|
return out
|
||||||
|
|
||||||
# Update the message format to be compatible with python's % operator
|
# Update the message format to be compatible with python's % operator
|
||||||
def convert_msg_format(msgformat):
|
def convert_msg_format(msgformat):
|
||||||
mf = msgformat.replace('%c', '%u')
|
for c in ['%u', '%i', '%hu', '%hi', '%c', '%.*s', '%*s']:
|
||||||
mf = mf.replace('%.*s', '%s').replace('%*s', '%s')
|
msgformat = msgformat.replace(c, '%s')
|
||||||
return mf
|
return msgformat
|
||||||
|
|
||||||
class MessageFormat:
|
class MessageFormat:
|
||||||
def __init__(self, msgid, msgformat):
|
def __init__(self, msgid, msgformat, enumerations={}):
|
||||||
self.msgid = msgid
|
self.msgid = msgid
|
||||||
self.msgformat = msgformat
|
self.msgformat = msgformat
|
||||||
self.debugformat = convert_msg_format(msgformat)
|
self.debugformat = convert_msg_format(msgformat)
|
||||||
parts = msgformat.split()
|
self.name = msgformat.split()[0]
|
||||||
self.name = parts[0]
|
self.param_names = lookup_params(msgformat, enumerations)
|
||||||
argparts = [arg.split('=') for arg in parts[1:]]
|
self.param_types = [t for name, t in self.param_names]
|
||||||
self.param_types = [MessageTypes[fmt] for name, fmt in argparts]
|
|
||||||
self.param_names = [(name, MessageTypes[fmt]) for name, fmt in argparts]
|
|
||||||
self.name_to_type = dict(self.param_names)
|
self.name_to_type = dict(self.param_names)
|
||||||
def encode(self, params):
|
def encode(self, params):
|
||||||
out = []
|
out = []
|
||||||
|
@ -131,7 +166,7 @@ class MessageFormat:
|
||||||
out = []
|
out = []
|
||||||
for name, t in self.param_names:
|
for name, t in self.param_names:
|
||||||
v = params[name]
|
v = params[name]
|
||||||
if not t.is_int:
|
if t.is_dynamic_string:
|
||||||
v = repr(v)
|
v = repr(v)
|
||||||
out.append(v)
|
out.append(v)
|
||||||
return self.debugformat % tuple(out)
|
return self.debugformat % tuple(out)
|
||||||
|
@ -162,7 +197,7 @@ class OutputFormat:
|
||||||
out = []
|
out = []
|
||||||
for t in self.param_types:
|
for t in self.param_types:
|
||||||
v, pos = t.parse(s, pos)
|
v, pos = t.parse(s, pos)
|
||||||
if not t.is_int:
|
if t.is_dynamic_string:
|
||||||
v = repr(v)
|
v = repr(v)
|
||||||
out.append(v)
|
out.append(v)
|
||||||
outmsg = self.debugformat % tuple(out)
|
outmsg = self.debugformat % tuple(out)
|
||||||
|
@ -183,10 +218,10 @@ class MessageParser:
|
||||||
error = error
|
error = error
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.unknown = UnknownFormat()
|
self.unknown = UnknownFormat()
|
||||||
|
self.enumerations = {}
|
||||||
self.command_ids = []
|
self.command_ids = []
|
||||||
self.messages_by_id = {}
|
self.messages_by_id = {}
|
||||||
self.messages_by_name = {}
|
self.messages_by_name = {}
|
||||||
self.static_strings = {}
|
|
||||||
self.config = {}
|
self.config = {}
|
||||||
self.version = self.build_versions = ""
|
self.version = self.build_versions = ""
|
||||||
self.raw_identify_data = ""
|
self.raw_identify_data = ""
|
||||||
|
@ -239,9 +274,6 @@ class MessageParser:
|
||||||
if pos != len(s)-MESSAGE_TRAILER_SIZE:
|
if pos != len(s)-MESSAGE_TRAILER_SIZE:
|
||||||
raise error("Extra data at end of message")
|
raise error("Extra data at end of message")
|
||||||
params['#name'] = mid.name
|
params['#name'] = mid.name
|
||||||
static_string_id = params.get('static_string_id')
|
|
||||||
if static_string_id is not None:
|
|
||||||
params['#msg'] = self.static_strings.get(static_string_id, "?")
|
|
||||||
return params
|
return params
|
||||||
def encode(self, seq, cmd):
|
def encode(self, seq, cmd):
|
||||||
msglen = MESSAGE_MIN + len(cmd)
|
msglen = MESSAGE_MIN + len(cmd)
|
||||||
|
@ -282,27 +314,47 @@ class MessageParser:
|
||||||
argparts = dict(arg.split('=', 1) for arg in parts[1:])
|
argparts = dict(arg.split('=', 1) for arg in parts[1:])
|
||||||
for name, value in argparts.items():
|
for name, value in argparts.items():
|
||||||
t = mp.name_to_type[name]
|
t = mp.name_to_type[name]
|
||||||
if t.is_int:
|
if t.is_dynamic_string:
|
||||||
|
tval = self._parse_buffer(value)
|
||||||
|
elif t.is_int:
|
||||||
tval = int(value, 0)
|
tval = int(value, 0)
|
||||||
else:
|
else:
|
||||||
tval = self._parse_buffer(value)
|
tval = value
|
||||||
argparts[name] = tval
|
argparts[name] = tval
|
||||||
except:
|
except:
|
||||||
#traceback.print_exc()
|
#logging.exception("Unable to extract params")
|
||||||
raise error("Unable to extract params from: %s" % (msgname,))
|
raise error("Unable to extract params from: %s" % (msgname,))
|
||||||
try:
|
try:
|
||||||
cmd = mp.encode_by_name(**argparts)
|
cmd = mp.encode_by_name(**argparts)
|
||||||
except:
|
except:
|
||||||
#traceback.print_exc()
|
#logging.exception("Unable to encode")
|
||||||
raise error("Unable to encode: %s" % (msgname,))
|
raise error("Unable to encode: %s" % (msgname,))
|
||||||
return cmd
|
return cmd
|
||||||
|
def _fill_enumerations(self, enumerations):
|
||||||
|
for add_name, add_enums in enumerations.items():
|
||||||
|
enums = self.enumerations.setdefault(add_name, {})
|
||||||
|
for enum, value in add_enums.items():
|
||||||
|
if type(value) == type(0):
|
||||||
|
# Simple enumeration
|
||||||
|
enums[str(enum)] = value
|
||||||
|
continue
|
||||||
|
# Enumeration range
|
||||||
|
enum = enum_root = str(enum)
|
||||||
|
while enum_root and enum_root[-1].isdigit():
|
||||||
|
enum_root = enum_root[:-1]
|
||||||
|
start_enum = 0
|
||||||
|
if len(enum_root) != len(enum):
|
||||||
|
start_enum = int(enum[len(enum_root):])
|
||||||
|
start_value, count = value
|
||||||
|
for i in range(count):
|
||||||
|
enums[enum_root + str(start_enum + i)] = start_value + i
|
||||||
def _init_messages(self, messages, output_ids=[]):
|
def _init_messages(self, messages, output_ids=[]):
|
||||||
for msgformat, msgid in messages.items():
|
for msgformat, msgid in messages.items():
|
||||||
msgid = int(msgid)
|
msgid = int(msgid)
|
||||||
if msgid in output_ids:
|
if msgid in output_ids:
|
||||||
self.messages_by_id[msgid] = OutputFormat(msgid, msgformat)
|
self.messages_by_id[msgid] = OutputFormat(msgid, msgformat)
|
||||||
continue
|
continue
|
||||||
msg = MessageFormat(msgid, msgformat)
|
msg = MessageFormat(msgid, msgformat, self.enumerations)
|
||||||
self.messages_by_id[msgid] = msg
|
self.messages_by_id[msgid] = msg
|
||||||
self.messages_by_name[msg.name] = msg
|
self.messages_by_name[msg.name] = msg
|
||||||
def process_identify(self, data, decompress=True):
|
def process_identify(self, data, decompress=True):
|
||||||
|
@ -311,6 +363,7 @@ class MessageParser:
|
||||||
data = zlib.decompress(data)
|
data = zlib.decompress(data)
|
||||||
self.raw_identify_data = data
|
self.raw_identify_data = data
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
|
self._fill_enumerations(data.get('enumerations', {}))
|
||||||
commands = data.get('commands')
|
commands = data.get('commands')
|
||||||
responses = data.get('responses')
|
responses = data.get('responses')
|
||||||
output = data.get('output', {})
|
output = data.get('output', {})
|
||||||
|
@ -319,8 +372,6 @@ class MessageParser:
|
||||||
all_messages.update(output)
|
all_messages.update(output)
|
||||||
self.command_ids = sorted(commands.values())
|
self.command_ids = sorted(commands.values())
|
||||||
self._init_messages(all_messages, output.values())
|
self._init_messages(all_messages, output.values())
|
||||||
static_strings = data.get('static_strings', {})
|
|
||||||
self.static_strings = {int(k): v for k, v in static_strings.items()}
|
|
||||||
self.config.update(data.get('config', {}))
|
self.config.update(data.get('config', {}))
|
||||||
self.version = data.get('version', '')
|
self.version = data.get('version', '')
|
||||||
self.build_versions = data.get('build_versions', '')
|
self.build_versions = data.get('build_versions', '')
|
||||||
|
|
|
@ -31,7 +31,6 @@ class SerialReader:
|
||||||
# Message handlers
|
# Message handlers
|
||||||
handlers = {
|
handlers = {
|
||||||
'#unknown': self.handle_unknown, '#output': self.handle_output,
|
'#unknown': self.handle_unknown, '#output': self.handle_output,
|
||||||
'shutdown': self.handle_output, 'is_shutdown': self.handle_output
|
|
||||||
}
|
}
|
||||||
self.handlers = { (k, None): v for k, v in handlers.items() }
|
self.handlers = { (k, None): v for k, v in handlers.items() }
|
||||||
def _bg_thread(self):
|
def _bg_thread(self):
|
||||||
|
|
|
@ -60,25 +60,45 @@ Handlers.append(HandleCallList())
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
# Static string generation
|
# Enumeration and static string generation
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
STATIC_STRING_MIN = 2
|
STATIC_STRING_MIN = 2
|
||||||
|
|
||||||
# Generate a dynamic string to integer mapping
|
# Generate a dynamic string to integer mapping
|
||||||
class HandleStaticStrings:
|
class HandleEnumerations:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.static_strings = []
|
self.static_strings = []
|
||||||
self.found_strings = {}
|
self.enumerations = {}
|
||||||
self.ctr_dispatch = { '_DECL_STATIC_STR': self.decl_static_str }
|
self.ctr_dispatch = {
|
||||||
|
'_DECL_STATIC_STR': self.decl_static_str,
|
||||||
|
'_DECL_ENUMERATION': self.decl_enumeration,
|
||||||
|
'_DECL_ENUMERATION_RANGE': self.decl_enumeration_range
|
||||||
|
}
|
||||||
|
def add_enumeration(self, enum, name, value):
|
||||||
|
enums = self.enumerations.setdefault(enum, {})
|
||||||
|
if name in enums and enums[name] != value:
|
||||||
|
error("Conflicting definition for enumeration '%s %s'" % (
|
||||||
|
enum, name))
|
||||||
|
enums[name] = value
|
||||||
|
def decl_enumeration(self, req):
|
||||||
|
enum, name, value = req.split()[1:]
|
||||||
|
self.add_enumeration(enum, name, decode_integer(value))
|
||||||
|
def decl_enumeration_range(self, req):
|
||||||
|
enum, name, count, value = req.split()[1:]
|
||||||
|
try:
|
||||||
|
count = int(count, 0)
|
||||||
|
except ValueError as e:
|
||||||
|
error("Invalid enumeration count in '%s'" % (req,))
|
||||||
|
self.add_enumeration(enum, name, (decode_integer(value), count))
|
||||||
def decl_static_str(self, req):
|
def decl_static_str(self, req):
|
||||||
msg = req.split(None, 1)[1]
|
msg = req.split(None, 1)[1]
|
||||||
if msg not in self.found_strings:
|
if msg not in self.static_strings:
|
||||||
self.found_strings[msg] = 1
|
|
||||||
self.static_strings.append(msg)
|
self.static_strings.append(msg)
|
||||||
def update_data_dictionary(self, data):
|
def update_data_dictionary(self, data):
|
||||||
data['static_strings'] = { i + STATIC_STRING_MIN: s
|
for i, s in enumerate(self.static_strings):
|
||||||
for i, s in enumerate(self.static_strings) }
|
self.add_enumeration("static_string_id", s, i + STATIC_STRING_MIN)
|
||||||
|
data['enumerations'] = self.enumerations
|
||||||
def generate_code(self, options):
|
def generate_code(self, options):
|
||||||
code = []
|
code = []
|
||||||
for i, s in enumerate(self.static_strings):
|
for i, s in enumerate(self.static_strings):
|
||||||
|
@ -94,7 +114,7 @@ ctr_lookup_static_string(const char *str)
|
||||||
"""
|
"""
|
||||||
return fmt % ("".join(code).strip(),)
|
return fmt % ("".join(code).strip(),)
|
||||||
|
|
||||||
Handlers.append(HandleStaticStrings())
|
Handlers.append(HandleEnumerations())
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
|
@ -21,6 +21,13 @@
|
||||||
#define DECL_CONSTANT_STR(NAME, VALUE) \
|
#define DECL_CONSTANT_STR(NAME, VALUE) \
|
||||||
DECL_CTR("_DECL_CONSTANT_STR " NAME " " VALUE)
|
DECL_CTR("_DECL_CONSTANT_STR " NAME " " VALUE)
|
||||||
|
|
||||||
|
// Declare an enumeration
|
||||||
|
#define DECL_ENUMERATION(ENUM, NAME, VALUE) \
|
||||||
|
DECL_CTR_INT("_DECL_ENUMERATION " ENUM " " NAME, (VALUE))
|
||||||
|
#define DECL_ENUMERATION_RANGE(ENUM, NAME, VALUE, COUNT) \
|
||||||
|
DECL_CTR_INT("_DECL_ENUMERATION_RANGE " ENUM " " NAME \
|
||||||
|
" " __stringify(COUNT), (VALUE))
|
||||||
|
|
||||||
// Send an output message (and declare a static message type for it)
|
// Send an output message (and declare a static message type for it)
|
||||||
#define output(FMT, args...) \
|
#define output(FMT, args...) \
|
||||||
command_sendf(_DECL_OUTPUT(FMT) , ##args )
|
command_sendf(_DECL_OUTPUT(FMT) , ##args )
|
||||||
|
|
Loading…
Reference in New Issue