Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime: reduce codegen per task #5213

Merged
merged 2 commits into from
Nov 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions tokio/src/runtime/task/abort.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::runtime::task::{Id, RawTask};
use crate::runtime::task::{Header, RawTask};
use std::fmt;
use std::panic::{RefUnwindSafe, UnwindSafe};

@@ -14,13 +14,12 @@ use std::panic::{RefUnwindSafe, UnwindSafe};
/// [`JoinHandle`]: crate::task::JoinHandle
#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
pub struct AbortHandle {
raw: Option<RawTask>,
id: Id,
raw: RawTask,
}

impl AbortHandle {
pub(super) fn new(raw: Option<RawTask>, id: Id) -> Self {
Self { raw, id }
pub(super) fn new(raw: RawTask) -> Self {
Self { raw }
}

/// Abort the task associated with the handle.
@@ -35,9 +34,7 @@ impl AbortHandle {
/// [cancelled]: method@super::error::JoinError::is_cancelled
/// [`JoinHandle::abort`]: method@super::JoinHandle::abort
pub fn abort(&self) {
if let Some(ref raw) = self.raw {
raw.remote_abort();
}
self.raw.remote_abort();
}

/// Checks if the task associated with this `AbortHandle` has finished.
@@ -47,12 +44,8 @@ impl AbortHandle {
/// some time, and this method does not return `true` until it has
/// completed.
pub fn is_finished(&self) -> bool {
if let Some(raw) = self.raw {
let state = raw.header().state.load();
state.is_complete()
} else {
true
}
let state = self.raw.state().load();
state.is_complete()
}

/// Returns a [task ID] that uniquely identifies this task relative to other
@@ -67,7 +60,8 @@ impl AbortHandle {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
self.id
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}
}

@@ -79,16 +73,15 @@ impl RefUnwindSafe for AbortHandle {}

impl fmt::Debug for AbortHandle {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("AbortHandle")
.field("id", &self.id)
.finish()
// Safety: The header pointer is valid.
let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) };
let id = unsafe { id_ptr.as_ref() };
fmt.debug_struct("AbortHandle").field("id", id).finish()
}
}

