Skip to content

Commit b8ac94e

Browse files
authoredJan 30, 2025··
rt: add before and after task poll callbacks (tokio-rs#7120)
Add callbacks for poll start and stop, enabling users to instrument these points in the runtime's life cycle.
1 parent 5086e56 commit b8ac94e

File tree

8 files changed

+347
-17
lines changed

8 files changed

+347
-17
lines changed
 

‎tokio/src/runtime/builder.rs

+111
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ pub struct Builder {
8888
/// To run before each task is spawned.
8989
pub(super) before_spawn: Option<TaskCallback>,
9090

91+
/// To run before each poll
92+
#[cfg(tokio_unstable)]
93+
pub(super) before_poll: Option<TaskCallback>,
94+
95+
/// To run after each poll
96+
#[cfg(tokio_unstable)]
97+
pub(super) after_poll: Option<TaskCallback>,
98+
9199
/// To run after each task is terminated.
92100
pub(super) after_termination: Option<TaskCallback>,
93101

@@ -306,6 +314,11 @@ impl Builder {
306314
before_spawn: None,
307315
after_termination: None,
308316

317+
#[cfg(tokio_unstable)]
318+
before_poll: None,
319+
#[cfg(tokio_unstable)]
320+
after_poll: None,
321+
309322
keep_alive: None,
310323

311324
// Defaults for these values depend on the scheduler kind, so we get them
@@ -743,6 +756,92 @@ impl Builder {
743756
self
744757
}
745758

759+
/// Executes function `f` just before a task is polled
760+
///
761+
/// `f` is called within the Tokio context, so functions like
762+
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
763+
/// invoked immediately.
764+
///
765+
/// **Note**: This is an [unstable API][unstable]. The public API of this type
766+
/// may break in 1.x releases. See [the documentation on unstable
767+
/// features][unstable] for details.
768+
///
769+
/// [unstable]: crate#unstable-features
770+
///
771+
/// # Examples
772+
///
773+
/// ```
774+
/// # use std::sync::{atomic::AtomicUsize, Arc};
775+
/// # use tokio::task::yield_now;
776+
/// # pub fn main() {
777+
/// let poll_start_counter = Arc::new(AtomicUsize::new(0));
778+
/// let poll_start = poll_start_counter.clone();
779+
/// let rt = tokio::runtime::Builder::new_multi_thread()
780+
/// .enable_all()
781+
/// .on_before_task_poll(move |meta| {
782+
/// println!("task {} is about to be polled", meta.id())
783+
/// })
784+
/// .build()
785+
/// .unwrap();
786+
/// let task = rt.spawn(async {
787+
/// yield_now().await;
788+
/// });
789+
/// let _ = rt.block_on(task);
790+
///
791+
/// # }
792+
/// ```
793+
#[cfg(tokio_unstable)]
794+
pub fn on_before_task_poll<F>(&mut self, f: F) -> &mut Self
795+
where
796+
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
797+
{
798+
self.before_poll = Some(std::sync::Arc::new(f));
799+
self
800+
}
801+
802+
/// Executes function `f` just after a task is polled
803+
///
804+
/// `f` is called within the Tokio context, so functions like
805+
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
806+
/// invoked immediately.
807+
///
808+
/// **Note**: This is an [unstable API][unstable]. The public API of this type
809+
/// may break in 1.x releases. See [the documentation on unstable
810+
/// features][unstable] for details.
811+
///
812+
/// [unstable]: crate#unstable-features
813+
///
814+
/// # Examples
815+
///
816+
/// ```
817+
/// # use std::sync::{atomic::AtomicUsize, Arc};
818+
/// # use tokio::task::yield_now;
819+
/// # pub fn main() {
820+
/// let poll_stop_counter = Arc::new(AtomicUsize::new(0));
821+
/// let poll_stop = poll_stop_counter.clone();
822+
/// let rt = tokio::runtime::Builder::new_multi_thread()
823+
/// .enable_all()
824+
/// .on_after_task_poll(move |meta| {
825+
/// println!("task {} completed polling", meta.id());
826+
/// })
827+
/// .build()
828+
/// .unwrap();
829+
/// let task = rt.spawn(async {
830+
/// yield_now().await;
831+
/// });
832+
/// let _ = rt.block_on(task);
833+
///
834+
/// # }
835+
/// ```
836+
#[cfg(tokio_unstable)]
837+
pub fn on_after_task_poll<F>(&mut self, f: F) -> &mut Self
838+
where
839+
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
840+
{
841+
self.after_poll = Some(std::sync::Arc::new(f));
842+
self
843+
}
844+
746845
/// Executes function `f` just after a task is terminated.
747846
///
748847
/// `f` is called within the Tokio context, so functions like
@@ -1410,6 +1509,10 @@ impl Builder {
14101509
before_park: self.before_park.clone(),
14111510
after_unpark: self.after_unpark.clone(),
14121511
before_spawn: self.before_spawn.clone(),
1512+
#[cfg(tokio_unstable)]
1513+
before_poll: self.before_poll.clone(),
1514+
#[cfg(tokio_unstable)]
1515+
after_poll: self.after_poll.clone(),
14131516
after_termination: self.after_termination.clone(),
14141517
global_queue_interval: self.global_queue_interval,
14151518
event_interval: self.event_interval,
@@ -1560,6 +1663,10 @@ cfg_rt_multi_thread! {
15601663
before_park: self.before_park.clone(),
15611664
after_unpark: self.after_unpark.clone(),
15621665
before_spawn: self.before_spawn.clone(),
1666+
#[cfg(tokio_unstable)]
1667+
before_poll: self.before_poll.clone(),
1668+
#[cfg(tokio_unstable)]
1669+
after_poll: self.after_poll.clone(),
15631670
after_termination: self.after_termination.clone(),
15641671
global_queue_interval: self.global_queue_interval,
15651672
event_interval: self.event_interval,
@@ -1610,6 +1717,10 @@ cfg_rt_multi_thread! {
16101717
after_unpark: self.after_unpark.clone(),
16111718
before_spawn: self.before_spawn.clone(),
16121719
after_termination: self.after_termination.clone(),
1720+
#[cfg(tokio_unstable)]
1721+
before_poll: self.before_poll.clone(),
1722+
#[cfg(tokio_unstable)]
1723+
after_poll: self.after_poll.clone(),
16131724
global_queue_interval: self.global_queue_interval,
16141725
event_interval: self.event_interval,
16151726
local_queue_capacity: self.local_queue_capacity,

‎tokio/src/runtime/config.rs

+8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ pub(crate) struct Config {
2727
/// To run after each task is terminated.
2828
pub(crate) after_termination: Option<TaskCallback>,
2929

30+
/// To run before each poll
31+
#[cfg(tokio_unstable)]
32+
pub(crate) before_poll: Option<TaskCallback>,
33+
34+
/// To run after each poll
35+
#[cfg(tokio_unstable)]
36+
pub(crate) after_poll: Option<TaskCallback>,
37+
3038
/// The multi-threaded scheduler includes a per-worker LIFO slot used to
3139
/// store the last scheduled task. This can improve certain usage patterns,
3240
/// especially message passing between tasks. However, this LIFO slot is not

‎tokio/src/runtime/scheduler/current_thread/mod.rs

+13
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ impl CurrentThread {
145145
task_hooks: TaskHooks {
146146
task_spawn_callback: config.before_spawn.clone(),
147147
task_terminate_callback: config.after_termination.clone(),
148+
#[cfg(tokio_unstable)]
149+
before_poll_callback: config.before_poll.clone(),
150+
#[cfg(tokio_unstable)]
151+
after_poll_callback: config.after_poll.clone(),
148152
},
149153
shared: Shared {
150154
inject: Inject::new(),
@@ -766,8 +770,17 @@ impl CoreGuard<'_> {
766770

767771
let task = context.handle.shared.owned.assert_owner(task);
768772

773+
#[cfg(tokio_unstable)]
774+
let task_id = task.task_id();
775+
769776
let (c, ()) = context.run_task(core, || {
777+
#[cfg(tokio_unstable)]
778+
context.handle.task_hooks.poll_start_callback(task_id);
779+
770780
task.run();
781+
782+
#[cfg(tokio_unstable)]
783+
context.handle.task_hooks.poll_stop_callback(task_id);
771784
});
772785

773786
core = c;

‎tokio/src/runtime/scheduler/multi_thread/worker.rs

+23-4
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ pub(super) fn create(
282282

283283
let remotes_len = remotes.len();
284284
let handle = Arc::new(Handle {
285-
task_hooks: TaskHooks {
286-
task_spawn_callback: config.before_spawn.clone(),
287-
task_terminate_callback: config.after_termination.clone(),
288-
},
285+
task_hooks: TaskHooks::from_config(&config),
289286
shared: Shared {
290287
remotes: remotes.into_boxed_slice(),
291288
inject,
@@ -574,6 +571,9 @@ impl Context {
574571
}
575572

576573
fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult {
574+
#[cfg(tokio_unstable)]
575+
let task_id = task.task_id();
576+
577577
let task = self.worker.handle.shared.owned.assert_owner(task);
578578

579579
// Make sure the worker is not in the **searching** state. This enables
@@ -593,7 +593,16 @@ impl Context {
593593

594594
// Run the task
595595
coop::budget(|| {
596+
// Unlike the poll time above, poll start callback is attached to the task id,
597+
// so it is tightly associated with the actual poll invocation.
598+
#[cfg(tokio_unstable)]
599+
self.worker.handle.task_hooks.poll_start_callback(task_id);
600+
596601
task.run();
602+
603+
#[cfg(tokio_unstable)]
604+
self.worker.handle.task_hooks.poll_stop_callback(task_id);
605+
597606
let mut lifo_polls = 0;
598607

599608
// As long as there is budget remaining and a task exists in the
@@ -656,7 +665,17 @@ impl Context {
656665
// Run the LIFO task, then loop
657666
*self.core.borrow_mut() = Some(core);
658667
let task = self.worker.handle.shared.owned.assert_owner(task);
668+
669+
#[cfg(tokio_unstable)]
670+
let task_id = task.task_id();
671+
672+
#[cfg(tokio_unstable)]
673+
self.worker.handle.task_hooks.poll_start_callback(task_id);
674+
659675
task.run();
676+
677+
#[cfg(tokio_unstable)]
678+
self.worker.handle.task_hooks.poll_stop_callback(task_id);
660679
}
661680
})
662681
}

‎tokio/src/runtime/scheduler/multi_thread_alt/worker.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,7 @@ pub(super) fn create(
303303
let (inject, inject_synced) = inject::Shared::new();
304304

305305
let handle = Arc::new(Handle {
306-
task_hooks: TaskHooks {
307-
task_spawn_callback: config.before_spawn.clone(),
308-
task_terminate_callback: config.after_termination.clone(),
309-
},
306+
task_hooks: TaskHooks::from_config(&config),
310307
shared: Shared {
311308
remotes: remotes.into_boxed_slice(),
312309
inject,

‎tokio/src/runtime/task/mod.rs

+23-9
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,13 @@ pub(crate) struct LocalNotified<S: 'static> {
256256
_not_send: PhantomData<*const ()>,
257257
}
258258

259+
impl<S> LocalNotified<S> {
260+
#[cfg(tokio_unstable)]
261+
pub(crate) fn task_id(&self) -> Id {
262+
self.task.id()
263+
}
264+
}
265+
259266
/// A task that is not owned by any `OwnedTasks`. Used for blocking tasks.
260267
/// This type holds two ref-counts.
261268
pub(crate) struct UnownedTask<S: 'static> {
@@ -386,6 +393,16 @@ impl<S: 'static> Task<S> {
386393
self.raw.header_ptr()
387394
}
388395

396+
/// Returns a [task ID] that uniquely identifies this task relative to other
397+
/// currently spawned tasks.
398+
///
399+
/// [task ID]: crate::task::Id
400+
#[cfg(tokio_unstable)]
401+
pub(crate) fn id(&self) -> crate::task::Id {
402+
// Safety: The header pointer is valid.
403+
unsafe { Header::get_id(self.raw.header_ptr()) }
404+
}
405+
389406
cfg_taskdump! {
390407
/// Notify the task for task dumping.
391408
///
@@ -400,22 +417,19 @@ impl<S: 'static> Task<S> {
400417
}
401418
}
402419

403-
/// Returns a [task ID] that uniquely identifies this task relative to other
404-
/// currently spawned tasks.
405-
///
406-
/// [task ID]: crate::task::Id
407-
#[cfg(tokio_unstable)]
408-
pub(crate) fn id(&self) -> crate::task::Id {
409-
// Safety: The header pointer is valid.
410-
unsafe { Header::get_id(self.raw.header_ptr()) }
411-
}
412420
}
413421
}
414422

415423
impl<S: 'static> Notified<S> {
416424
fn header(&self) -> &Header {
417425
self.0.header()
418426
}
427+
428+
#[cfg(tokio_unstable)]
429+
#[allow(dead_code)]
430+
pub(crate) fn task_id(&self) -> crate::task::Id {
431+
self.0.id()
432+
}
419433
}
420434

421435
impl<S: 'static> Notified<S> {

‎tokio/src/runtime/task_hooks.rs

+40
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,57 @@
11
use std::marker::PhantomData;
22

3+
use super::Config;
4+
35
impl TaskHooks {
46
pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) {
57
if let Some(f) = self.task_spawn_callback.as_ref() {
68
f(meta)
79
}
810
}
11+
12+
#[allow(dead_code)]
13+
pub(crate) fn from_config(config: &Config) -> Self {
14+
Self {
15+
task_spawn_callback: config.before_spawn.clone(),
16+
task_terminate_callback: config.after_termination.clone(),
17+
#[cfg(tokio_unstable)]
18+
before_poll_callback: config.before_poll.clone(),
19+
#[cfg(tokio_unstable)]
20+
after_poll_callback: config.after_poll.clone(),
21+
}
22+
}
23+
24+
#[cfg(tokio_unstable)]
25+
#[inline]
26+
pub(crate) fn poll_start_callback(&self, id: super::task::Id) {
27+
if let Some(poll_start) = &self.before_poll_callback {
28+
(poll_start)(&TaskMeta {
29+
id,
30+
_phantom: std::marker::PhantomData,
31+
})
32+
}
33+
}
34+
35+
#[cfg(tokio_unstable)]
36+
#[inline]
37+
pub(crate) fn poll_stop_callback(&self, id: super::task::Id) {
38+
if let Some(poll_stop) = &self.after_poll_callback {
39+
(poll_stop)(&TaskMeta {
40+
id,
41+
_phantom: std::marker::PhantomData,
42+
})
43+
}
44+
}
945
}
1046

1147
#[derive(Clone)]
1248
pub(crate) struct TaskHooks {
1349
pub(crate) task_spawn_callback: Option<TaskCallback>,
1450
pub(crate) task_terminate_callback: Option<TaskCallback>,
51+
#[cfg(tokio_unstable)]
52+
pub(crate) before_poll_callback: Option<TaskCallback>,
53+
#[cfg(tokio_unstable)]
54+
pub(crate) after_poll_callback: Option<TaskCallback>,
1555
}
1656

1757
/// Task metadata supplied to user-provided hooks for task events.

‎tokio/tests/rt_poll_callbacks.rs

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#![allow(unknown_lints, unexpected_cfgs)]
2+
#![cfg(tokio_unstable)]
3+
4+
use std::sync::{atomic::AtomicUsize, Arc, Mutex};
5+
6+
use tokio::task::yield_now;
7+
8+
#[cfg(not(target_os = "wasi"))]
9+
#[test]
10+
fn callbacks_fire_multi_thread() {
11+
let poll_start_counter = Arc::new(AtomicUsize::new(0));
12+
let poll_stop_counter = Arc::new(AtomicUsize::new(0));
13+
let poll_start = poll_start_counter.clone();
14+
let poll_stop = poll_stop_counter.clone();
15+
16+
let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
17+
Arc::new(Mutex::new(None));
18+
let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
19+
Arc::new(Mutex::new(None));
20+
21+
let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
22+
let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
23+
let rt = tokio::runtime::Builder::new_multi_thread()
24+
.enable_all()
25+
.on_before_task_poll(move |task_meta| {
26+
before_task_poll_callback_task_id_ref
27+
.lock()
28+
.unwrap()
29+
.replace(task_meta.id());
30+
poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
31+
})
32+
.on_after_task_poll(move |task_meta| {
33+
after_task_poll_callback_task_id_ref
34+
.lock()
35+
.unwrap()
36+
.replace(task_meta.id());
37+
poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
38+
})
39+
.build()
40+
.unwrap();
41+
let task = rt.spawn(async {
42+
yield_now().await;
43+
yield_now().await;
44+
yield_now().await;
45+
});
46+
47+
let spawned_task_id = task.id();
48+
49+
rt.block_on(task).expect("task should succeed");
50+
// We need to drop the runtime to guarantee the workers have exited (and thus called the callback)
51+
drop(rt);
52+
53+
assert_eq!(
54+
before_task_poll_callback_task_id.lock().unwrap().unwrap(),
55+
spawned_task_id
56+
);
57+
assert_eq!(
58+
after_task_poll_callback_task_id.lock().unwrap().unwrap(),
59+
spawned_task_id
60+
);
61+
let actual_count = 4;
62+
assert_eq!(
63+
poll_start.load(std::sync::atomic::Ordering::Relaxed),
64+
actual_count,
65+
"unexpected number of poll starts"
66+
);
67+
assert_eq!(
68+
poll_stop.load(std::sync::atomic::Ordering::Relaxed),
69+
actual_count,
70+
"unexpected number of poll stops"
71+
);
72+
}
73+
74+
#[test]
75+
fn callbacks_fire_current_thread() {
76+
let poll_start_counter = Arc::new(AtomicUsize::new(0));
77+
let poll_stop_counter = Arc::new(AtomicUsize::new(0));
78+
let poll_start = poll_start_counter.clone();
79+
let poll_stop = poll_stop_counter.clone();
80+
81+
let before_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
82+
Arc::new(Mutex::new(None));
83+
let after_task_poll_callback_task_id: Arc<Mutex<Option<tokio::task::Id>>> =
84+
Arc::new(Mutex::new(None));
85+
86+
let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id);
87+
let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id);
88+
let rt = tokio::runtime::Builder::new_current_thread()
89+
.enable_all()
90+
.on_before_task_poll(move |task_meta| {
91+
before_task_poll_callback_task_id_ref
92+
.lock()
93+
.unwrap()
94+
.replace(task_meta.id());
95+
poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
96+
})
97+
.on_after_task_poll(move |task_meta| {
98+
after_task_poll_callback_task_id_ref
99+
.lock()
100+
.unwrap()
101+
.replace(task_meta.id());
102+
poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
103+
})
104+
.build()
105+
.unwrap();
106+
107+
let task = rt.spawn(async {
108+
yield_now().await;
109+
yield_now().await;
110+
yield_now().await;
111+
});
112+
113+
let spawned_task_id = task.id();
114+
115+
let _ = rt.block_on(task);
116+
drop(rt);
117+
118+
assert_eq!(
119+
before_task_poll_callback_task_id.lock().unwrap().unwrap(),
120+
spawned_task_id
121+
);
122+
assert_eq!(
123+
after_task_poll_callback_task_id.lock().unwrap().unwrap(),
124+
spawned_task_id
125+
);
126+
assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4);
127+
assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4);
128+
}

0 commit comments

Comments
 (0)
Please sign in to comment.