From 5b5562c1d13fa0b78fd13731cebdaa911eb47726 Mon Sep 17 00:00:00 2001 From: Damien George Date: Sat, 31 May 2014 17:59:11 +0100 Subject: [PATCH] py: Fix stack underflow with optimised for loop. --- py/compile.c | 2 +- py/vm.c | 6 +++--- tests/basics/for_return.py | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 tests/basics/for_return.py diff --git a/py/compile.c b/py/compile.c index c90772a7e3..1f0d90570e 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1745,7 +1745,7 @@ void compile_while_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { // And, if the loop never runs, the loop variable should never be assigned void compile_for_stmt_optimised_range(compiler_t *comp, mp_parse_node_t pn_var, mp_parse_node_t pn_start, mp_parse_node_t pn_end, mp_parse_node_t pn_step, mp_parse_node_t pn_body, mp_parse_node_t pn_else) { START_BREAK_CONTINUE_BLOCK - comp->break_label |= MP_EMIT_BREAK_FROM_FOR; + // note that we don't need to pop anything when breaking from an optimise for loop uint top_label = comp_next_label(comp); uint entry_label = comp_next_label(comp); diff --git a/py/vm.c b/py/vm.c index aa7e0e2cfc..75093d2401 100644 --- a/py/vm.c +++ b/py/vm.c @@ -159,8 +159,8 @@ mp_vm_return_kind_t mp_execute_bytecode(const byte *code, const mp_obj_t *args, #if DETECT_VM_STACK_OVERFLOW if (vm_return_kind == MP_VM_RETURN_NORMAL) { - if (sp != state) { - printf("Stack misalign: %d\n", sp - state); + if (sp < state) { + printf("VM stack underflow: " INT_FMT "\n", sp - state); assert(0); } } @@ -178,7 +178,7 @@ mp_vm_return_kind_t mp_execute_bytecode(const byte *code, const mp_obj_t *args, } } if (overflow) { - printf("VM stack overflow state=%p n_state+1=%u\n", state, n_state); + printf("VM stack overflow state=%p n_state+1=" UINT_FMT "\n", state, n_state); assert(0); } } diff --git a/tests/basics/for_return.py b/tests/basics/for_return.py new file mode 100644 index 0000000000..0441352ad9 --- /dev/null +++ b/tests/basics/for_return.py @@ -0,0 +1,7 @@ +# test returning from within a for loop + +def f(): + for i in [1, 2, 3]: + return i + +print(f())