Skip to content

Commit f6eba86

Browse files
committed
Finalize async generators in the correct context
1 parent 6f2870a commit f6eba86

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

newsfragments/92.feature.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
trio-asyncio now properly finalizes asyncio-flavored async generators
2+
upon closure of the event loop. Previously, Trio's async generator finalizers
3+
would try to finalize all async generators in Trio mode, regardless of their
4+
flavor, which could lead to stderr spew.

tests/test_trio_asyncio.py

+68
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import types
44
import asyncio
55
import trio
6+
import trio.testing
67
import trio_asyncio
78
import contextlib
9+
import gc
810

911

1012
async def use_asyncio():
@@ -203,3 +205,69 @@ async def main():
203205
asyncio.run(main())
204206

205207
assert scope.value.code == 42
208+
209+
210+
@pytest.mark.trio
211+
@pytest.mark.parametrize("alive_on_exit", (False, True))
212+
@pytest.mark.parametrize("slow_finalizer", (False, True))
213+
@pytest.mark.parametrize("loop_timeout", (0, 1, 20))
214+
async def test_asyncgens(alive_on_exit, slow_finalizer, loop_timeout, autojump_clock):
215+
import sniffio
216+
217+
record = set()
218+
holder = []
219+
220+
async def agen(label):
221+
assert sniffio.current_async_library() == label
222+
try:
223+
yield 1
224+
finally:
225+
library = sniffio.current_async_library()
226+
try:
227+
await sys.modules[library].sleep(10 if slow_finalizer else 0)
228+
except (trio.Cancelled, asyncio.CancelledError):
229+
pass
230+
record.add((label, library))
231+
232+
async def iterate_in_trio():
233+
await agen("trio").asend(None)
234+
235+
async def aio_main(nursery):
236+
nursery.start_soon(iterate_in_trio)
237+
ag = agen("asyncio")
238+
await ag.asend(None)
239+
if alive_on_exit:
240+
holder.append(ag)
241+
else:
242+
del ag
243+
for _ in range(5):
244+
gc.collect()
245+
246+
sys.unraisablehook, prev_hook = sys.__unraisablehook__, sys.unraisablehook
247+
try:
248+
start_time = trio.current_time()
249+
with trio.move_on_after(loop_timeout) as scope:
250+
if loop_timeout == 0:
251+
scope.cancel()
252+
async with trio_asyncio.open_loop() as loop:
253+
async with trio.open_nursery() as nursery:
254+
await loop.run_aio_coroutine(aio_main(nursery))
255+
256+
# asyncio agen should be finalized as soon as asyncio loop ends
257+
assert ("asyncio", "asyncio") in record
258+
259+
# asyncio agen finalizer should be able to take a cancel
260+
if (slow_finalizer or loop_timeout == 0) and alive_on_exit:
261+
assert trio.current_time() == start_time + min(loop_timeout, 10)
262+
assert scope.cancelled_caught == (loop_timeout < 10)
263+
else:
264+
assert trio.current_time() == start_time
265+
assert not scope.cancelled_caught
266+
267+
# trio asyncgen should eventually be finalized in trio mode
268+
for _ in range(5):
269+
gc.collect()
270+
await trio.testing.wait_all_tasks_blocked()
271+
assert record == {("trio", "trio"), ("asyncio", "asyncio")}
272+
finally:
273+
sys.unraisablehook = prev_hook

trio_asyncio/_base.py

+31
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class BaseTrioEventLoop(asyncio.SelectorEventLoop):
135135
# (threading) Thread this loop is running in
136136
_thread = None
137137

