Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rt: implement initial set of task hooks #6742

Merged
merged 12 commits into from
Aug 27, 2024
12 changes: 11 additions & 1 deletion tokio/src/runtime/blocking/schedule.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[cfg(feature = "test-util")]
use crate::runtime::scheduler;
use crate::runtime::task::{self, Task};
use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks};
use crate::runtime::Handle;

/// `task::Schedule` implementation that does nothing (except some bookkeeping
Expand All @@ -12,6 +12,7 @@ use crate::runtime::Handle;
pub(crate) struct BlockingSchedule {
#[cfg(feature = "test-util")]
handle: Handle,
hooks: TaskHarnessScheduleHooks,
}

impl BlockingSchedule {
Expand All @@ -32,6 +33,9 @@ impl BlockingSchedule {
BlockingSchedule {
#[cfg(feature = "test-util")]
handle: handle.clone(),
hooks: TaskHarnessScheduleHooks {
task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(),
},
}
}
}
Expand All @@ -57,4 +61,10 @@ impl task::Schedule for BlockingSchedule {
fn schedule(&self, _task: task::Notified<Self>) {
unreachable!();
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.hooks.task_terminate_callback.clone(),
}
}
}
106 changes: 105 additions & 1 deletion tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#![cfg_attr(loom, allow(unused_imports))]

use crate::runtime::handle::Handle;
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime};
#[cfg(tokio_unstable)]
use crate::runtime::TaskMeta;
use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback};
use crate::util::rand::{RngSeed, RngSeedGenerator};

use std::fmt;
Expand Down Expand Up @@ -78,6 +82,12 @@ pub struct Builder {
/// To run after each thread is unparked.
pub(super) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(super) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(super) after_termination: Option<TaskCallback>,

/// Customizable keep alive timeout for `BlockingPool`
pub(super) keep_alive: Option<Duration>,

Expand Down Expand Up @@ -290,6 +300,9 @@ impl Builder {
before_park: None,
after_unpark: None,

before_spawn: None,
after_termination: None,

keep_alive: None,

// Defaults for these values depend on the scheduler kind, so we get them
Expand Down Expand Up @@ -677,6 +690,91 @@ impl Builder {
self
}

/// Executes function `f` just before a task is spawned.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being
/// invoked immediately.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one spawn callback for a runtime; calling this function more
/// than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_spawn(|_| {
/// println!("spawning task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(all(not(loom), tokio_unstable))]
pub fn on_task_spawn<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.before_spawn = Some(std::sync::Arc::new(f));
self
}

/// Executes function `f` just after a task is terminated.
///
/// `f` is called within the Tokio context, so functions like
/// [`tokio::spawn`](crate::spawn) can be called.
///
/// This can be used for bookkeeping or monitoring purposes.
///
/// Note: There can only be one task termination callback for a runtime; calling this
/// function more than once replaces the last callback defined, rather than adding to it.
///
/// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time.
///
/// # Examples
///
/// ```
/// # use tokio::runtime;
/// # pub fn main() {
/// let runtime = runtime::Builder::new_current_thread()
/// .on_task_terminate(|_| {
/// println!("killing task");
/// })
/// .build()
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn(std::future::ready(()));
///
/// for _ in 0..64 {
/// tokio::task::yield_now().await;
/// }
/// })
/// # }
/// ```
#[cfg(all(not(loom), tokio_unstable))]
pub fn on_task_terminate<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&TaskMeta<'_>) + Send + Sync + 'static,
{
self.after_termination = Some(std::sync::Arc::new(f));
self
}

/// Creates the configured `Runtime`.
///
/// The returned `Runtime` instance is ready to spawn tasks.
Expand Down Expand Up @@ -1118,6 +1216,8 @@ impl Builder {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down Expand Up @@ -1269,6 +1369,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down Expand Up @@ -1316,6 +1418,8 @@ cfg_rt_multi_thread! {
Config {
before_park: self.before_park.clone(),
after_unpark: self.after_unpark.clone(),
before_spawn: self.before_spawn.clone(),
after_termination: self.after_termination.clone(),
global_queue_interval: self.global_queue_interval,
event_interval: self.event_interval,
local_queue_capacity: self.local_queue_capacity,
Expand Down
8 changes: 7 additions & 1 deletion tokio/src/runtime/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"),
allow(dead_code)
)]
use crate::runtime::Callback;
use crate::runtime::{Callback, TaskCallback};
use crate::util::RngSeedGenerator;

pub(crate) struct Config {
Expand All @@ -21,6 +21,12 @@ pub(crate) struct Config {
/// Callback for a worker unparking itself
pub(crate) after_unpark: Option<Callback>,

/// To run before each task is spawned.
pub(crate) before_spawn: Option<TaskCallback>,

/// To run after each task is terminated.
pub(crate) after_termination: Option<TaskCallback>,

/// The multi-threaded scheduler includes a per-worker LIFO slot used to
/// store the last scheduled task. This can improve certain usage patterns,
/// especially message passing between tasks. However, this LIFO slot is not
Expand Down
7 changes: 7 additions & 0 deletions tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,13 @@ cfg_rt! {
pub use dump::Dump;
}

mod task_hooks;
pub(crate) use task_hooks::{TaskHooks, TaskCallback};
#[cfg(tokio_unstable)]
pub use task_hooks::TaskMeta;
#[cfg(not(tokio_unstable))]
pub(crate) use task_hooks::TaskMeta;

mod handle;
pub use handle::{EnterGuard, Handle, TryCurrentError};

Expand Down
27 changes: 25 additions & 2 deletions tokio/src/runtime/scheduler/current_thread/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ use crate::loom::sync::atomic::AtomicBool;
use crate::loom::sync::Arc;
use crate::runtime::driver::{self, Driver};
use crate::runtime::scheduler::{self, Defer, Inject};
use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task};
use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics};
use crate::runtime::task::{
self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks,
};
use crate::runtime::{
blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics,
};
use crate::sync::notify::Notify;
use crate::util::atomic_cell::AtomicCell;
use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef};
Expand Down Expand Up @@ -41,6 +45,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,
}

