Skip to content

Commit f573b7b

Browse files
committed
Drop the join waker of a task eagerly when the task completes and there is no
join interest
1 parent bb7ca75 commit f573b7b

File tree

3 files changed

+153
-46
lines changed

3 files changed

+153
-46
lines changed

tokio/src/runtime/task/harness.rs

+50-16
Original file line numberDiff line numberDiff line change
@@ -284,21 +284,33 @@ where
284284
}
285285

286286
pub(super) fn drop_join_handle_slow(self) {
287+
use super::state::TransitionToJoinHandleDrop;
287288
// Try to unset `JOIN_INTEREST`. This must be done as a first step in
288289
// case the task concurrently completed.
289-
if self.state().unset_join_interested().is_err() {
290-
// It is our responsibility to drop the output. This is critical as
291-
// the task output may not be `Send` and as such must remain with
292-
// the scheduler or `JoinHandle`. i.e. if the output remains in the
293-
// task structure until the task is deallocated, it may be dropped
294-
// by a Waker on any arbitrary thread.
295-
//
296-
// Panics are delivered to the user via the `JoinHandle`. Given that
297-
// they are dropping the `JoinHandle`, we assume they are not
298-
// interested in the panic and swallow it.
299-
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
300-
self.core().drop_future_or_output();
301-
}));
290+
//
291+
// TODO Create new bit/flag in state -> Set WantToDropJoinWaker in transition when failing
292+
let transition = self.state().transition_to_join_handle_drop();
293+
match transition {
294+
TransitionToJoinHandleDrop::Failed => {
295+
// It is our responsibility to drop the output. This is critical as
296+
// the task output may not be `Send` and as such must remain with
297+
// the scheduler or `JoinHandle`. i.e. if the output remains in the
298+
// task structure until the task is deallocated, it may be dropped
299+
// by a Waker on any arbitrary thread.
300+
//
301+
// Panics are delivered to the user via the `JoinHandle`. Given that
302+
// they are dropping the `JoinHandle`, we assume they are not
303+
// interested in the panic and swallow it.
304+
let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| {
305+
self.core().drop_future_or_output();
306+
}));
307+
}
308+
TransitionToJoinHandleDrop::OkDropJoinWaker => unsafe {
309+
// If there is a waker associated with this task when the `JoinHandle` is about to get
310+
// dropped we want to also drop this waker if the task is already completed.
311+
self.trailer().set_waker(None);
312+
},
313+
TransitionToJoinHandleDrop::OkDoNothing => (),
302314
}
303315

304316
// Drop the `JoinHandle` reference, possibly deallocating the task
@@ -309,6 +321,7 @@ where
309321

