Skip to content

Commit ad75cf7

Browse files
committed
Finalize async generators in the correct context
1 parent 6f2870a commit ad75cf7

File tree

5 files changed

+233
-5
lines changed

5 files changed

+233
-5
lines changed

newsfragments/92.feature.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
trio-asyncio now properly finalizes asyncio-flavored async generators
2+
upon closure of the event loop. Previously, Trio's async generator finalizers
3+
would try to finalize all async generators in Trio mode, regardless of their
4+
flavor, which could lead to spurious errors.

tests/conftest.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
@pytest.fixture
1313
async def loop():
1414
async with trio_asyncio.open_loop() as loop:
15-
try:
16-
yield loop
17-
finally:
18-
await loop.stop().wait()
15+
yield loop
1916

2017

2118
# auto-trio-ize all async functions

tests/test_trio_asyncio.py

+99
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import types
44
import asyncio
55
import trio
6+
import trio.testing
67
import trio_asyncio
78
import contextlib
9+
import gc
810

911

1012
async def use_asyncio():
@@ -203,3 +205,100 @@ async def main():
203205
asyncio.run(main())
204206

205207
assert scope.value.code == 42
208+
209+
210+
@pytest.mark.trio
211+
@pytest.mark.parametrize("alive_on_exit", (False, True))
212+
@pytest.mark.parametrize("slow_finalizer", (False, True))
213+
@pytest.mark.parametrize("loop_timeout", (0, 1, 20))
214+
async def test_asyncgens(alive_on_exit, slow_finalizer, loop_timeout, autojump_clock):
215+
import sniffio
216+
217+
record = set()
218+
holder = []
219+
220+
async def agen(label, extra):
221+
assert sniffio.current_async_library() == label
222+
if label == "asyncio":
223+
loop = asyncio.get_running_loop()
224+
try:
225+
yield 1
226+
finally:
227+
library = sniffio.current_async_library()
228+
if label == "asyncio":
229+
assert loop is asyncio.get_running_loop()
230+
try:
231+
await sys.modules[library].sleep(5 if slow_finalizer else 0)
232+
except (trio.Cancelled, asyncio.CancelledError):
233+
pass
234+
record.add((label + extra, library))
235+
236+
async def iterate_one(label, extra=""):
237+
ag = agen(label, extra)
238+
await ag.asend(None)
239+
if alive_on_exit:
240+
holder.append(ag)
241+
else:
242+
del ag
243+
244+
sys.unraisablehook, prev_hook = sys.__unraisablehook__, sys.unraisablehook
245+
try:
246+
start_time = trio.current_time()
247+
with trio.move_on_after(loop_timeout) as scope:
248+
if loop_timeout == 0:
249+
scope.cancel()
250+
async with trio_asyncio.open_loop() as loop:
251+
async with trio_asyncio.open_loop() as loop2:
252+
async with trio.open_nursery() as nursery:
253+
# Make sure the iterate_one aio tasks don't get
254+
# cancelled before they start:
255+
nursery.cancel_scope.shield = True
256+
try:
257+
nursery.start_soon(iterate_one, "trio")
258+
nursery.start_soon(
259+
loop.run_aio_coroutine, iterate_one("asyncio")
260+
)
261+
nursery.start_soon(
262+
loop2.run_aio_coroutine, iterate_one("asyncio", "2")
263+
)
264+
await loop.synchronize()
265+
await loop2.synchronize()
266+
finally:
267+
nursery.cancel_scope.shield = False
268+
if not alive_on_exit and sys.implementation.name == "pypy":
269+
for _ in range(5):
270+
gc.collect()
271+
272+
# asyncio agens should be finalized as soon as asyncio loop ends,
273+
# regardless of liveness
274+
assert ("asyncio", "asyncio") in record
275+
assert ("asyncio2", "asyncio") in record
276+
277+
# asyncio agen finalizers should be able to take a cancel
278+
if (slow_finalizer or loop_timeout == 0) and alive_on_exit:
279+
# Each loop finalizes in series, and takes 5 seconds
280+
# if slow_finalizer is true.
281+
assert trio.current_time() == start_time + min(loop_timeout, 10)
282+
assert scope.cancelled_caught == (loop_timeout < 10)
283+
else:
284+
# `not alive_on_exit` implies that the asyncio agen aclose() tasks
285+
# are started before loop shutdown, which means they'll be
286+
# cancelled during loop shutdown; this matches regular asyncio.
287+
#
288+
# `not slow_finalizer and loop_timeout > 0` implies that the agens
289+
# have time to complete before we cancel them.
290+
assert trio.current_time() == start_time
291+
assert not scope.cancelled_caught
292+
293+
# trio asyncgen should eventually be finalized in trio mode
294+
del holder[:]
295+
for _ in range(5):
296+
gc.collect()
297+
await trio.testing.wait_all_tasks_blocked()
298+
assert record == {
299+
("trio", "trio"),
300+
("asyncio", "asyncio"),
301+
("asyncio2", "asyncio"),
302+
}
303+
finally:
304+
sys.unraisablehook = prev_hook

