Skip to content

Commit 53ea44b

Browse files
authored
sync: add CancellationToken::run_until_cancelled (tokio-rs#6618)
1 parent a865ca1 commit 53ea44b

File tree

2 files changed

+94
-0
lines changed

2 files changed

+94
-0
lines changed

tokio-util/src/sync/cancellation_token.rs

+46
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,52 @@ impl CancellationToken {
241241
pub fn drop_guard(self) -> DropGuard {
242242
DropGuard { inner: Some(self) }
243243
}
244+
245+
/// Runs a future to completion and returns its result wrapped inside of an `Option`
246+
/// unless the `CancellationToken` is cancelled. In that case the function returns
247+
/// `None` and the future gets dropped.
248+
///
249+
/// # Cancel safety
250+
///
251+
/// This method is only cancel safe if `fut` is cancel safe.
252+
pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
253+
where
254+
F: Future,
255+
{
256+
pin_project! {
257+
/// A Future that is resolved once the corresponding [`CancellationToken`]
258+
/// is cancelled or a given Future gets resolved. It is biased towards the
259+
/// Future completion.
260+
#[must_use = "futures do nothing unless polled"]
261+
struct RunUntilCancelledFuture<'a, F: Future> {
262+
#[pin]
263+
cancellation: WaitForCancellationFuture<'a>,
264+
#[pin]
265+
future: F,
266+
}
267+
}
268+
269+
impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
270+
type Output = Option<F::Output>;
271+
272+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
273+
let this = self.project();
274+
if let Poll::Ready(res) = this.future.poll(cx) {
275+
Poll::Ready(Some(res))
276+
} else if this.cancellation.poll(cx).is_ready() {
277+
Poll::Ready(None)
278+
} else {
279+
Poll::Pending
280+
}
281+
}
282+
}
283+
284+
RunUntilCancelledFuture {
285+
cancellation: self.cancelled(),
286+
future: fut,
287+
}
288+
.await
289+
}
244290
}
245291

246292
// ===== impl WaitForCancellationFuture =====

tokio-util/tests/sync_cancellation_token.rs

+48
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#![warn(rust_2018_idioms)]
22

33
use tokio::pin;
4+
use tokio::sync::oneshot;
45
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
56

67
use core::future::Future;
@@ -445,3 +446,50 @@ fn derives_send_sync() {
445446
assert_send::<WaitForCancellationFuture<'static>>();
446447
assert_sync::<WaitForCancellationFuture<'static>>();
447448
}
449+
450+
#[test]
451+
fn run_until_cancelled_test() {
452+
let (waker, _) = new_count_waker();
453+
454+
{
455+
let token = CancellationToken::new();
456+
457+
let fut = token.run_until_cancelled(std::future::pending::<()>());
458+
pin!(fut);
459+
460+
assert_eq!(
461+
Poll::Pending,
462+
fut.as_mut().poll(&mut Context::from_waker(&waker))
463+
);
464+
465+
token.cancel();
466+
467+
assert_eq!(
468+
Poll::Ready(None),
469+
fut.as_mut().poll(&mut Context::from_waker(&waker))
470+
);
471+
}
472+
473+
{
474+
let (tx, rx) = oneshot::channel::<()>();
475+
476+
let token = CancellationToken::new();
477+
let fut = token.run_until_cancelled(async move {
478+
rx.await.unwrap();
479+
42
480+
});
481+
pin!(fut);
482+
483+
assert_eq!(
484+
Poll::Pending,
485+
fut.as_mut().poll(&mut Context::from_waker(&waker))
486+
);
487+
488+
tx.send(()).unwrap();
489+
490+
assert_eq!(
491+
Poll::Ready(Some(42)),
492+
fut.as_mut().poll(&mut Context::from_waker(&waker))
493+
);
494+
}
495+
}

0 commit comments

Comments
 (0)