impl Drop for AbortHandle {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
raw.drop_abort_handle();
}
self.raw.drop_abort_handle();
}
}
74 changes: 67 additions & 7 deletions tokio/src/runtime/task/core.rs
Original file line number Diff line number Diff line change
@@ -25,6 +25,9 @@ use std::task::{Context, Poll, Waker};
///
/// It is critical for `Header` to be the first field as the task structure will
/// be referenced by both *mut Cell and *mut Header.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Cell<T: Future, S> {
/// Hot task state data
@@ -44,15 +47,19 @@ pub(super) struct CoreStage<T: Future> {
/// The core of the task.
///
/// Holds the future or output, depending on the stage of execution.
///
/// Any changes to the layout of this struct _must_ also be reflected in the
/// const fns in raw.rs.
#[repr(C)]
pub(super) struct Core<T: Future, S> {
/// Scheduler used to drive this future.
pub(super) scheduler: S,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,

/// The task's ID, used for populating `JoinError`s.
pub(super) task_id: Id,

/// Either the future or the output.
pub(super) stage: CoreStage<T>,
}

/// Crate public as this is also needed by the pool.
@@ -82,7 +89,7 @@ pub(crate) struct Header {

/// The tracing ID for this instrumented task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) id: Option<tracing::Id>,
pub(super) tracing_id: Option<tracing::Id>,
}

unsafe impl Send for Header {}
@@ -117,15 +124,15 @@ impl<T: Future, S: Schedule> Cell<T, S> {
/// structures.
pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box<Cell<T, S>> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let id = future.id();
let tracing_id = future.id();
let result = Box::new(Cell {
header: Header {
state,
queue_next: UnsafeCell::new(None),
vtable: raw::vtable::<T, S>(),
owner_id: UnsafeCell::new(0),
#[cfg(all(tokio_unstable, feature = "tracing"))]
id,
tracing_id,
},
core: Core {
scheduler,
@@ -144,8 +151,16 @@ impl<T: Future, S: Schedule> Cell<T, S> {
{
let trailer_addr = (&result.trailer) as *const Trailer as usize;
let trailer_ptr = unsafe { Header::get_trailer(NonNull::from(&result.header)) };

assert_eq!(trailer_addr, trailer_ptr.as_ptr() as usize);

let scheduler_addr = (&result.core.scheduler) as *const S as usize;
let scheduler_ptr =
unsafe { Header::get_scheduler::<S>(NonNull::from(&result.header)) };
assert_eq!(scheduler_addr, scheduler_ptr.as_ptr() as usize);

let id_addr = (&result.core.task_id) as *const Id as usize;
let id_ptr = unsafe { Header::get_id_ptr(NonNull::from(&result.header)) };
assert_eq!(id_addr, id_ptr.as_ptr() as usize);
}

result
@@ -295,6 +310,51 @@ impl Header {
let trailer = me.as_ptr().cast::<u8>().add(offset).cast::<Trailer>();
NonNull::new_unchecked(trailer)
}

/// Gets a pointer to the scheduler of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
///
/// The generic type S must be set to the correct scheduler type for this
/// task.
pub(super) unsafe fn get_scheduler<S>(me: NonNull<Header>) -> NonNull<S> {
let offset = me.as_ref().vtable.scheduler_offset;
let scheduler = me.as_ptr().cast::<u8>().add(offset).cast::<S>();
NonNull::new_unchecked(scheduler)
}

/// Gets a pointer to the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id_ptr(me: NonNull<Header>) -> NonNull<Id> {
let offset = me.as_ref().vtable.id_offset;
let id = me.as_ptr().cast::<u8>().add(offset).cast::<Id>();
NonNull::new_unchecked(id)
}

/// Gets the id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
pub(super) unsafe fn get_id(me: NonNull<Header>) -> Id {
let ptr = Header::get_id_ptr(me).as_ptr();
*ptr
}

/// Gets the tracing id of the task containing this `Header`.
///
/// # Safety
///
/// The provided raw pointer must point at the header of a task.
#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) unsafe fn get_tracing_id(me: &NonNull<Header>) -> Option<&tracing::Id> {
me.as_ref().tracing_id.as_ref()
}
}

impl Trailer {
186 changes: 92 additions & 94 deletions tokio/src/runtime/task/harness.rs
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use crate::future::Future;
use crate::runtime::task::core::{Cell, Core, Header, Trailer};
use crate::runtime::task::state::{Snapshot, State};
use crate::runtime::task::waker::waker_ref;
use crate::runtime::task::{JoinError, Notified, Schedule, Task};
use crate::runtime::task::{JoinError, Notified, RawTask, Schedule, Task};

use std::mem;
use std::mem::ManuallyDrop;
@@ -47,11 +47,102 @@ where
}
}

/// Task operations that can be implemented without being generic over the
/// scheduler or task. Only one version of these methods should exist in the
/// final binary.
impl RawTask {
pub(super) fn drop_reference(self) {
if self.state().ref_dec() {
self.dealloc();
}
}

/// This call consumes a ref-count and notifies the task. This will create a
/// new Notified and submit it if necessary.
///
/// The caller does not need to hold a ref-count besides the one that was
/// passed to this call.
pub(super) fn wake_by_val(&self) {
use super::state::TransitionToNotifiedByVal;

match self.state().transition_to_notified_by_val() {
TransitionToNotifiedByVal::Submit => {
// The caller has given us a ref-count, and the transition has
// created a new ref-count, so we now hold two. We turn the new
// ref-count Notified and pass it to the call to `schedule`.
//
// The old ref-count is retained for now to ensure that the task
// is not dropped during the call to `schedule` if the call
// drops the task it was given.
self.schedule();

// Now that we have completed the call to schedule, we can
// release our ref-count.
self.drop_reference();
}
TransitionToNotifiedByVal::Dealloc => {
self.dealloc();
}
TransitionToNotifiedByVal::DoNothing => {}
}
}

/// This call notifies the task. It will not consume any ref-counts, but the
/// caller should hold a ref-count. This will create a new Notified and
/// submit it if necessary.
pub(super) fn wake_by_ref(&self) {
use super::state::TransitionToNotifiedByRef;

match self.state().transition_to_notified_by_ref() {
TransitionToNotifiedByRef::Submit => {
// The transition above incremented the ref-count for a new task
// and the caller also holds a ref-count. The caller's ref-count
// ensures that the task is not destroyed even if the new task
// is dropped before `schedule` returns.
self.schedule();
}
TransitionToNotifiedByRef::DoNothing => {}
}
}

/// Remotely aborts the task.
///
/// The caller should hold a ref-count, but we do not consume it.
///
/// This is similar to `shutdown` except that it asks the runtime to perform
/// the shutdown. This is necessary to avoid the shutdown happening in the
/// wrong thread for non-Send tasks.
pub(super) fn remote_abort(&self) {
if self.state().transition_to_notified_and_cancel() {
// The transition has created a new ref-count, which we turn into
// a Notified and pass to the task.
//
// Since the caller holds a ref-count, the task cannot be destroyed
// before the call to `schedule` returns even if the call drops the
// `Notified` internally.
self.schedule();
}
}

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) fn try_set_join_waker(&self, waker: &Waker) -> bool {
can_read_output(self.header(), self.trailer(), waker)
}
}

impl<T, S> Harness<T, S>
where
T: Future,
S: Schedule,
{
pub(super) fn drop_reference(self) {
if self.state().ref_dec() {
self.dealloc();
}
}

/// Polls the inner future. A ref-count is consumed.
///
/// All necessary state checks and transitions are performed.
@@ -185,13 +276,6 @@ where
}
}

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
can_read_output(self.header(), self.trailer(), waker)
}