trio_asyncio/_base.py

+74
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,71 @@ def shutdown(self, wait=None):
9191
self._running = False
9292

9393

94+
class AsyncGeneratorDispatcher:
95+
"""Helper object providing async generator hooks that route
96+
finalization to either the correct trio-asyncio event loop or the
97+
outer Trio run, depending on where the generator was first iterated.
98+
"""
99+
100+
def __init__(self, prev_hooks):
101+
self.prev_hooks = prev_hooks
102+
self.refcnt = 1
103+
104+
@classmethod
105+
def install(cls):
106+
current_hooks = sys.get_asyncgen_hooks()
107+
108+
# These hooks should either be our own AsyncGeneratorDispatcher
109+
# (for another trio-asyncio loop) or Trio's hooks. Both of those
110+
# provide both hooks.
111+
assert current_hooks.firstiter is not None
112+
assert current_hooks.finalizer is not None
113+
114+
matches = (
115+
getattr(current_hooks.firstiter, "__func__", None) is cls.firstiter
116+
) + (getattr(current_hooks.finalizer, "__func__", None) is cls.finalizer)
117+
if matches == 0:
118+
# Create a new dispatcher that forwards non-trio-asyncio asyncgens
119+
# to the current_hooks
120+
dispatcher = cls(prev_hooks=current_hooks)
121+
sys.set_asyncgen_hooks(
122+
firstiter=dispatcher.firstiter, finalizer=dispatcher.finalizer
123+
)
124+
else:
125+
# Take a new reference to the dispatcher that the current_hooks
126+
# refer to
127+
assert matches == 2
128+
dispatcher = current_hooks.firstiter.__self__
129+
assert dispatcher is current_hooks.finalizer.__self__
130+
assert isinstance(dispatcher, cls)
131+
dispatcher.refcnt += 1
132+
return dispatcher
133+
134+
def uninstall(self):
135+
self.refcnt -= 1
136+
if self.refcnt <= 0:
137+
sys.set_asyncgen_hooks(*self.prev_hooks)
138+
assert self.refcnt == 0
139+
140+
def firstiter(self, agen):
141+
if sniffio_library.name == "asyncio":
142+
loop = asyncio.get_running_loop()
143+
agen.ag_frame.f_locals["@trio_asyncio_loop"] = loop
144+
return loop._asyncgen_firstiter_hook(agen)
145+
else:
146+
return self.prev_hooks.firstiter(agen)
147+
148+
def finalizer(self, agen):
149+
try:
150+
loop = agen.ag_frame.f_locals.get("@trio_asyncio_loop")
151+
except AttributeError: # pragma: no cover
152+
loop = None
153+
if loop is not None:
154+
return loop._asyncgen_finalizer_hook(agen)
155+
else:
156+
return self.prev_hooks.finalizer(agen)
157+
158+
94159
class BaseTrioEventLoop(asyncio.SelectorEventLoop):
95160
"""An asyncio event loop that runs on top of Trio.
96161
@@ -135,6 +200,10 @@ class BaseTrioEventLoop(asyncio.SelectorEventLoop):
135200
# (threading) Thread this loop is running in
136201
_thread = None
137202

203+
# An instance of AsyncGeneratorDispatcher for handling asyncio async
204+
# generators; it may be shared by multiple running trio-asyncio loops
205+
_asyncgen_dispatcher = None
206+
138207
def __init__(self, queue_len=None):
139208
if queue_len is None:
140209
queue_len = math.inf
@@ -629,6 +698,7 @@ async def _main_loop_init(self, nursery):
629698
self._nursery = nursery
630699
self._task = trio.lowlevel.current_task()
631700
self._token = trio.lowlevel.current_trio_token()
701+
self._asyncgen_dispatcher = AsyncGeneratorDispatcher.install()
632702

633703
async def _main_loop(self, task_status=trio.TASK_STATUS_IGNORED):
634704
"""Run the loop by processing its event queue.
@@ -738,6 +808,10 @@ async def _main_loop_exit(self):
738808
except TrioAsyncioExit:
739809
pass
740810

