diff --git a/py/compile.c b/py/compile.c index e90b366e0e..e708bde8e0 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1196,7 +1196,8 @@ STATIC void compile_declare_global(compiler_t *comp, mp_parse_node_t pn, qstr qs STATIC void compile_declare_nonlocal(compiler_t *comp, mp_parse_node_t pn, qstr qst, bool added, id_info_t *id_info) { if (added) { - scope_find_local_and_close_over(comp->scope_cur, id_info, qst); + id_info->kind = ID_INFO_KIND_GLOBAL_IMPLICIT; + scope_check_to_close_over(comp->scope_cur, id_info); if (id_info->kind == ID_INFO_KIND_GLOBAL_IMPLICIT) { compile_syntax_error(comp, pn, "no binding for nonlocal found"); } @@ -3391,6 +3392,14 @@ mp_raw_code_t *mp_compile_to_raw_code(mp_parse_tree_t *parse_tree, qstr source_f #endif } else { compile_scope(comp, s, MP_PASS_SCOPE); + + // Check if any implicitly declared variables should be closed over + for (size_t i = 0; i < s->id_info_len; ++i) { + id_info_t *id = &s->id_info[i]; + if (id->kind == ID_INFO_KIND_GLOBAL_IMPLICIT) { + scope_check_to_close_over(s, id); + } + } } // update maximim number of labels needed diff --git a/py/emitcommon.c b/py/emitcommon.c index 89cc2c9597..149e0b0f1f 100644 --- a/py/emitcommon.c +++ b/py/emitcommon.c @@ -35,7 +35,7 @@ void mp_emit_common_get_id_for_load(scope_t *scope, qstr qst) { bool added; id_info_t *id = scope_find_or_add_id(scope, qst, &added); if (added) { - scope_find_local_and_close_over(scope, id, qst); + id->kind = ID_INFO_KIND_GLOBAL_IMPLICIT; } } diff --git a/py/scope.c b/py/scope.c index 1a6ae7b8ad..8adb85b80f 100644 --- a/py/scope.c +++ b/py/scope.c @@ -130,21 +130,19 @@ STATIC void scope_close_over_in_parents(scope_t *scope, qstr qst) { } } -void scope_find_local_and_close_over(scope_t *scope, id_info_t *id, qstr qst) { +void scope_check_to_close_over(scope_t *scope, id_info_t *id) { if (scope->parent != NULL) { for (scope_t *s = scope->parent; s->parent != NULL; s = s->parent) { - id_info_t *id2 = scope_find(s, qst); + id_info_t *id2 = scope_find(s, id->qst); if (id2 != NULL) { if (id2->kind == ID_INFO_KIND_LOCAL || id2->kind == ID_INFO_KIND_CELL || id2->kind == ID_INFO_KIND_FREE) { id->kind = ID_INFO_KIND_FREE; - scope_close_over_in_parents(scope, qst); - return; + scope_close_over_in_parents(scope, id->qst); } break; } } } - id->kind = ID_INFO_KIND_GLOBAL_IMPLICIT; } #endif // MICROPY_ENABLE_COMPILER diff --git a/py/scope.h b/py/scope.h index 5e9a0eb7b2..d51bb90bbf 100644 --- a/py/scope.h +++ b/py/scope.h @@ -93,6 +93,6 @@ void scope_free(scope_t *scope); id_info_t *scope_find_or_add_id(scope_t *scope, qstr qstr, bool *added); id_info_t *scope_find(scope_t *scope, qstr qstr); id_info_t *scope_find_global(scope_t *scope, qstr qstr); -void scope_find_local_and_close_over(scope_t *scope, id_info_t *id, qstr qst); +void scope_check_to_close_over(scope_t *scope, id_info_t *id); #endif // MICROPY_INCLUDED_PY_SCOPE_H diff --git a/tests/basics/scope_implicit.py b/tests/basics/scope_implicit.py new file mode 100644 index 0000000000..aecda77156 --- /dev/null +++ b/tests/basics/scope_implicit.py @@ -0,0 +1,31 @@ +# test implicit scoping rules + +# implicit nonlocal, with variable defined after closure +def f(): + def g(): + return x # implicit nonlocal + x = 3 # variable defined after function that closes over it + return g +print(f()()) + +# implicit nonlocal at inner level, with variable defined after closure +def f(): + def g(): + def h(): + return x # implicit nonlocal + return h + x = 4 # variable defined after function that closes over it + return g +print(f()()()) + +# local variable which should not be implicitly made nonlocal +def f(): + x = 0 + def g(): + x # local because next statement assigns to it + x = 1 + g() +try: + f() +except NameError: + print('NameError') diff --git a/tests/run-tests b/tests/run-tests index d72ae9dc4a..62af90f28c 100755 --- a/tests/run-tests +++ b/tests/run-tests @@ -358,6 +358,7 @@ def run_tests(pyb, tests, args, base_path="."): skip_tests.add('basics/del_deref.py') # requires checking for unbound local skip_tests.add('basics/del_local.py') # requires checking for unbound local skip_tests.add('basics/exception_chain.py') # raise from is not supported + skip_tests.add('basics/scope_implicit.py') # requires checking for unbound local skip_tests.add('basics/try_finally_return2.py') # requires raise_varargs skip_tests.add('basics/unboundlocal.py') # requires checking for unbound local skip_tests.add('misc/features.py') # requires raise_varargs