pub(super) fn drop_join_handle_slow(self) {
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
// case the task concurrently completed.
@@ -214,92 +298,6 @@ where
self.drop_reference();
}

/// Remotely aborts the task.
///
/// The caller should hold a ref-count, but we do not consume it.
///
/// This is similar to `shutdown` except that it asks the runtime to perform
/// the shutdown. This is necessary to avoid the shutdown happening in the
/// wrong thread for non-Send tasks.
pub(super) fn remote_abort(self) {
if self.state().transition_to_notified_and_cancel() {
// The transition has created a new ref-count, which we turn into
// a Notified and pass to the task.
//
// Since the caller holds a ref-count, the task cannot be destroyed
// before the call to `schedule` returns even if the call drops the
// `Notified` internally.
self.core()
.scheduler
.schedule(Notified(self.get_new_task()));
}
}

// ===== waker behavior =====

/// This call consumes a ref-count and notifies the task. This will create a
/// new Notified and submit it if necessary.
///
/// The caller does not need to hold a ref-count besides the one that was
/// passed to this call.
pub(super) fn wake_by_val(self) {
use super::state::TransitionToNotifiedByVal;

match self.state().transition_to_notified_by_val() {
TransitionToNotifiedByVal::Submit => {
// The caller has given us a ref-count, and the transition has
// created a new ref-count, so we now hold two. We turn the new
// ref-count Notified and pass it to the call to `schedule`.
//
// The old ref-count is retained for now to ensure that the task
// is not dropped during the call to `schedule` if the call
// drops the task it was given.
self.core()
.scheduler
.schedule(Notified(self.get_new_task()));

// Now that we have completed the call to schedule, we can
// release our ref-count.
self.drop_reference();
}
TransitionToNotifiedByVal::Dealloc => {
self.dealloc();
}
TransitionToNotifiedByVal::DoNothing => {}
}
}

/// This call notifies the task. It will not consume any ref-counts, but the
/// caller should hold a ref-count. This will create a new Notified and
/// submit it if necessary.
pub(super) fn wake_by_ref(&self) {
use super::state::TransitionToNotifiedByRef;

match self.state().transition_to_notified_by_ref() {
TransitionToNotifiedByRef::Submit => {
// The transition above incremented the ref-count for a new task
// and the caller also holds a ref-count. The caller's ref-count
// ensures that the task is not destroyed even if the new task
// is dropped before `schedule` returns.
self.core()
.scheduler
.schedule(Notified(self.get_new_task()));
}
TransitionToNotifiedByRef::DoNothing => {}
}
}

