diff --git a/tools/mpy-tool.py b/tools/mpy-tool.py index fb18eb8cbd..cc30eafecd 100755 --- a/tools/mpy-tool.py +++ b/tools/mpy-tool.py @@ -182,7 +182,7 @@ mp_binary_op_method_name = ( ) -class Opcodes: +class Opcode: # fmt: off # Load, Store, Delete, Import, Make, Build, Unpack, Call, Jump, Exception, For, sTack, Return, Yield, Op MP_BC_BASE_RESERVED = (0x00) # ---------------- @@ -318,6 +318,13 @@ class Opcodes: for i in range(MP_BC_BINARY_OP_MULTI_NUM): mapping[MP_BC_BINARY_OP_MULTI + i] = "BINARY_OP %d %s" % (i, mp_binary_op_method_name[i]) + def __init__(self, offset, fmt, opcode_byte, arg, extra_arg): + self.offset = offset + self.fmt = fmt + self.opcode_byte = opcode_byte + self.arg = arg + self.extra_arg = extra_arg + # This definition of a small int covers all possible targets, in the sense that every # target can encode as a small int, an integer that passes this test. The minimum is set @@ -326,15 +333,31 @@ def mp_small_int_fits(i): return -0x2000 <= i <= 0x1FFF +def mp_encode_uint(val, signed=False): + encoded = bytearray([val & 0x7F]) + val >>= 7 + while val != 0 and val != -1: + encoded.insert(0, 0x80 | (val & 0x7F)) + val >>= 7 + if signed: + if val == -1 and encoded[0] & 0x40 == 0: + encoded.insert(0, 0xFF) + elif val == 0 and encoded[0] & 0x40 != 0: + encoded.insert(0, 0x80) + return encoded + + def mp_opcode_decode(bytecode, ip): opcode = bytecode[ip] ip_start = ip f = (0x000003A4 >> (2 * ((opcode) >> 4))) & 3 - extra_byte = (opcode & MP_BC_MASK_EXTRA_BYTE) == 0 ip += 1 - arg = 0 + arg = None + extra_arg = None if f in (MP_BC_FORMAT_QSTR, MP_BC_FORMAT_VAR_UINT): arg = bytecode[ip] & 0x7F + if opcode == Opcode.MP_BC_LOAD_CONST_SMALL_INT and arg & 0x40 != 0: + arg |= -1 << 7 while bytecode[ip] & 0x80 != 0: ip += 1 arg = arg << 7 | bytecode[ip] & 0x7F @@ -343,15 +366,50 @@ def mp_opcode_decode(bytecode, ip): if bytecode[ip] & 0x80 == 0: arg = bytecode[ip] ip += 1 - if opcode in Opcodes.ALL_OFFSET_SIGNED: + if opcode in Opcode.ALL_OFFSET_SIGNED: arg -= 0x40 else: arg = bytecode[ip] & 0x7F | bytecode[ip + 1] << 7 ip += 2 - if opcode in Opcodes.ALL_OFFSET_SIGNED: + if opcode in Opcode.ALL_OFFSET_SIGNED: arg -= 0x4000 - ip += extra_byte - return f, ip - ip_start, arg + if opcode & MP_BC_MASK_EXTRA_BYTE == 0: + extra_arg = bytecode[ip] + ip += 1 + return f, ip - ip_start, arg, extra_arg + + +def mp_opcode_encode(opcode): + overflow = False + encoded = bytearray([opcode.opcode_byte]) + if opcode.fmt in (MP_BC_FORMAT_QSTR, MP_BC_FORMAT_VAR_UINT): + signed = opcode.opcode_byte == Opcode.MP_BC_LOAD_CONST_SMALL_INT + encoded.extend(mp_encode_uint(opcode.arg, signed)) + elif opcode.fmt == MP_BC_FORMAT_OFFSET: + is_signed = opcode.opcode_byte in Opcode.ALL_OFFSET_SIGNED + + # The -2 accounts for this jump opcode taking 2 bytes (at least). + bytecode_offset = opcode.target.offset - opcode.offset - 2 + + # Check if the bytecode_offset is small enough to use a 1-byte encoding. + if (is_signed and -64 <= bytecode_offset <= 63) or ( + not is_signed and bytecode_offset <= 127 + ): + # Use a 1-byte jump offset. + if is_signed: + bytecode_offset += 0x40 + overflow = not (0 <= bytecode_offset <= 0x7F) + encoded.append(bytecode_offset & 0x7F) + else: + bytecode_offset -= 1 + if is_signed: + bytecode_offset += 0x4000 + overflow = not (0 <= bytecode_offset <= 0x7FFF) + encoded.append(0x80 | (bytecode_offset & 0x7F)) + encoded.append((bytecode_offset >> 7) & 0xFF) + if opcode.extra_arg is not None: + encoded.append(opcode.extra_arg) + return overflow, encoded def read_prelude_sig(read_byte): @@ -393,6 +451,21 @@ def read_prelude_size(read_byte): return I, C +# See py/bc.h:MP_BC_PRELUDE_SIZE_ENCODE macro. +def encode_prelude_size(I, C): + # Encode bit-wise as: xIIIIIIC + encoded = bytearray() + while True: + z = (I & 0x3F) << 1 | (C & 1) + C >>= 1 + I >>= 6 + if C | I: + z |= 0x80 + encoded.append(z) + if not C | I: + return encoded + + def extract_prelude(bytecode, ip): def local_read_byte(): b = bytecode[ip_ref[0]] @@ -400,6 +473,8 @@ def extract_prelude(bytecode, ip): return b ip_ref = [ip] # to close over ip in Python 2 and 3 + + # Read prelude signature. ( n_state, n_exc_stack, @@ -409,13 +484,12 @@ def extract_prelude(bytecode, ip): n_def_pos_args, ) = read_prelude_sig(local_read_byte) - n_info, n_cell = read_prelude_size(local_read_byte) - ip = ip_ref[0] + offset_prelude_size = ip_ref[0] - ip2 = ip - ip = ip2 + n_info + n_cell - # ip now points to first opcode - # ip2 points to simple_name qstr + # Read prelude size. + n_info, n_cell = read_prelude_size(local_read_byte) + + offset_source_info = ip_ref[0] # Extract simple_name and argument qstrs (var uints). args = [] @@ -428,11 +502,18 @@ def extract_prelude(bytecode, ip): break args.append(value) + offset_line_info = ip_ref[0] + offset_closure_info = offset_source_info + n_info + offset_opcodes = offset_source_info + n_info + n_cell + return ( - ip2, - ip, - ip_ref[0], + offset_prelude_size, + offset_source_info, + offset_line_info, + offset_closure_info, + offset_opcodes, (n_state, n_exc_stack, scope_flags, n_pos_args, n_kwonly_args, n_def_pos_args), + (n_info, n_cell), args, ) @@ -480,6 +561,8 @@ class CompiledModule: qstr_table, obj_table, raw_code, + qstr_table_file_offset, + obj_table_file_offset, raw_code_file_offset, escaped_name, ): @@ -489,8 +572,10 @@ class CompiledModule: self.header = header self.qstr_table = qstr_table self.obj_table = obj_table - self.raw_code_file_offset = raw_code_file_offset self.raw_code = raw_code + self.qstr_table_file_offset = qstr_table_file_offset + self.obj_table_file_offset = obj_table_file_offset + self.raw_code_file_offset = raw_code_file_offset self.escaped_name = escaped_name def hexdump(self): @@ -772,14 +857,17 @@ class RawCode(object): if code_kind in (MP_CODE_BYTECODE, MP_CODE_NATIVE_PY): ( - self.offset_names, - self.offset_opcodes, + self.offset_prelude_size, + self.offset_source_info, self.offset_line_info, - self.prelude, + self.offset_closure_info, + self.offset_opcodes, + self.prelude_signature, + self.prelude_size, self.names, ) = extract_prelude(self.fun_data, prelude_offset) - self.scope_flags = self.prelude[2] - self.n_pos_args = self.prelude[3] + self.scope_flags = self.prelude_signature[2] + self.n_pos_args = self.prelude_signature[3] self.simple_name = self.qstr_table[self.names[0]] else: self.simple_name = self.qstr_table[0] @@ -836,12 +924,12 @@ class RawCode(object): if self.code_kind == MP_CODE_BYTECODE: print(" #if MICROPY_PY_SYS_SETTRACE") print(" .prelude = {") - print(" .n_state = %u," % self.prelude[0]) - print(" .n_exc_stack = %u," % self.prelude[1]) - print(" .scope_flags = %u," % self.prelude[2]) - print(" .n_pos_args = %u," % self.prelude[3]) - print(" .n_kwonly_args = %u," % self.prelude[4]) - print(" .n_def_pos_args = %u," % self.prelude[5]) + print(" .n_state = %u," % self.prelude_signature[0]) + print(" .n_exc_stack = %u," % self.prelude_signature[1]) + print(" .scope_flags = %u," % self.prelude_signature[2]) + print(" .n_pos_args = %u," % self.prelude_signature[3]) + print(" .n_kwonly_args = %u," % self.prelude_signature[4]) + print(" .n_def_pos_args = %u," % self.prelude_signature[5]) print(" .qstr_block_name_idx = %u," % self.names[0]) print( " .line_info = fun_data_%s + %u," @@ -878,13 +966,13 @@ class RawCodeBytecode(RawCode): bc = self.fun_data print("simple_name:", self.simple_name.str) print(" raw bytecode:", len(bc), hexlify_to_str(bc)) - print(" prelude:", self.prelude) + print(" prelude:", self.prelude_signature) print(" args:", [self.qstr_table[i].str for i in self.names[1:]]) print(" line info:", hexlify_to_str(bc[self.offset_line_info : self.offset_opcodes])) ip = self.offset_opcodes while ip < len(bc): - fmt, sz, arg = mp_opcode_decode(bc, ip) - if bc[ip] == Opcodes.MP_BC_LOAD_CONST_OBJ: + fmt, sz, arg, _ = mp_opcode_decode(bc, ip) + if bc[ip] == Opcode.MP_BC_LOAD_CONST_OBJ: arg = repr(self.obj_table[arg]) if fmt == MP_BC_FORMAT_QSTR: arg = self.qstr_table[arg].str @@ -893,7 +981,7 @@ class RawCodeBytecode(RawCode): else: arg = "" print( - " %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcodes.mapping[bc[ip]], arg) + " %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcode.mapping[bc[ip]], arg) ) ip += sz self.disassemble_children() @@ -908,12 +996,12 @@ class RawCodeBytecode(RawCode): print("static const byte fun_data_%s[%u] = {" % (self.escaped_name, len(bc))) print(" ", end="") - for b in bc[: self.offset_names]: + for b in bc[: self.offset_source_info]: print("0x%02x," % b, end="") print(" // prelude") print(" ", end="") - for b in bc[self.offset_names : self.offset_line_info]: + for b in bc[self.offset_source_info : self.offset_line_info]: print("0x%02x," % b, end="") print(" // names: %s" % ", ".join(self.qstr_table[i].str for i in self.names)) @@ -924,8 +1012,8 @@ class RawCodeBytecode(RawCode): ip = self.offset_opcodes while ip < len(bc): - fmt, sz, arg = mp_opcode_decode(bc, ip) - opcode_name = Opcodes.mapping[bc[ip]] + fmt, sz, arg, _ = mp_opcode_decode(bc, ip) + opcode_name = Opcode.mapping[bc[ip]] if fmt == MP_BC_FORMAT_QSTR: opcode_name += " " + repr(self.qstr_table[arg].str) elif fmt in (MP_BC_FORMAT_VAR_UINT, MP_BC_FORMAT_OFFSET): @@ -1000,7 +1088,7 @@ class RawCodeNative(RawCode): ) if self.code_kind != MP_CODE_NATIVE_PY: return - print(" prelude:", self.prelude) + print(" prelude:", self.prelude_signature) print(" args:", [self.qstr_table[i].str for i in self.names[1:]]) print(" line info:", fun_data[self.offset_line_info : self.offset_opcodes]) ip = 0 @@ -1255,11 +1343,13 @@ def read_mpy(filename): n_obj = reader.read_uint() # Read qstrs and construct qstr table. + qstr_table_file_offset = reader.tell() qstr_table = [] for i in range(n_qstr): qstr_table.append(read_qstr(reader, segments)) # Read objects and construct object table. + obj_table_file_offset = reader.tell() obj_table = [] for i in range(n_obj): obj_table.append(read_obj(reader, segments)) @@ -1279,6 +1369,8 @@ def read_mpy(filename): qstr_table, obj_table, raw_code, + qstr_table_file_offset, + obj_table_file_offset, raw_code_file_offset, cm_escaped_name, ) @@ -1477,25 +1569,100 @@ def freeze_mpy(base_qstrs, compiled_modules): print("*/") -def merge_mpy(raw_codes, output_file): - assert len(raw_codes) <= 2 # so var-uints all fit in 1 byte +def adjust_bytecode_qstr_obj_indices(bytecode_in, qstr_table_base, obj_table_base): + # Expand bytcode to a list of opcodes. + opcodes = [] + labels = {} + ip = 0 + while ip < len(bytecode_in): + fmt, sz, arg, extra_arg = mp_opcode_decode(bytecode_in, ip) + opcode = Opcode(ip, fmt, bytecode_in[ip], arg, extra_arg) + labels[ip] = opcode + opcodes.append(opcode) + ip += sz + if fmt == MP_BC_FORMAT_OFFSET: + opcode.arg += ip + + # Link jump opcodes to their destination. + for opcode in opcodes: + if opcode.fmt == MP_BC_FORMAT_OFFSET: + opcode.target = labels[opcode.arg] + + # Adjust bytcode as required. + for opcode in opcodes: + if opcode.fmt == MP_BC_FORMAT_QSTR: + opcode.arg += qstr_table_base + elif opcode.opcode_byte == Opcode.MP_BC_LOAD_CONST_OBJ: + opcode.arg += obj_table_base + + # Write out new bytecode. + offset_changed = True + while offset_changed: + offset_changed = False + overflow = False + bytecode_out = b"" + for opcode in opcodes: + ip = len(bytecode_out) + if opcode.offset != ip: + offset_changed = True + opcode.offset = ip + opcode_overflow, encoded_opcode = mp_opcode_encode(opcode) + if opcode_overflow: + overflow = True + bytecode_out += encoded_opcode + + if overflow: + raise Exception("bytecode overflow") + + return bytecode_out + + +def rewrite_raw_code(rc, qstr_table_base, obj_table_base): + if rc.code_kind != MP_CODE_BYTECODE: + raise Exception("can only rewrite bytecode") + + source_info = bytearray() + for arg in rc.names: + source_info.extend(mp_encode_uint(qstr_table_base + arg)) + + closure_info = rc.fun_data[rc.offset_closure_info : rc.offset_opcodes] + + bytecode_in = memoryview(rc.fun_data)[rc.offset_opcodes :] + bytecode_out = adjust_bytecode_qstr_obj_indices(bytecode_in, qstr_table_base, obj_table_base) + + prelude_signature = rc.fun_data[: rc.offset_prelude_size] + prelude_size = encode_prelude_size(len(source_info), len(closure_info)) + + fun_data = prelude_signature + prelude_size + source_info + closure_info + bytecode_out + + output = mp_encode_uint(len(fun_data) << 3 | bool(len(rc.children)) << 2) + output += fun_data + + if rc.children: + output += mp_encode_uint(len(rc.children)) + for child in rc.children: + output += rewrite_raw_code(child, qstr_table_base, obj_table_base) + + return output + + +def merge_mpy(compiled_modules, output_file): merged_mpy = bytearray() - if len(raw_codes) == 1: - with open(raw_codes[0].mpy_source_file, "rb") as f: + if len(compiled_modules) == 1: + with open(compiled_modules[0].mpy_source_file, "rb") as f: merged_mpy.extend(f.read()) else: - main_rc = None - for rc in raw_codes: - if len(rc.qstr_table) > 1 or len(rc.obj_table) > 0: + main_cm_idx = None + for idx, cm in enumerate(compiled_modules): + if cm.header[2]: # Must use qstr_table and obj_table from this raw_code - if main_rc is not None: - raise Exception( - "can't merge files when more than one has a populated qstr or obj table" - ) - main_rc = rc - if main_rc is None: - main_rc = raw_codes[0] + if main_cm_idx is not None: + raise Exception("can't merge files when more than one contains native code") + main_cm_idx = idx + if main_cm_idx is not None: + # Shift main_cm to front of list. + compiled_modules.insert(0, compiled_modules.pop(main_cm_idx)) header = bytearray(4) header[0] = ord("M") @@ -1504,32 +1671,50 @@ def merge_mpy(raw_codes, output_file): header[3] = config.mp_small_int_bits merged_mpy.extend(header) - # Copy n_qstr, n_obj, qstr_table, obj_table from main_rc. - with open(main_rc.mpy_source_file, "rb") as f: - data = f.read(main_rc.raw_code_file_offset) - merged_mpy.extend(data[4:]) + n_qstr = 0 + n_obj = 0 + for cm in compiled_modules: + n_qstr += len(cm.qstr_table) + n_obj += len(cm.obj_table) + merged_mpy.extend(mp_encode_uint(n_qstr)) + merged_mpy.extend(mp_encode_uint(n_obj)) + + # Copy verbatim the qstr and object tables from all compiled modules. + def copy_section(file, offset, offset2): + with open(file, "rb") as f: + f.seek(offset) + merged_mpy.extend(f.read(offset2 - offset)) + + for cm in compiled_modules: + copy_section(cm.mpy_source_file, cm.qstr_table_file_offset, cm.obj_table_file_offset) + for cm in compiled_modules: + copy_section(cm.mpy_source_file, cm.obj_table_file_offset, cm.raw_code_file_offset) bytecode = bytearray() - bytecode_len = 3 + len(raw_codes) * 5 + 2 - bytecode.append(bytecode_len << 3 | 1 << 2) # kind, has_children and length - bytecode.append(0b00000000) # signature prelude - bytecode.append(0b00000010) # size prelude; n_info=1 + bytecode.append(0b00000000) # prelude signature + bytecode.append(0b00000010) # prelude size (n_info=1, n_cell=0) bytecode.extend(b"\x00") # simple_name: qstr index 0 (will use source filename) - for idx in range(len(raw_codes)): + for idx in range(len(compiled_modules)): bytecode.append(0x32) # MP_BC_MAKE_FUNCTION bytecode.append(idx) # index raw code bytecode.extend(b"\x34\x00\x59") # MP_BC_CALL_FUNCTION, 0 args, MP_BC_POP_TOP bytecode.extend(b"\x51\x63") # MP_BC_LOAD_NONE, MP_BC_RETURN_VALUE + merged_mpy.extend(mp_encode_uint(len(bytecode) << 3 | 1 << 2)) # length, has_children merged_mpy.extend(bytecode) + merged_mpy.extend(mp_encode_uint(len(compiled_modules))) # n_children - merged_mpy.append(len(raw_codes)) # n_children - - for rc in raw_codes: - with open(rc.mpy_source_file, "rb") as f: - f.seek(rc.raw_code_file_offset) - data = f.read() # read rest of mpy file - merged_mpy.extend(data) + qstr_table_base = 0 + obj_table_base = 0 + for cm in compiled_modules: + if qstr_table_base == 0 and obj_table_base == 0: + with open(cm.mpy_source_file, "rb") as f: + f.seek(cm.raw_code_file_offset) + merged_mpy.extend(f.read()) + else: + merged_mpy.extend(rewrite_raw_code(cm.raw_code, qstr_table_base, obj_table_base)) + qstr_table_base += len(cm.qstr_table) + obj_table_base += len(cm.obj_table) if output_file is None: sys.stdout.buffer.write(merged_mpy)