310322
/// Completes the task. This method assumes that the state is RUNNING.
311323
fn complete(self) {
324+
use super::state::TransitionToTerminal;
312325
// The future has completed and its output has been written to the task
313326
// stage. We transition from running to complete.
314327

@@ -346,8 +359,29 @@ where
346359
// The task has completed execution and will no longer be scheduled.
347360
let num_release = self.release();
348361

349-
if self.state().transition_to_terminal(num_release) {
350-
self.dealloc();
362+
match self.state().transition_to_terminal(num_release) {
363+
TransitionToTerminal::OkDoNothing => (),
364+
TransitionToTerminal::OkDealloc => {
365+
self.dealloc();
366+
}
367+
TransitionToTerminal::FailedDropJoinWaker => {
368+
// Safety: In this case we are the only one referencing the task and the active
369+
// waker is the only one preventing the task from being deallocated so noone else
370+
// will try to access the waker here.
371+
unsafe {
372+
self.trailer().set_waker(None);
373+
}
374+
375+
// We do not expect this to happen since `TransitionToTerminal::DropJoinWaker`
376+
// will only be returned when after dropping the JoinWaker the task can be
377+
// safely. Because after this failed transition the COMPLETE bit is still set
378+
// its fine to transition to terminal in two steps here
379+
if let TransitionToTerminal::OkDealloc =
380+
self.state().transition_to_terminal(num_release)
381+
{
382+
self.dealloc();
383+
}
384+
}
351385
}
352386
}
353387

@@ -387,7 +421,7 @@ fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool {
387421

388422
debug_assert!(snapshot.is_join_interested());
389423

390-
if !snapshot.is_complete() {
424+
if !snapshot.is_complete() && !snapshot.is_terminal() {
391425
// If the task is not complete, try storing the provided waker in the
392426
// task's waker field.
393427

tokio/src/runtime/task/state.rs

+72-30
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ const JOIN_WAKER: usize = 0b10_000;
3636
/// The task has been forcibly cancelled.
3737
const CANCELLED: usize = 0b100_000;
3838

39+
const TERMINAL: usize = 0b1_000_000;
40+
3941
/// All bits.
40-
const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
42+
// const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED;
43+
const STATE_MASK: usize =
44+
LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED | TERMINAL;
4145

4246
/// Bits used by the ref count portion of the state.
4347
const REF_COUNT_MASK: usize = !STATE_MASK;
@@ -89,6 +93,20 @@ pub(crate) enum TransitionToNotifiedByRef {
8993
Submit,
9094
}
9195

96+
#[must_use]
97+
pub(crate) enum TransitionToJoinHandleDrop {
98+
Failed,
99+
OkDoNothing,
100+
OkDropJoinWaker,
101+
}
102+
103+
#[must_use]
104+
pub(crate) enum TransitionToTerminal {
105+
OkDoNothing,
106+
OkDealloc,
107+
FailedDropJoinWaker,
108+
}
109+
92110
/// All transitions are performed via RMW operations. This establishes an
93111
/// unambiguous modification order.
94112
impl State {
@@ -174,30 +192,69 @@ impl State {
174192
})
175193
}
176194

195+
pub(super) fn transition_to_join_handle_drop(&self) -> TransitionToJoinHandleDrop {
196+
self.fetch_update_action(|mut snapshot| {
197+
if snapshot.is_join_interested() {
198+
snapshot.unset_join_interested()
199+
}
200+
201+
if snapshot.is_complete() && !snapshot.is_terminal() {
202+
(TransitionToJoinHandleDrop::Failed, None)
203+
} else if snapshot.is_join_waker_set() {
204+
snapshot.unset_join_waker();
205+
(TransitionToJoinHandleDrop::OkDropJoinWaker, Some(snapshot))
206+
} else {
207+
(TransitionToJoinHandleDrop::OkDoNothing, Some(snapshot))
208+
}
209+
})
210+
}
211+
177212
/// Transitions the task from `Running` -> `Complete`.
178213
pub(super) fn transition_to_complete(&self) -> Snapshot {
179214
const DELTA: usize = RUNNING | COMPLETE;
180215

181216
let prev = Snapshot(self.val.fetch_xor(DELTA, AcqRel));
182217
assert!(prev.is_running());
183218
assert!(!prev.is_complete());
219+
assert!(!prev.is_terminal());
184220

185221
Snapshot(prev.0 ^ DELTA)
186222
}
187223

188224
/// Transitions from `Complete` -> `Terminal`, decrementing the reference
189225
/// count the specified number of times.
190226
///
191-
/// Returns true if the task should be deallocated.
192-
pub(super) fn transition_to_terminal(&self, count: usize) -> bool {
193-
let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel));
194-
assert!(
195-
prev.ref_count() >= count,
196-
"current: {}, sub: {}",
197-
prev.ref_count(),
198-
count
199-
);
200-
prev.ref_count() == count
227+
/// Returns `TransitionToTerminal::OkDoNothing` if transition was successful but the task can
228+
/// not already be deallocated.
229+
/// Returns `TransitionToTerminal::OkDealloc` if the task should be deallocated.
230+
/// Returns `TransitionToTerminal::FailedDropJoinWaker` if the transition failed because of a
231+
/// the join waker being the only last. In this case the reference count will not be decremented
232+
/// but the `JOIN_WAKER` bit will be unset.
233+
pub(super) fn transition_to_terminal(&self, count: usize) -> TransitionToTerminal {
234+
self.fetch_update_action(|mut snapshot| {
235+
assert!(!snapshot.is_running());
236+
assert!(snapshot.is_complete());
237+
assert!(!snapshot.is_terminal());
238+
assert!(
239+
snapshot.ref_count() >= count,
240+
"current: {}, sub: {}",
241+
snapshot.ref_count(),
242+
count
243+
);
244+
245+
if snapshot.ref_count() == count {
246+
snapshot.0 -= count * REF_ONE;
247+
snapshot.0 |= TERMINAL;
248+
(TransitionToTerminal::OkDealloc, Some(snapshot))
249+
} else if !snapshot.is_join_interested() && snapshot.is_join_waker_set() {
250+
snapshot.unset_join_waker();
251+
(TransitionToTerminal::FailedDropJoinWaker, Some(snapshot))
252+
} else {
253+
snapshot.0 -= count * REF_ONE;
254+
snapshot.0 |= TERMINAL;
255+
(TransitionToTerminal::OkDoNothing, Some(snapshot))
256+
}
257+
})
201258
}
202259

203260
/// Transitions the state to `NOTIFIED`.
@@ -371,25 +428,6 @@ impl State {
371428
.map_err(|_| ())
372429
}
373430

374-
/// Tries to unset the `JOIN_INTEREST` flag.
375-
///
376-
/// Returns `Ok` if the operation happens before the task transitions to a
377-
/// completed state, `Err` otherwise.
378-
pub(super) fn unset_join_interested(&self) -> UpdateResult {
379-
self.fetch_update(|curr| {
380-
assert!(curr.is_join_interested());
381-
382-
if curr.is_complete() {
383-
return None;
384-
}
385-
386-
let mut next = curr;
387-
next.unset_join_interested();
388-
389-
Some(next)
390-
})
391-
}
392-
393431
/// Sets the `JOIN_WAKER` bit.
394432
///
395433
/// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if
@@ -557,6 +595,10 @@ impl Snapshot {
557595
self.0 & COMPLETE == COMPLETE
558596
}
559597

598+
pub(super) fn is_terminal(self) -> bool {
599+
self.0 & TERMINAL == TERMINAL
600+
}
601+
560602
pub(super) fn is_join_interested(self) -> bool {
561603
self.0 & JOIN_INTEREST == JOIN_INTEREST
562604
}

tokio/src/runtime/tests/loom_multi_thread.rs

+31
Original file line numberDiff line numberDiff line change
@@ -459,3 +459,34 @@ impl<T: Future> Future for Track<T> {
459459
})
460460
}
461461
}
462+
463+
#[test]
464+
fn timo_test() {
465+
use crate::sync::mpsc::channel;
466+
467+
loom::model(|| {
468+
let pool = mk_pool(2);
469+
470+
pool.block_on(async move {
471+
let (tx, mut rx) = channel(1);
472+
473+
let (a_closer, mut wait_for_close_a) = channel::<()>(1);
474+
let (b_closer, mut wait_for_close_b) = channel::<()>(1);
475+
476+
let a = spawn(async move {
477+
let b = rx.recv().await.unwrap();
478+
479+
futures::future::select(std::pin::pin!(b), std::pin::pin!(a_closer.send(()))).await;
480+
});
481+
482+
let b = spawn(async move {
483+
let _ = a.await;
484+
let _ = b_closer.send(()).await;
485+
});
486+
487+
tx.send(b).await.unwrap();
488+
489+
futures::future::join(wait_for_close_a.recv(), wait_for_close_b.recv()).await;
490+
});
491+
});
492+
}

0 commit comments

Comments
 (0)