pub(super) fn drop_reference(self) {
if self.state().ref_dec() {
self.dealloc();
}
}

#[cfg(all(tokio_unstable, feature = "tracing"))]
pub(super) fn id(&self) -> Option<&tracing::Id> {
self.header().id.as_ref()
}

// ====== internal ======

/// Completes the task. This method assumes that the state is RUNNING.
67 changes: 24 additions & 43 deletions tokio/src/runtime/task/join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::runtime::task::{Id, RawTask};
use crate::runtime::task::{Header, RawTask};

use std::fmt;
use std::future::Future;
@@ -154,8 +154,7 @@ cfg_rt! {
/// [`std::thread::JoinHandle`]: std::thread::JoinHandle
/// [`JoinError`]: crate::task::JoinError
pub struct JoinHandle<T> {
raw: Option<RawTask>,
id: Id,
raw: RawTask,
_p: PhantomData<T>,
}
}
@@ -167,10 +166,9 @@ impl<T> UnwindSafe for JoinHandle<T> {}
impl<T> RefUnwindSafe for JoinHandle<T> {}

impl<T> JoinHandle<T> {
pub(super) fn new(raw: RawTask, id: Id) -> JoinHandle<T> {
pub(super) fn new(raw: RawTask) -> JoinHandle<T> {
JoinHandle {
raw: Some(raw),
id,
raw,
_p: PhantomData,
}
}
@@ -209,9 +207,7 @@ impl<T> JoinHandle<T> {
/// ```
/// [cancelled]: method@super::error::JoinError::is_cancelled
pub fn abort(&self) {
if let Some(raw) = self.raw {
raw.remote_abort();
}
self.raw.remote_abort();
}

/// Checks if the task associated with this `JoinHandle` has finished.
@@ -243,31 +239,22 @@ impl<T> JoinHandle<T> {
/// ```
/// [`abort`]: method@JoinHandle::abort
pub fn is_finished(&self) -> bool {
if let Some(raw) = self.raw {
let state = raw.header().state.load();
state.is_complete()
} else {
true
}
let state = self.raw.header().state.load();
state.is_complete()
}

/// Set the waker that is notified when the task completes.
pub(crate) fn set_join_waker(&mut self, waker: &Waker) {
if let Some(raw) = self.raw {
if raw.try_set_join_waker(waker) {
// In this case the task has already completed. We wake the waker immediately.
waker.wake_by_ref();
}
if self.raw.try_set_join_waker(waker) {
// In this case the task has already completed. We wake the waker immediately.
waker.wake_by_ref();
}
}

/// Returns a new `AbortHandle` that can be used to remotely abort this task.
pub(crate) fn abort_handle(&self) -> super::AbortHandle {
let raw = self.raw.map(|raw| {
raw.ref_inc();
raw
});
super::AbortHandle::new(raw, self.id)
self.raw.ref_inc();
super::AbortHandle::new(self.raw)
}

/// Returns a [task ID] that uniquely identifies this task relative to other
@@ -282,7 +269,8 @@ impl<T> JoinHandle<T> {
#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn id(&self) -> super::Id {
self.id
// Safety: The header pointer is valid.
unsafe { Header::get_id(self.raw.header_ptr()) }
}
}

@@ -297,13 +285,6 @@ impl<T> Future for JoinHandle<T> {
// Keep track of task budget
let coop = ready!(crate::runtime::coop::poll_proceed(cx));

// Raw should always be set. If it is not, this is due to polling after
// completion
let raw = self
.raw
.as_ref()
.expect("polling after `JoinHandle` already completed");

// Try to read the task output. If the task is not yet complete, the
// waker is stored and is notified once the task does complete.
//
@@ -316,7 +297,8 @@ impl<T> Future for JoinHandle<T> {
//
// The type of `T` must match the task's output type.
unsafe {
raw.try_read_output(&mut ret as *mut _ as *mut (), cx.waker());
self.raw
.try_read_output(&mut ret as *mut _ as *mut (), cx.waker());
}

if ret.is_ready() {
@@ -329,13 +311,11 @@ impl<T> Future for JoinHandle<T> {

impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
if raw.header().state.drop_join_handle_fast().is_ok() {
return;
}

raw.drop_join_handle_slow();
if self.raw.state().drop_join_handle_fast().is_ok() {
return;
}

self.raw.drop_join_handle_slow();
}
}

@@ -344,8 +324,9 @@ where
T: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("JoinHandle")
.field("id", &self.id)
.finish()
// Safety: The header pointer is valid.
let id_ptr = unsafe { Header::get_id_ptr(self.raw.header_ptr()) };
let id = unsafe { id_ptr.as_ref() };
fmt.debug_struct("JoinHandle").field("id", id).finish()
}
}
2 changes: 1 addition & 1 deletion tokio/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
@@ -338,7 +338,7 @@ cfg_rt! {
raw,
_p: PhantomData,
});
let join = JoinHandle::new(raw, id);
let join = JoinHandle::new(raw);

(task, notified, join)
}
134 changes: 97 additions & 37 deletions tokio/src/runtime/task/raw.rs
Original file line number Diff line number Diff line change
@@ -14,45 +14,47 @@ pub(super) struct Vtable {
/// Polls the future.
pub(super) poll: unsafe fn(NonNull<Header>),

/// Schedules the task for execution on the runtime.
pub(super) schedule: unsafe fn(NonNull<Header>),

/// Deallocates the memory.
pub(super) dealloc: unsafe fn(NonNull<Header>),

/// Reads the task output, if complete.
pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),

/// Try to set the waker notified when the task is complete. Returns true if
/// the task has already completed. If this call returns false, then the
/// waker will not be notified.
pub(super) try_set_join_waker: unsafe fn(NonNull<Header>, &Waker) -> bool,

/// The join handle has been dropped.
pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),

/// An abort handle has been dropped.
pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),

