|
3 | 3 | import types
|
4 | 4 | import asyncio
|
5 | 5 | import trio
|
| 6 | +import trio.testing |
6 | 7 | import trio_asyncio
|
7 | 8 | import contextlib
|
| 9 | +import gc |
8 | 10 |
|
9 | 11 |
|
10 | 12 | async def use_asyncio():
|
@@ -203,3 +205,100 @@ async def main():
|
203 | 205 | asyncio.run(main())
|
204 | 206 |
|
205 | 207 | assert scope.value.code == 42
|
| 208 | + |
| 209 | + |
| 210 | +@pytest.mark.trio |
| 211 | +@pytest.mark.parametrize("alive_on_exit", (False, True)) |
| 212 | +@pytest.mark.parametrize("slow_finalizer", (False, True)) |
| 213 | +@pytest.mark.parametrize("loop_timeout", (0, 1, 20)) |
| 214 | +async def test_asyncgens(alive_on_exit, slow_finalizer, loop_timeout, autojump_clock): |
| 215 | + import sniffio |
| 216 | + |
| 217 | + record = set() |
| 218 | + holder = [] |
| 219 | + |
| 220 | + async def agen(label, extra): |
| 221 | + assert sniffio.current_async_library() == label |
| 222 | + if label == "asyncio": |
| 223 | + loop = asyncio.get_running_loop() |
| 224 | + try: |
| 225 | + yield 1 |
| 226 | + finally: |
| 227 | + library = sniffio.current_async_library() |
| 228 | + if label == "asyncio": |
| 229 | + assert loop is asyncio.get_running_loop() |
| 230 | + try: |
| 231 | + await sys.modules[library].sleep(5 if slow_finalizer else 0) |
| 232 | + except (trio.Cancelled, asyncio.CancelledError): |
| 233 | + pass |
| 234 | + record.add((label + extra, library)) |
| 235 | + |
| 236 | + async def iterate_one(label, extra=""): |
| 237 | + ag = agen(label, extra) |
| 238 | + await ag.asend(None) |
| 239 | + if alive_on_exit: |
| 240 | + holder.append(ag) |
| 241 | + else: |
| 242 | + del ag |
| 243 | + |
| 244 | + sys.unraisablehook, prev_hook = sys.__unraisablehook__, sys.unraisablehook |
| 245 | + try: |
| 246 | + start_time = trio.current_time() |
| 247 | + with trio.move_on_after(loop_timeout) as scope: |
| 248 | + if loop_timeout == 0: |
| 249 | + scope.cancel() |
| 250 | + async with trio_asyncio.open_loop() as loop: |
| 251 | + async with trio_asyncio.open_loop() as loop2: |
| 252 | + async with trio.open_nursery() as nursery: |
| 253 | + # Make sure the iterate_one aio tasks don't get |
| 254 | + # cancelled before they start: |
| 255 | + nursery.cancel_scope.shield = True |
| 256 | + try: |
| 257 | + nursery.start_soon(iterate_one, "trio") |
| 258 | + nursery.start_soon( |
| 259 | + loop.run_aio_coroutine, iterate_one("asyncio") |
| 260 | + ) |
| 261 | + nursery.start_soon( |
| 262 | + loop2.run_aio_coroutine, iterate_one("asyncio", "2") |
| 263 | + ) |
| 264 | + await loop.synchronize() |
| 265 | + await loop2.synchronize() |
| 266 | + finally: |
| 267 | + nursery.cancel_scope.shield = False |
| 268 | + if not alive_on_exit and sys.implementation.name == "pypy": |
| 269 | + for _ in range(5): |
| 270 | + gc.collect() |
| 271 | + |
| 272 | + # asyncio agens should be finalized as soon as asyncio loop ends, |
| 273 | + # regardless of liveness |
| 274 | + assert ("asyncio", "asyncio") in record |
| 275 | + assert ("asyncio2", "asyncio") in record |
| 276 | + |
| 277 | + # asyncio agen finalizers should be able to take a cancel |
| 278 | + if (slow_finalizer or loop_timeout == 0) and alive_on_exit: |
| 279 | + # Each loop finalizes in series, and takes 5 seconds |
| 280 | + # if slow_finalizer is true. |
| 281 | + assert trio.current_time() == start_time + min(loop_timeout, 10) |
| 282 | + assert scope.cancelled_caught == (loop_timeout < 10) |
| 283 | + else: |
| 284 | + # `not alive_on_exit` implies that the asyncio agen aclose() tasks |
| 285 | + # are started before loop shutdown, which means they'll be |
| 286 | + # cancelled during loop shutdown; this matches regular asyncio. |
| 287 | + # |
| 288 | + # `not slow_finalizer and loop_timeout > 0` implies that the agens |
| 289 | + # have time to complete before we cancel them. |
| 290 | + assert trio.current_time() == start_time |
| 291 | + assert not scope.cancelled_caught |
| 292 | + |
| 293 | + # trio asyncgen should eventually be finalized in trio mode |
| 294 | + del holder[:] |
| 295 | + for _ in range(5): |
| 296 | + gc.collect() |
| 297 | + await trio.testing.wait_all_tasks_blocked() |
| 298 | + assert record == { |
| 299 | + ("trio", "trio"), |
| 300 | + ("asyncio", "asyncio"), |
| 301 | + ("asyncio2", "asyncio"), |
| 302 | + } |
| 303 | + finally: |
| 304 | + sys.unraisablehook = prev_hook |
0 commit comments