811+
# Restore previous async generator hooks
812+
self._asyncgen_dispatcher.uninstall()
813+
self._asyncgen_dispatcher = None
814+
741815
# Kill off unprocessed work
742816
self._cancel_fds()
743817
self._cancel_timers()

trio_asyncio/_loop.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import trio
77
import asyncio
8+
import warnings
89
import threading
910
from contextvars import ContextVar
1011
from contextlib import asynccontextmanager
@@ -560,6 +561,49 @@ async def wait_for_sync():
560561
tasks_nursery.cancel_scope.cancel()
561562

562563
finally:
564+
# If we have any async generators left, finalize them before
565+
# closing the event loop. Make sure that the finalizers have a
566+
# chance to actually start before they're exposed to any
567+
# external cancellation, since asyncio doesn't guarantee that
568+
# cancelled tasks have a chance to start first.
569+
570+
asyncgens_done = trio.Event()
571+
should_warn = False
572+
if len(loop._asyncgens) == 0:
573+
asyncgens_done.set()
574+
elif not loop.is_running():
575+
asyncgens_done.set()
576+
should_warn = True
577+
else:
578+
shield_asyncgen_finalizers = trio.CancelScope(shield=True)
579+
580+
async def sentinel():
581+
try:
582+
yield
583+
finally:
584+
try:
585+
# Open-coded asyncio version of loop.synchronize();
586+
# since we closed the tasks_nursery, we can't do
587+
# any more asyncio-to-trio-mode conversions
588+
w = asyncio.Event()
589+
loop.call_soon(w.set)
590+
await w.wait()
591+
finally:
592+
shield_asyncgen_finalizers.shield = False
593+
594+
async def shutdown_asyncgens_from_aio():
595+
agen = sentinel()
596+
await agen.asend(None)
597+
try:
598+
await loop.shutdown_asyncgens()
599+
finally:
600+
asyncgens_done.set()
601+
602+
@loop_nursery.start_soon
603+
async def shutdown_asyncgens_from_trio():
604+
with shield_asyncgen_finalizers:
605+
await loop.run_aio_coroutine(shutdown_asyncgens_from_aio())
606+
563607
if forwarded_cancellation is not None:
564608
# Now that we're outside the shielded tasks_nursery, we can
565609
# add this cancellation to the set of errors propagating out
@@ -570,7 +614,17 @@ async def forward_cancellation():
570614
raise forwarded_cancellation
571615

572616
try:
573-
await loop._main_loop_exit()
617+
try:
618+
if should_warn:
619+
warnings.warn(
620+
"trio-asyncio loop was stopped before its async "
621+
"generators were finalized; weird stuff might happen",
622+
RuntimeWarning,
623+
)
624+
finally:
625+
with trio.CancelScope(shield=True):
626+
await asyncgens_done.wait()
627+
await loop._main_loop_exit()
574628
finally:
575629
loop.close()
576630
current_loop.reset(old_loop)

0 commit comments

Comments
 (0)