/// The task is remotely aborted.
pub(super) remote_abort: unsafe fn(NonNull<Header>),

/// Scheduler is being shutdown.
pub(super) shutdown: unsafe fn(NonNull<Header>),

/// The number of bytes that the `trailer` field is offset from the header.
pub(super) trailer_offset: usize,

/// The number of bytes that the `scheduler` field is offset from the header.
pub(super) scheduler_offset: usize,

/// The number of bytes that the `id` field is offset from the header.
pub(super) id_offset: usize,
}

/// Get the vtable for the requested `T` and `S` generics.
pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
&Vtable {
poll: poll::<T, S>,
schedule: schedule::<S>,
dealloc: dealloc::<T, S>,
try_read_output: try_read_output::<T, S>,
try_set_join_waker: try_set_join_waker::<T, S>,
drop_join_handle_slow: drop_join_handle_slow::<T, S>,
drop_abort_handle: drop_abort_handle::<T, S>,
remote_abort: remote_abort::<T, S>,
shutdown: shutdown::<T, S>,
trailer_offset: TrailerOffsetHelper::<T, S>::OFFSET,
trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET,
scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET,
id_offset: OffsetHelper::<T, S>::ID_OFFSET,
}
}

@@ -61,17 +63,31 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
///
/// See this thread for more info:
/// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508>
struct TrailerOffsetHelper<T, S>(T, S);
impl<T: Future, S: Schedule> TrailerOffsetHelper<T, S> {
struct OffsetHelper<T, S>(T, S);
impl<T: Future, S: Schedule> OffsetHelper<T, S> {
// Pass `size_of`/`align_of` as arguments rather than calling them directly
// inside `get_trailer_offset` because trait bounds on generic parameters
// of const fn are unstable on our MSRV.
const OFFSET: usize = get_trailer_offset(
const TRAILER_OFFSET: usize = get_trailer_offset(
std::mem::size_of::<Header>(),
std::mem::size_of::<Core<T, S>>(),
std::mem::align_of::<Core<T, S>>(),
std::mem::align_of::<Trailer>(),
);

// The `scheduler` is the first field of `Core`, so it has the same
// offset as `Core`.
const SCHEDULER_OFFSET: usize = get_core_offset(
std::mem::size_of::<Header>(),
std::mem::align_of::<Core<T, S>>(),
);

const ID_OFFSET: usize = get_id_offset(
std::mem::size_of::<Header>(),
std::mem::align_of::<Core<T, S>>(),
std::mem::size_of::<S>(),
std::mem::align_of::<Id>(),
);
}

/// Compute the offset of the `Trailer` field in `Cell<T, S>` using the
@@ -101,6 +117,44 @@ const fn get_trailer_offset(
offset
}

/// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the
/// `#[repr(C)]` algorithm.
///
/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
const fn get_core_offset(header_size: usize, core_align: usize) -> usize {
let mut offset = header_size;

let core_misalign = offset % core_align;
if core_misalign > 0 {
offset += core_align - core_misalign;
}

offset
}

/// Compute the offset of the `Id` field in `Cell<T, S>` using the
/// `#[repr(C)]` algorithm.
///
/// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
/// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
const fn get_id_offset(
header_size: usize,
core_align: usize,
scheduler_size: usize,
id_align: usize,
) -> usize {
let mut offset = get_core_offset(header_size, core_align);
offset += scheduler_size;

let id_misalign = offset % id_align;
if id_misalign > 0 {
offset += id_align - id_misalign;
}

offset
}

impl RawTask {
pub(super) fn new<T, S>(task: T, scheduler: S, id: Id) -> RawTask
where
@@ -121,19 +175,36 @@ impl RawTask {
self.ptr
}

/// Returns a reference to the task's meta structure.
///
/// Safe as `Header` is `Sync`.
pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> {
unsafe { Header::get_trailer(self.ptr) }
}

/// Returns a reference to the task's header.
pub(super) fn header(&self) -> &Header {
unsafe { self.ptr.as_ref() }
}

/// Returns a reference to the task's trailer.
pub(super) fn trailer(&self) -> &Trailer {
unsafe { &*self.trailer_ptr().as_ptr() }
}

/// Returns a reference to the task's state.
pub(super) fn state(&self) -> &State {
&self.header().state
}

/// Safety: mutual exclusion is required to call this function.
pub(super) fn poll(self) {
let vtable = self.header().vtable;
unsafe { (vtable.poll)(self.ptr) }
}

pub(super) fn schedule(self) {
let vtable = self.header().vtable;
unsafe { (vtable.schedule)(self.ptr) }
}

pub(super) fn dealloc(self) {
let vtable = self.header().vtable;
unsafe {
@@ -148,11 +219,6 @@ impl RawTask {
(vtable.try_read_output)(self.ptr, dst, waker);
}

pub(super) fn try_set_join_waker(self, waker: &Waker) -> bool {
let vtable = self.header().vtable;
unsafe { (vtable.try_set_join_waker)(self.ptr, waker) }
}

pub(super) fn drop_join_handle_slow(self) {
let vtable = self.header().vtable;
unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
@@ -168,11 +234,6 @@ impl RawTask {
unsafe { (vtable.shutdown)(self.ptr) }
}

pub(super) fn remote_abort(self) {
let vtable = self.header().vtable;
unsafe { (vtable.remote_abort)(self.ptr) }
}

/// Increment the task's reference count.
///
/// Currently, this is used only when creating an `AbortHandle`.
@@ -194,6 +255,15 @@ unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) {
harness.poll();
}

unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) {
use crate::runtime::task::{Notified, Task};

let scheduler = Header::get_scheduler::<S>(ptr);
scheduler
.as_ref()
.schedule(Notified(Task::from_raw(ptr.cast())));
}

unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.dealloc();
@@ -210,11 +280,6 @@ unsafe fn try_read_output<T: Future, S: Schedule>(
harness.try_read_output(out, waker);
}

unsafe fn try_set_join_waker<T: Future, S: Schedule>(ptr: NonNull<Header>, waker: &Waker) -> bool {
let harness = Harness::<T, S>::from_raw(ptr);
harness.try_set_join_waker(waker)
}

unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.drop_join_handle_slow()
@@ -225,11 +290,6 @@ unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
harness.drop_reference();
}

unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.remote_abort()
}

unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) {
let harness = Harness::<T, S>::from_raw(ptr);
harness.shutdown()
82 changes: 28 additions & 54 deletions tokio/src/runtime/task/waker.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::future::Future;
use crate::runtime::task::harness::Harness;
use crate::runtime::task::{Header, Schedule};
use crate::runtime::task::{Header, RawTask, Schedule};

use std::marker::PhantomData;
use std::mem::ManuallyDrop;
@@ -28,7 +27,7 @@ where
// point and not an *owned* waker, we must ensure that `drop` is never
// called on this waker instance. This is done by wrapping it with
// `ManuallyDrop` and then never calling drop.
let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker::<T, S>(*header))) };
let waker = unsafe { ManuallyDrop::new(Waker::from_raw(raw_waker(*header))) };

