diff --git a/tools/mpy-tool.py b/tools/mpy-tool.py index 6a769f7b40..e3a812ebe0 100755 --- a/tools/mpy-tool.py +++ b/tools/mpy-tool.py @@ -91,19 +91,6 @@ class Config: config = Config() -class QStrType: - def __init__(self, str): - self.str = str - self.qstr_esc = qstrutil.qstr_escape(self.str) - self.qstr_id = "MP_QSTR_" + self.qstr_esc - - -# Initialise global list of qstrs with static qstrs -global_qstrs = [None] # MP_QSTRnull should never be referenced -for n in qstrutil.static_qstr_list: - global_qstrs.append(QStrType(n)) - - MP_CODE_BYTECODE = 2 MP_CODE_NATIVE_PY = 3 MP_CODE_NATIVE_VIPER = 4 @@ -469,6 +456,29 @@ def extract_prelude(bytecode, ip): ) +class QStrType: + def __init__(self, str): + self.str = str + self.qstr_esc = qstrutil.qstr_escape(self.str) + self.qstr_id = "MP_QSTR_" + self.qstr_esc + + +class GlobalQStrList: + def __init__(self): + # Initialise global list of qstrs with static qstrs + self.qstrs = [None] # MP_QSTRnull should never be referenced + for n in qstrutil.static_qstr_list: + self.qstrs.append(QStrType(n)) + + def add(self, s): + q = QStrType(s) + self.qstrs.append(q) + return q + + def get_by_index(self, i): + return self.qstrs[i] + + class MPFunTable: def __repr__(self): return "mp_fun_table" @@ -496,10 +506,6 @@ class CompiledModule: self.raw_code = raw_code self.escaped_name = escaped_name - def _unpack_qstr(self, ip): - qst = self.bytecode[ip] | self.bytecode[ip + 1] << 8 - return global_qstrs[qst] - def hexdump(self): with open(self.mpy_source_file, "rb") as f: WIDTH = 16 @@ -1077,8 +1083,7 @@ class RawCodeNative(RawCode): if qi < len(self.qstr_links) and i == self.qstr_links[qi][0]: # link qstr qi_off, qi_kind, qi_val = self.qstr_links[qi] - qst = global_qstrs[qi_val].qstr_id - i += self._link_qstr(i, qi_kind, qst) + i += self._link_qstr(i, qi_kind, qi_val.qstr_id) qi += 1 else: # copy machine code (max 16 bytes) @@ -1139,17 +1144,15 @@ def read_qstr(reader, segments): ln = reader.read_uint() if ln & 1: # static qstr - segments.append( - MPYSegment(MPYSegment.META, global_qstrs[ln >> 1].str, start_pos, start_pos) - ) - return ln >> 1 + q = global_qstrs.get_by_index(ln >> 1) + segments.append(MPYSegment(MPYSegment.META, q.str, start_pos, start_pos)) + return q ln >>= 1 start_pos = reader.tell() data = str_cons(reader.read_bytes(ln), "utf8") reader.read_byte() # read and discard null terminator segments.append(MPYSegment(MPYSegment.QSTR, data, start_pos, reader.tell())) - global_qstrs.append(QStrType(data)) - return len(global_qstrs) - 1 + return global_qstrs.add(data) def read_obj(reader, segments): @@ -1304,8 +1307,7 @@ def read_mpy(filename): # Read qstrs and construct qstr table. qstr_table = [] for i in range(n_qstr): - q = read_qstr(reader, segments) - qstr_table.append(global_qstrs[q]) + qstr_table.append(read_qstr(reader, segments)) # Read objects and construct object table. obj_table = [] @@ -1345,7 +1347,7 @@ def disassemble_mpy(compiled_modules): def freeze_mpy(base_qstrs, compiled_modules): # add to qstrs new = {} - for q in global_qstrs: + for q in global_qstrs.qstrs: # don't add duplicates if q is None or q.qstr_esc in base_qstrs or q.qstr_esc in new: continue @@ -1587,6 +1589,8 @@ def merge_mpy(raw_codes, output_file): def main(): + global global_qstrs + import argparse cmd_parser = argparse.ArgumentParser(description="A tool to work with MicroPython .mpy files.") @@ -1637,6 +1641,9 @@ def main(): config.MICROPY_QSTR_BYTES_IN_HASH = 1 base_qstrs = list(qstrutil.static_qstr_list) + # Create initial list of global qstrs. + global_qstrs = GlobalQStrList() + # Load all .mpy files. try: compiled_modules = [read_mpy(file) for file in args.files]