diff --git a/py/modbuiltins.c b/py/modbuiltins.c index c9a49685a4..a7e49a1ed9 100644 --- a/py/modbuiltins.c +++ b/py/modbuiltins.c @@ -322,7 +322,7 @@ STATIC mp_obj_t mp_builtin_next(size_t n_args, const mp_obj_t *args) { if (n_args == 1) { mp_obj_t ret = mp_iternext_allow_raise(args[0]); if (ret == MP_OBJ_STOP_ITERATION) { - mp_raise_type(&mp_type_StopIteration); + mp_raise_StopIteration(MP_STATE_THREAD(stop_iteration_arg)); } else { return ret; } @@ -336,7 +336,7 @@ MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mp_builtin_next_obj, 1, 2, mp_builtin_next); STATIC mp_obj_t mp_builtin_next(mp_obj_t o) { mp_obj_t ret = mp_iternext_allow_raise(o); if (ret == MP_OBJ_STOP_ITERATION) { - mp_raise_type(&mp_type_StopIteration); + mp_raise_StopIteration(MP_STATE_THREAD(stop_iteration_arg)); } else { return ret; } diff --git a/py/mpstate.h b/py/mpstate.h index e42d13fb20..07335bae4c 100644 --- a/py/mpstate.h +++ b/py/mpstate.h @@ -266,6 +266,9 @@ typedef struct _mp_state_thread_t { // pending exception object (MP_OBJ_NULL if not pending) volatile mp_obj_t mp_pending_exception; + // If MP_OBJ_STOP_ITERATION is propagated then this holds its argument. + mp_obj_t stop_iteration_arg; + #if MICROPY_PY_SYS_SETTRACE mp_obj_t prof_trace_callback; bool prof_callback_is_executing; diff --git a/py/objgenerator.c b/py/objgenerator.c index 1ee7b8b1db..784310092e 100644 --- a/py/objgenerator.c +++ b/py/objgenerator.c @@ -238,16 +238,20 @@ mp_vm_return_kind_t mp_obj_gen_resume(mp_obj_t self_in, mp_obj_t send_value, mp_ return ret_kind; } -STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value) { +STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, bool raise_stop_iteration) { mp_obj_t ret; switch (mp_obj_gen_resume(self_in, send_value, throw_value, &ret)) { case MP_VM_RETURN_NORMAL: default: - // Optimize return w/o value in case generator is used in for loop - if (ret == mp_const_none || ret == MP_OBJ_STOP_ITERATION) { - return MP_OBJ_STOP_ITERATION; + // A normal return is a StopIteration, either raise it or return + // MP_OBJ_STOP_ITERATION as an optimisation. + if (ret == mp_const_none) { + ret = MP_OBJ_NULL; + } + if (raise_stop_iteration) { + mp_raise_StopIteration(ret); } else { - nlr_raise(mp_obj_new_exception_arg1(&mp_type_StopIteration, ret)); + return mp_make_stop_iteration(ret); } case MP_VM_RETURN_YIELD: @@ -259,16 +263,11 @@ STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_o } STATIC mp_obj_t gen_instance_iternext(mp_obj_t self_in) { - return gen_resume_and_raise(self_in, mp_const_none, MP_OBJ_NULL); + return gen_resume_and_raise(self_in, mp_const_none, MP_OBJ_NULL, false); } STATIC mp_obj_t gen_instance_send(mp_obj_t self_in, mp_obj_t send_value) { - mp_obj_t ret = gen_resume_and_raise(self_in, send_value, MP_OBJ_NULL); - if (ret == MP_OBJ_STOP_ITERATION) { - mp_raise_type(&mp_type_StopIteration); - } else { - return ret; - } + return gen_resume_and_raise(self_in, send_value, MP_OBJ_NULL, true); } STATIC MP_DEFINE_CONST_FUN_OBJ_2(gen_instance_send_obj, gen_instance_send); @@ -290,12 +289,7 @@ STATIC mp_obj_t gen_instance_throw(size_t n_args, const mp_obj_t *args) { exc = args[2]; } - mp_obj_t ret = gen_resume_and_raise(args[0], mp_const_none, exc); - if (ret == MP_OBJ_STOP_ITERATION) { - mp_raise_type(&mp_type_StopIteration); - } else { - return ret; - } + return gen_resume_and_raise(args[0], mp_const_none, exc, true); } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(gen_instance_throw_obj, 2, 4, gen_instance_throw); diff --git a/py/objgetitemiter.c b/py/objgetitemiter.c index 2670314ab2..31ed4a9228 100644 --- a/py/objgetitemiter.c +++ b/py/objgetitemiter.c @@ -48,7 +48,6 @@ STATIC mp_obj_t it_iternext(mp_obj_t self_in) { // an exception was raised mp_obj_type_t *t = (mp_obj_type_t *)((mp_obj_base_t *)nlr.ret_val)->type; if (t == &mp_type_StopIteration || t == &mp_type_IndexError) { - // return MP_OBJ_STOP_ITERATION instead of raising return MP_OBJ_STOP_ITERATION; } else { // re-raise exception diff --git a/py/runtime.c b/py/runtime.c index b53711bbe2..2c849fe950 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -1217,6 +1217,7 @@ mp_obj_t mp_getiter(mp_obj_t o_in, mp_obj_iter_buf_t *iter_buf) { mp_obj_t mp_iternext_allow_raise(mp_obj_t o_in) { const mp_obj_type_t *type = mp_obj_get_type(o_in); if (type->iternext != NULL) { + MP_STATE_THREAD(stop_iteration_arg) = MP_OBJ_NULL; return type->iternext(o_in); } else { // check for __next__ method @@ -1242,6 +1243,7 @@ mp_obj_t mp_iternext(mp_obj_t o_in) { MP_STACK_CHECK(); // enumerate, filter, map and zip can recursively call mp_iternext const mp_obj_type_t *type = mp_obj_get_type(o_in); if (type->iternext != NULL) { + MP_STATE_THREAD(stop_iteration_arg) = MP_OBJ_NULL; return type->iternext(o_in); } else { // check for __next__ method @@ -1256,7 +1258,7 @@ mp_obj_t mp_iternext(mp_obj_t o_in) { return ret; } else { if (mp_obj_is_subclass_fast(MP_OBJ_FROM_PTR(((mp_obj_base_t *)nlr.ret_val)->type), MP_OBJ_FROM_PTR(&mp_type_StopIteration))) { - return MP_OBJ_STOP_ITERATION; + return mp_make_stop_iteration(mp_obj_exception_get_value(MP_OBJ_FROM_PTR(nlr.ret_val))); } else { nlr_jump(nlr.ret_val); } @@ -1281,14 +1283,18 @@ mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t th } if (type->iternext != NULL && send_value == mp_const_none) { + MP_STATE_THREAD(stop_iteration_arg) = MP_OBJ_NULL; mp_obj_t ret = type->iternext(self_in); *ret_val = ret; if (ret != MP_OBJ_STOP_ITERATION) { return MP_VM_RETURN_YIELD; } else { // The generator is finished. - // This is an optimised "raise StopIteration(None)". - *ret_val = mp_const_none; + // This is an optimised "raise StopIteration(*ret_val)". + *ret_val = MP_STATE_THREAD(stop_iteration_arg); + if (*ret_val == MP_OBJ_NULL) { + *ret_val = mp_const_none; + } return MP_VM_RETURN_NORMAL; } } @@ -1559,6 +1565,14 @@ NORETURN void mp_raise_NotImplementedError(mp_rom_error_text_t msg) { #endif +NORETURN void mp_raise_StopIteration(mp_obj_t arg) { + if (arg == MP_OBJ_NULL) { + mp_raise_type(&mp_type_StopIteration); + } else { + nlr_raise(mp_obj_new_exception_arg1(&mp_type_StopIteration, arg)); + } +} + NORETURN void mp_raise_OSError(int errno_) { nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(errno_))); } diff --git a/py/runtime.h b/py/runtime.h index 7d2cb94e84..8484479a54 100644 --- a/py/runtime.h +++ b/py/runtime.h @@ -154,6 +154,11 @@ mp_obj_t mp_iternext_allow_raise(mp_obj_t o); // may return MP_OBJ_STOP_ITERATIO mp_obj_t mp_iternext(mp_obj_t o); // will always return MP_OBJ_STOP_ITERATION instead of raising StopIteration(...) mp_vm_return_kind_t mp_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val); +static inline mp_obj_t mp_make_stop_iteration(mp_obj_t o) { + MP_STATE_THREAD(stop_iteration_arg) = o; + return MP_OBJ_STOP_ITERATION; +} + mp_obj_t mp_make_raise_obj(mp_obj_t o); mp_obj_t mp_import_name(qstr name, mp_obj_t fromlist, mp_obj_t level); @@ -179,6 +184,7 @@ NORETURN void mp_raise_TypeError(mp_rom_error_text_t msg); NORETURN void mp_raise_NotImplementedError(mp_rom_error_text_t msg); #endif +NORETURN void mp_raise_StopIteration(mp_obj_t arg); NORETURN void mp_raise_OSError(int errno_); NORETURN void mp_raise_recursion_depth(void); diff --git a/tests/basics/gen_yield_from_stopped.py b/tests/basics/gen_yield_from_stopped.py index 468679b615..82feefed08 100644 --- a/tests/basics/gen_yield_from_stopped.py +++ b/tests/basics/gen_yield_from_stopped.py @@ -16,3 +16,15 @@ try: next(run()) except StopIteration: print("StopIteration") + + +# Where "f" is a native generator +def run(): + print((yield from f)) + + +f = zip() +try: + next(run()) +except StopIteration: + print("StopIteration") diff --git a/tests/basics/stopiteration.py b/tests/basics/stopiteration.py new file mode 100644 index 0000000000..d4719c9bc3 --- /dev/null +++ b/tests/basics/stopiteration.py @@ -0,0 +1,63 @@ +# test StopIteration interaction with generators + +try: + enumerate, exec +except: + print("SKIP") + raise SystemExit + + +def get_stop_iter_arg(msg, code): + try: + exec(code) + print("FAIL") + except StopIteration as er: + print(msg, er.args) + + +class A: + def __iter__(self): + return self + + def __next__(self): + raise StopIteration(42) + + +class B: + def __getitem__(self, index): + # argument to StopIteration should get ignored + raise StopIteration(42) + + +def gen(x): + return x + yield + + +def gen2(x): + try: + yield + except ValueError: + pass + return x + + +get_stop_iter_arg("next", "next(A())") +get_stop_iter_arg("iter", "next(iter(B()))") +get_stop_iter_arg("enumerate", "next(enumerate(A()))") +get_stop_iter_arg("map", "next(map(lambda x:x, A()))") +get_stop_iter_arg("zip", "next(zip(A()))") +g = gen(None) +get_stop_iter_arg("generator0", "next(g)") +get_stop_iter_arg("generator1", "next(g)") +g = gen(42) +get_stop_iter_arg("generator0", "next(g)") +get_stop_iter_arg("generator1", "next(g)") +get_stop_iter_arg("send", "gen(None).send(None)") +get_stop_iter_arg("send", "gen(42).send(None)") +g = gen2(None) +next(g) +get_stop_iter_arg("throw", "g.throw(ValueError)") +g = gen2(42) +next(g) +get_stop_iter_arg("throw", "g.throw(ValueError)")