Skip to content

Commit

Permalink
Add trio.testing.wait_all_threads_completed
Browse files Browse the repository at this point in the history
This is the equivalent of trio.testing.wait_all_tasks_blocked but for
threads managed by trio. This is useful when writing tests that use
to_thread
  • Loading branch information
VincentVanlaer committed Jan 25, 2024
1 parent 556df86 commit 1faf1c8
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 1 deletion.
4 changes: 4 additions & 0 deletions docs/source/reference-testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Inter-task ordering

.. autofunction:: wait_all_tasks_blocked

.. autofunction:: wait_all_threads_completed

.. autofunction:: active_thread_count


.. _testing-streams:

Expand Down
1 change: 1 addition & 0 deletions newsfragments/2937.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `trio.testing.wait_all_threads_completed`, which blocks until no threads are running tasks. This is intended to be used in the same way as `trio.testing.wait_all_tasks_blocked`.
40 changes: 40 additions & 0 deletions src/trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from_thread_run,
from_thread_run_sync,
to_thread_run_sync,
wait_all_threads_completed,
)
from ..testing import wait_all_tasks_blocked

Expand Down Expand Up @@ -1106,3 +1107,42 @@ async def test_cancellable_warns() -> None:

with pytest.warns(TrioDeprecationWarning):
await to_thread_run_sync(bool, cancellable=True)


async def test_wait_all_threads_completed() -> None:
no_threads_left = False
e1 = Event()
e2 = Event()

e1_exited = Event()
e2_exited = Event()

async def wait_event(e: Event, e_exit: Event) -> None:
def thread() -> None:
from_thread_run(e.wait)

await to_thread_run_sync(thread)
e_exit.set()

async def wait_no_threads_left() -> None:
nonlocal no_threads_left
await wait_all_threads_completed()
no_threads_left = True

async with _core.open_nursery() as nursery:
nursery.start_soon(wait_event, e1, e1_exited)
nursery.start_soon(wait_event, e2, e2_exited)
await wait_all_tasks_blocked()
nursery.start_soon(wait_no_threads_left)
await wait_all_tasks_blocked()
assert not no_threads_left

e1.set()
await e1_exited.wait()
await wait_all_tasks_blocked()
assert not no_threads_left

e2.set()
await e2_exited.wait()
await wait_all_tasks_blocked()
assert no_threads_left
80 changes: 79 additions & 1 deletion src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

from ._core import (
RunVar,
TrioInternalError,
TrioToken,
checkpoint,
disable_ki_protection,
enable_ki_protection,
start_thread_soon,
)
from ._deprecate import warn_deprecated
from ._sync import CapacityLimiter
from ._sync import CapacityLimiter, Event
from ._util import coroutine_or_error

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,6 +54,79 @@ class _ParentTaskData(threading.local):
_thread_counter = count()


class _ActiveThreadCount:
count: int
event: Event

def __init__(self) -> None:
self.count = 0
self.event = Event()


_active_threads_local: RunVar[_ActiveThreadCount] = RunVar("active_threads")


def _increment_active_threads() -> None:
try:
active_threads_local = _active_threads_local.get()
except LookupError:
active_threads_local = _ActiveThreadCount()
_active_threads_local.set(active_threads_local)

active_threads_local.count += 1


def _decrement_active_threads() -> None:
try:
active_threads_local = _active_threads_local.get()
active_threads_local.count -= 1
if active_threads_local.count == 0:
active_threads_local.event.set()
active_threads_local.event = Event()
except LookupError as e:
raise TrioInternalError(

Check warning on line 87 in src/trio/_threads.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_threads.py#L86-L87

Added lines #L86 - L87 were not covered by tests
"Tried to decrement active threads while _active_threads_local is unset"
) from e


async def wait_all_threads_completed() -> None:
"""Wait until no threads are still running tasks.
This is intended to be used when testing code with trio.to_thread to
make sure no tasks are still making progress in a thread. See the
following code for a usage example::
async def wait_all_settled():
while True:
await trio.testing.wait_all_threads_complete()
await trio.testing.wait_all_tasks_blocked()
if trio.testing.active_thread_count() == 0:
break
"""

await checkpoint()

try:
active_threads_local = _active_threads_local.get()
except LookupError:
active_threads_local = _ActiveThreadCount()
_active_threads_local.set(active_threads_local)

Check warning on line 113 in src/trio/_threads.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_threads.py#L111-L113

Added lines #L111 - L113 were not covered by tests

while active_threads_local.count != 0:
await active_threads_local.event.wait()


def active_thread_count() -> int:
"""Returns the number of threads that are currently running a task
See `trio.testing.wait_all_threads_completed`
"""
try:
return _active_threads_local.get().count
except LookupError:
return 0

Check warning on line 127 in src/trio/_threads.py

View check run for this annotation

Codecov / codecov/patch

src/trio/_threads.py#L124-L127

Added lines #L124 - L127 were not covered by tests


def current_default_thread_limiter() -> CapacityLimiter:
"""Get the default `~trio.CapacityLimiter` used by
`trio.to_thread.run_sync`.
Expand Down Expand Up @@ -377,6 +452,7 @@ def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None:
await limiter.acquire_on_behalf_of(placeholder)
try:
start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name)
_increment_active_threads()
except:
limiter.release_on_behalf_of(placeholder)
raise
Expand All @@ -397,12 +473,14 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
object
] = await trio.lowlevel.wait_task_rescheduled(abort)
if isinstance(msg_from_thread, outcome.Outcome):
_decrement_active_threads()
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
_decrement_active_threads()
raise TypeError(
"trio.to_thread.run_sync received unrecognized thread message {!r}."
"".format(msg_from_thread)
Expand Down
4 changes: 4 additions & 0 deletions src/trio/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
MockClock as MockClock,
wait_all_tasks_blocked as wait_all_tasks_blocked,
)
from .._threads import (
active_thread_count as active_thread_count,
wait_all_threads_completed as wait_all_threads_completed,
)
from .._util import fixup_module_metadata
from ._check_streams import (
check_half_closeable_stream as check_half_closeable_stream,
Expand Down

0 comments on commit 1faf1c8

Please sign in to comment.