WakerRef {
waker,
@@ -46,8 +45,8 @@ impl<S> ops::Deref for WakerRef<'_, S> {

cfg_trace! {
macro_rules! trace {
($harness:expr, $op:expr) => {
if let Some(id) = $harness.id() {
($header:expr, $op:expr) => {
if let Some(id) = Header::get_tracing_id(&$header) {
tracing::trace!(
target: "tokio::task::waker",
op = $op,
@@ -60,71 +59,46 @@ cfg_trace! {

cfg_not_trace! {
macro_rules! trace {
($harness:expr, $op:expr) => {
($header:expr, $op:expr) => {
// noop
let _ = &$harness;
let _ = &$header;
}
}
}

unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker
where
T: Future,
S: Schedule,
{
let header = ptr as *const Header;
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
trace!(harness, "waker.clone");
(*header).state.ref_inc();
raw_waker::<T, S>(ptr)
unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
let header = NonNull::new_unchecked(ptr as *mut Header);
trace!(header, "waker.clone");
header.as_ref().state.ref_inc();
raw_waker(header)
}

unsafe fn drop_waker<T, S>(ptr: *const ())
where
T: Future,
S: Schedule,
{
unsafe fn drop_waker(ptr: *const ()) {
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
trace!(harness, "waker.drop");
harness.drop_reference();
trace!(ptr, "waker.drop");
let raw = RawTask::from_raw(ptr);
raw.drop_reference();
}

unsafe fn wake_by_val<T, S>(ptr: *const ())
where
T: Future,
S: Schedule,
{
unsafe fn wake_by_val(ptr: *const ()) {
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
trace!(harness, "waker.wake");
harness.wake_by_val();
trace!(ptr, "waker.wake");
let raw = RawTask::from_raw(ptr);
raw.wake_by_val();
}

// Wake without consuming the waker
unsafe fn wake_by_ref<T, S>(ptr: *const ())
where
T: Future,
S: Schedule,
{
unsafe fn wake_by_ref(ptr: *const ()) {
let ptr = NonNull::new_unchecked(ptr as *mut Header);
let harness = Harness::<T, S>::from_raw(ptr);
trace!(harness, "waker.wake_by_ref");
harness.wake_by_ref();
trace!(ptr, "waker.wake_by_ref");
let raw = RawTask::from_raw(ptr);
raw.wake_by_ref();
}

fn raw_waker<T, S>(header: NonNull<Header>) -> RawWaker
where
T: Future,
S: Schedule,
{
static WAKER_VTABLE: RawWakerVTable =
RawWakerVTable::new(clone_waker, wake_by_val, wake_by_ref, drop_waker);

fn raw_waker(header: NonNull<Header>) -> RawWaker {
let ptr = header.as_ptr() as *const ();
let vtable = &RawWakerVTable::new(
clone_waker::<T, S>,
wake_by_val::<T, S>,
wake_by_ref::<T, S>,
drop_waker::<T, S>,
);
RawWaker::new(ptr, vtable)
RawWaker::new(ptr, &WAKER_VTABLE)
}