diff --git a/extmod/asyncio/core.py b/extmod/asyncio/core.py index 214cc52f45..e5af3038f7 100644 --- a/extmod/asyncio/core.py +++ b/extmod/asyncio/core.py @@ -219,6 +219,11 @@ def run_until_complete(main_task=None): elif t.state is None: # Task is already finished and nothing await'ed on the task, # so call the exception handler. + + # Save exception raised by the coro for later use. + t.data = exc + + # Create exception context and call the exception handler. _exc_context["exception"] = exc _exc_context["future"] = t Loop.call_exception_handler(_exc_context) diff --git a/extmod/asyncio/funcs.py b/extmod/asyncio/funcs.py index 599091dfbd..3ef8a76b1d 100644 --- a/extmod/asyncio/funcs.py +++ b/extmod/asyncio/funcs.py @@ -63,9 +63,6 @@ class _Remove: # async def gather(*aws, return_exceptions=False): - if not aws: - return [] - def done(t, er): # Sub-task "t" has finished, with exception "er". nonlocal state @@ -86,26 +83,39 @@ def gather(*aws, return_exceptions=False): # Gather waiting is done, schedule the main gather task. core._task_queue.push(gather_task) + # Prepare the sub-tasks for the gather. + # The `state` variable counts the number of tasks to wait for, and can be negative + # if the gather should not run at all (because a task already had an exception). ts = [core._promote_to_task(aw) for aw in aws] + state = 0 for i in range(len(ts)): - if ts[i].state is not True: - # Task is not running, gather not currently supported for this case. + if ts[i].state is True: + # Task is running, register the callback to call when the task is done. + ts[i].state = done + state += 1 + elif not ts[i].state: + # Task finished already. + if not isinstance(ts[i].data, StopIteration): + # Task finished by raising an exception. + if not return_exceptions: + # Do not run this gather at all. + state = -len(ts) + else: + # Task being waited on, gather not currently supported for this case. raise RuntimeError("can't gather") - # Register the callback to call when the task is done. - ts[i].state = done # Set the state for execution of the gather. gather_task = core.cur_task - state = len(ts) cancel_all = False - # Wait for the a sub-task to need attention. - gather_task.data = _Remove - try: - yield - except core.CancelledError as er: - cancel_all = True - state = er + # Wait for a sub-task to need attention (if there are any to wait for). + if state > 0: + gather_task.data = _Remove + try: + yield + except core.CancelledError as er: + cancel_all = True + state = er # Clean up tasks. for i in range(len(ts)): @@ -118,8 +128,13 @@ def gather(*aws, return_exceptions=False): # Sub-task ran to completion, get its return value. ts[i] = ts[i].data.value else: - # Sub-task had an exception with return_exceptions==True, so get its exception. - ts[i] = ts[i].data + # Sub-task had an exception. + if return_exceptions: + # Get the sub-task exception to return in the list of return values. + ts[i] = ts[i].data + elif isinstance(state, int): + # Raise the sub-task exception, if there is not already an exception to raise. + state = ts[i].data # Either this gather was cancelled, or one of the sub-tasks raised an exception with # return_exceptions==False, so reraise the exception here. diff --git a/tests/extmod/asyncio_gather_finished_early.py b/tests/extmod/asyncio_gather_finished_early.py new file mode 100644 index 0000000000..030e79e357 --- /dev/null +++ b/tests/extmod/asyncio_gather_finished_early.py @@ -0,0 +1,65 @@ +# Test asyncio.gather() when a task is already finished before the gather starts. + +try: + import asyncio +except ImportError: + print("SKIP") + raise SystemExit + + +# CPython and MicroPython differ in when they signal (and print) that a task raised an +# uncaught exception. So define an empty custom_handler() to suppress this output. +def custom_handler(loop, context): + pass + + +async def task_that_finishes_early(id, event, fail): + print("task_that_finishes_early", id) + event.set() + if fail: + raise ValueError("intentional exception", id) + + +async def task_that_runs(): + for i in range(5): + print("task_that_runs", i) + await asyncio.sleep(0) + + +async def main(start_task_that_runs, task_fail, return_exceptions): + print("== start", start_task_that_runs, task_fail, return_exceptions) + + # Set exception handler to suppress exception output. + loop = asyncio.get_event_loop() + loop.set_exception_handler(custom_handler) + + # Create tasks. + event_a = asyncio.Event() + event_b = asyncio.Event() + tasks = [] + if start_task_that_runs: + tasks.append(asyncio.create_task(task_that_runs())) + tasks.append(asyncio.create_task(task_that_finishes_early("a", event_a, task_fail))) + tasks.append(asyncio.create_task(task_that_finishes_early("b", event_b, task_fail))) + + # Make sure task_that_finishes_early() are both done, before calling gather(). + await event_a.wait() + await event_b.wait() + + # Gather the tasks. + try: + result = "complete", await asyncio.gather(*tasks, return_exceptions=return_exceptions) + except Exception as er: + result = "exception", er, start_task_that_runs and tasks[0].done() + + # Wait for the final task to finish (if it was started). + if start_task_that_runs: + await tasks[0] + + # Print results. + print(result) + + +# Run the test in the 8 different combinations of its arguments. +for i in range(8): + asyncio.run(main(bool(i & 4), bool(i & 2), bool(i & 1)))