/// Data required for executing the scheduler. The struct is passed around to
Expand Down Expand Up @@ -131,6 +138,10 @@ impl CurrentThread {
.unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL);

let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
inject: Inject::new(),
owned: OwnedTasks::new(1),
Expand Down Expand Up @@ -436,6 +447,12 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.task_hooks.spawn(&TaskMeta {
#[cfg(tokio_unstable)]
id,
_phantom: Default::default(),
});

if let Some(notified) = notified {
me.schedule(notified);
}
Expand Down Expand Up @@ -600,6 +617,12 @@ impl Schedule for Arc<Handle> {
});
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.task_hooks.task_terminate_callback.clone(),
}
}

cfg_unstable! {
fn unhandled_panic(&self) {
use crate::runtime::UnhandledPanic;
Expand Down
12 changes: 12 additions & 0 deletions tokio/src/runtime/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ cfg_rt! {

pub(crate) mod inject;
pub(crate) use inject::Inject;

use crate::runtime::TaskHooks;
}

cfg_rt_multi_thread! {
Expand Down Expand Up @@ -151,6 +153,16 @@ cfg_rt! {
}
}

pub(crate) fn hooks(&self) -> &TaskHooks {
match self {
Handle::CurrentThread(h) => &h.task_hooks,
#[cfg(feature = "rt-multi-thread")]
Handle::MultiThread(h) => &h.task_hooks,
#[cfg(all(tokio_unstable, feature = "rt-multi-thread"))]
Handle::MultiThreadAlt(h) => &h.task_hooks,
}
}

cfg_rt_multi_thread! {
cfg_unstable! {
pub(crate) fn expect_multi_thread_alt(&self) -> &Arc<multi_thread_alt::Handle> {
Expand Down
10 changes: 10 additions & 0 deletions tokio/src/runtime/scheduler/multi_thread/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::runtime::scheduler::multi_thread::worker;
use crate::runtime::{
blocking, driver,
task::{self, JoinHandle},
TaskHooks, TaskMeta,
};
use crate::util::RngSeedGenerator;

Expand All @@ -28,6 +29,9 @@ pub(crate) struct Handle {

/// Current random number generator seed
pub(crate) seed_generator: RngSeedGenerator,

/// User-supplied hooks to invoke for things
pub(crate) task_hooks: TaskHooks,
}

impl Handle {
Expand All @@ -51,6 +55,12 @@ impl Handle {
{
let (handle, notified) = me.shared.owned.bind(future, me.clone(), id);

me.task_hooks.spawn(&TaskMeta {
#[cfg(tokio_unstable)]
id,
_phantom: Default::default(),
});

me.schedule_option_task_without_yield(notified);

handle
Expand Down
14 changes: 12 additions & 2 deletions tokio/src/runtime/scheduler/multi_thread/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@

use crate::loom::sync::{Arc, Mutex};
use crate::runtime;
use crate::runtime::context;
use crate::runtime::scheduler::multi_thread::{
idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker,
};
use crate::runtime::scheduler::{inject, Defer, Lock};
use crate::runtime::task::OwnedTasks;
use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks};
use crate::runtime::{
blocking, coop, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics,
};
use crate::runtime::{context, TaskHooks};
use crate::util::atomic_cell::AtomicCell;
use crate::util::rand::{FastRand, RngSeedGenerator};

Expand Down Expand Up @@ -284,6 +284,10 @@ pub(super) fn create(

let remotes_len = remotes.len();
let handle = Arc::new(Handle {
task_hooks: TaskHooks {
task_spawn_callback: config.before_spawn.clone(),
task_terminate_callback: config.after_termination.clone(),
},
shared: Shared {
remotes: remotes.into_boxed_slice(),
inject,
Expand Down Expand Up @@ -1037,6 +1041,12 @@ impl task::Schedule for Arc<Handle> {
self.schedule_task(task, false);
}

fn hooks(&self) -> TaskHarnessScheduleHooks {
TaskHarnessScheduleHooks {
task_terminate_callback: self.task_hooks.task_terminate_callback.clone(),
}
}

fn yield_now(&self, task: Notified) {
self.schedule_task(task, true);
}
Expand Down
Loading
Loading