138+
# Trio's async generator hooks, as obtained before we added ours
139+
_trio_asyncgen_hooks = None
140+
138141
def __init__(self, queue_len=None):
139142
if queue_len is None:
140143
queue_len = math.inf
@@ -629,6 +632,30 @@ async def _main_loop_init(self, nursery):
629632
self._nursery = nursery
630633
self._task = trio.lowlevel.current_task()
631634
self._token = trio.lowlevel.current_trio_token()
635+
self._trio_asyncgen_hooks = sys.get_asyncgen_hooks()
636+
assert self._trio_asyncgen_hooks.firstiter is not None
637+
assert self._trio_asyncgen_hooks.finalizer is not None
638+
sys.set_asyncgen_hooks(
639+
firstiter=self._dispatch_asyncgen_firstiter,
640+
finalizer=self._dispatch_asyncgen_finalizer,
641+
)
642+
643+
def _dispatch_asyncgen_firstiter(self, agen):
644+
if sniffio_library.name == "asyncio":
645+
return self._asyncgen_firstiter_hook(agen)
646+
else:
647+
agen.ag_frame.f_locals["@trio_asyncio_trio_asyncgen"] = True
648+
return self._trio_asyncgen_hooks.firstiter(agen)
649+
650+
def _dispatch_asyncgen_finalizer(self, agen):
651+
try:
652+
is_ours = not agen.ag_frame.f_locals.get("@trio_asyncio_trio_asyncgen")
653+
except AttributeError: # pragma: no cover
654+
is_ours = True
655+
if is_ours:
656+
return self._asyncgen_finalizer_hook(agen)
657+
else:
658+
return self._trio_asyncgen_hooks.finalizer(agen)
632659

633660
async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED):
634661
"""Run the loop by processing its event queue.
@@ -738,6 +765,10 @@ async def _main_loop_exit(self):
738765
except TrioAsyncioExit:
739766
pass
740767

768+
# Restore previous async generator hooks
769+
sys.set_asyncgen_hooks(*self._trio_asyncgen_hooks)
770+
self._trio_asyncgen_hooks = None
771+
741772
# Kill off unprocessed work
742773
self._cancel_fds()
743774
self._cancel_timers()

trio_asyncio/_loop.py

+41
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,45 @@ async def wait_for_sync():
560560
tasks_nursery.cancel_scope.cancel()
561561

562562
finally:
563+
# If we have any async generators left, finalize them before
564+
# closing the event loop. Make sure that the finalizers have a
565+
# chance to actually start before they're exposed to any
566+
# external cancellation, since asyncio doesn't guarantee that
567+
# cancelled tasks have a chance to start first.
568+
569+
asyncgens_done = trio.Event()
570+
if len(loop._asyncgens) == 0:
571+
asyncgens_done.set()
572+
else:
573+
shield_asyncgen_finalizers = trio.CancelScope(shield=True)
574+
575+
async def sentinel():
576+
try:
577+
yield
578+
finally:
579+
try:
580+
# Open-coded asyncio version of loop.synchronize();
581+
# since we closed the tasks_nursery, we can't do
582+
# any more asyncio-to-trio-mode conversions
583+
w = asyncio.Event()
584+
loop.call_soon(w.set)
585+
await w.wait()
586+
finally:
587+
shield_asyncgen_finalizers.shield = False
588+
589+
async def shutdown_asyncgens_from_aio():
590+
agen = sentinel()
591+
await agen.asend(None)
592+
try:
593+
await loop.shutdown_asyncgens()
594+
finally:
595+
asyncgens_done.set()
596+
597+
@loop_nursery.start_soon
598+
async def shutdown_asyncgens_from_trio():
599+
with shield_asyncgen_finalizers:
600+
await loop.run_aio_coroutine(shutdown_asyncgens_from_aio())
601+
563602
if forwarded_cancellation is not None:
564603
# Now that we're outside the shielded tasks_nursery, we can
565604
# add this cancellation to the set of errors propagating out
@@ -570,6 +609,8 @@ async def forward_cancellation():
570609
raise forwarded_cancellation
571610

572611
try:
612+
with trio.CancelScope(shield=True):
613+
await asyncgens_done.wait()
573614
await loop._main_loop_exit()
574615
finally:
575616
loop.close()

0 commit comments

Comments
 (0)