Skip to content

Commit 4380c3d

Browse files
authored
sync: Added WeakSender to sync::broadcast::channel (tokio-rs#7100)
1 parent 383da87 commit 4380c3d

File tree

4 files changed

+349
-3
lines changed

4 files changed

+349
-3
lines changed

tokio/src/sync/broadcast.rs

+164-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ use std::future::Future;
128128
use std::marker::PhantomPinned;
129129
use std::pin::Pin;
130130
use std::ptr::NonNull;
131-
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
131+
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst};
132132
use std::task::{ready, Context, Poll, Waker};
133133

134134
/// Sending-half of the [`broadcast`] channel.
@@ -166,6 +166,40 @@ pub struct Sender<T> {
166166
shared: Arc<Shared<T>>,
167167
}
168168

169+
/// A sender that does not prevent the channel from being closed.
170+
///
171+
/// If all [`Sender`] instances of a channel were dropped and only `WeakSender`
172+
/// instances remain, the channel is closed.
173+
///
174+
/// In order to send messages, the `WeakSender` needs to be upgraded using
175+
/// [`WeakSender::upgrade`], which returns `Option<Sender>`. It returns `None`
176+
/// if all `Sender`s have been dropped, and otherwise it returns a `Sender`.
177+
///
178+
/// [`Sender`]: Sender
179+
/// [`WeakSender::upgrade`]: WeakSender::upgrade
180+
///
181+
/// # Examples
182+
///
183+
/// ```
184+
/// use tokio::sync::broadcast::channel;
185+
///
186+
/// #[tokio::main]
187+
/// async fn main() {
188+
/// let (tx, _rx) = channel::<i32>(15);
189+
/// let tx_weak = tx.downgrade();
190+
///
191+
/// // Upgrading will succeed because `tx` still exists.
192+
/// assert!(tx_weak.upgrade().is_some());
193+
///
194+
/// // If we drop `tx`, then it will fail.
195+
/// drop(tx);
196+
/// assert!(tx_weak.clone().upgrade().is_none());
197+
/// }
198+
/// ```
199+
pub struct WeakSender<T> {
200+
shared: Arc<Shared<T>>,
201+
}
202+
169203
/// Receiving-half of the [`broadcast`] channel.
170204
///
171205
/// Must not be used concurrently. Messages may be retrieved using
@@ -317,6 +351,9 @@ struct Shared<T> {
317351
/// Number of outstanding Sender handles.
318352
num_tx: AtomicUsize,
319353

354+
/// Number of outstanding weak Sender handles.
355+
num_weak_tx: AtomicUsize,
356+
320357
/// Notify when the last subscribed [`Receiver`] drops.
321358
notify_last_rx_drop: Notify,
322359
}
@@ -475,6 +512,9 @@ pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
475512
unsafe impl<T: Send> Send for Sender<T> {}
476513
unsafe impl<T: Send> Sync for Sender<T> {}
477514

515+
unsafe impl<T: Send> Send for WeakSender<T> {}
516+
unsafe impl<T: Send> Sync for WeakSender<T> {}
517+
478518
unsafe impl<T: Send> Send for Receiver<T> {}
479519
unsafe impl<T: Send> Sync for Receiver<T> {}
480520

@@ -533,6 +573,7 @@ impl<T> Sender<T> {
533573
waiters: LinkedList::new(),
534574
}),
535575
num_tx: AtomicUsize::new(1),
576+
num_weak_tx: AtomicUsize::new(0),
536577
notify_last_rx_drop: Notify::new(),
537578
});
538579

@@ -656,6 +697,18 @@ impl<T> Sender<T> {
656697
new_receiver(shared)
657698
}
658699

700+
/// Converts the `Sender` to a [`WeakSender`] that does not count
701+
/// towards RAII semantics, i.e. if all `Sender` instances of the
702+
/// channel were dropped and only `WeakSender` instances remain,
703+
/// the channel is closed.
704+
#[must_use = "Downgrade creates a WeakSender without destroying the original non-weak sender."]
705+
pub fn downgrade(&self) -> WeakSender<T> {
706+
self.shared.num_weak_tx.fetch_add(1, Relaxed);
707+
WeakSender {
708+
shared: self.shared.clone(),
709+
}
710+
}
711+
659712
/// Returns the number of queued values.
660713
///
661714
/// A value is queued until it has either been seen by all receivers that were alive at the time
@@ -858,6 +911,16 @@ impl<T> Sender<T> {
858911

859912
self.shared.notify_rx(tail);
860913
}
914+
915+
/// Returns the number of [`Sender`] handles.
916+
pub fn strong_count(&self) -> usize {
917+
self.shared.num_tx.load(Acquire)
918+
}
919+
920+
/// Returns the number of [`WeakSender`] handles.
921+
pub fn weak_count(&self) -> usize {
922+
self.shared.num_weak_tx.load(Acquire)
923+
}
861924
}
862925

863926
/// Create a new `Receiver` which reads starting from the tail.
@@ -998,20 +1061,76 @@ impl<T> Shared<T> {
9981061
impl<T> Clone for Sender<T> {
9991062
fn clone(&self) -> Sender<T> {
10001063
let shared = self.shared.clone();
1001-
shared.num_tx.fetch_add(1, SeqCst);
1064+
shared.num_tx.fetch_add(1, Relaxed);
10021065

10031066
Sender { shared }
10041067
}
10051068
}
10061069

10071070
impl<T> Drop for Sender<T> {
10081071
fn drop(&mut self) {
1009-
if 1 == self.shared.num_tx.fetch_sub(1, SeqCst) {
1072+
if 1 == self.shared.num_tx.fetch_sub(1, AcqRel) {
10101073
self.close_channel();
10111074
}
10121075
}
10131076
}
10141077

1078+
impl<T> WeakSender<T> {
1079+
/// Tries to convert a `WeakSender` into a [`Sender`].
1080+
///
1081+
/// This will return `Some` if there are other `Sender` instances alive and
1082+
/// the channel wasn't previously dropped, otherwise `None` is returned.
1083+
#[must_use]
1084+
pub fn upgrade(&self) -> Option<Sender<T>> {
1085+
let mut tx_count = self.shared.num_tx.load(Acquire);
1086+
1087+
loop {
1088+
if tx_count == 0 {
1089+
// channel is closed so this WeakSender can not be upgraded
1090+
return None;
1091+
}
1092+
1093+
match self
1094+
.shared
1095+
.num_tx
1096+
.compare_exchange_weak(tx_count, tx_count + 1, Relaxed, Acquire)
1097+
{
1098+
Ok(_) => {
1099+
return Some(Sender {
1100+
shared: self.shared.clone(),
1101+
})
1102+
}
1103+
Err(prev_count) => tx_count = prev_count,
1104+
}
1105+
}
1106+
}
1107+
1108+
/// Returns the number of [`Sender`] handles.
1109+
pub fn strong_count(&self) -> usize {
1110+
self.shared.num_tx.load(Acquire)
1111+
}
1112+
1113+
/// Returns the number of [`WeakSender`] handles.
1114+
pub fn weak_count(&self) -> usize {
1115+
self.shared.num_weak_tx.load(Acquire)
1116+
}
1117+
}
1118+
1119+
impl<T> Clone for WeakSender<T> {
1120+
fn clone(&self) -> WeakSender<T> {
1121+
let shared = self.shared.clone();
1122+
shared.num_weak_tx.fetch_add(1, Relaxed);
1123+
1124+
Self { shared }
1125+
}
1126+
}
1127+
1128+
impl<T> Drop for WeakSender<T> {
1129+
fn drop(&mut self) {
1130+
self.shared.num_weak_tx.fetch_sub(1, AcqRel);
1131+
}
1132+
}
1133+
10151134
impl<T> Receiver<T> {
10161135
/// Returns the number of messages that were sent into the channel and that
10171136
/// this [`Receiver`] has yet to receive.
@@ -1213,6 +1332,42 @@ impl<T> Receiver<T> {
12131332

12141333
Ok(RecvGuard { slot })
12151334
}
1335+
1336+
/// Returns the number of [`Sender`] handles.
1337+
pub fn sender_strong_count(&self) -> usize {
1338+
self.shared.num_tx.load(Acquire)
1339+
}
1340+
1341+
/// Returns the number of [`WeakSender`] handles.
1342+
pub fn sender_weak_count(&self) -> usize {
1343+
self.shared.num_weak_tx.load(Acquire)
1344+
}
1345+
1346+
/// Checks if a channel is closed.
1347+
///
1348+
/// This method returns `true` if the channel has been closed. The channel is closed
1349+
/// when all [`Sender`] have been dropped.
1350+
///
1351+
/// [`Sender`]: crate::sync::broadcast::Sender
1352+
///
1353+
/// # Examples
1354+
/// ```
1355+
/// use tokio::sync::broadcast;
1356+
///
1357+
/// #[tokio::main]
1358+
/// async fn main() {
1359+
/// let (tx, rx) = broadcast::channel::<()>(10);
1360+
/// assert!(!rx.is_closed());
1361+
///
1362+
/// drop(tx);
1363+
///
1364+
/// assert!(rx.is_closed());
1365+
/// }
1366+
/// ```
1367+
pub fn is_closed(&self) -> bool {
1368+
// Channel is closed when there are no strong senders left active
1369+
self.shared.num_tx.load(Acquire) == 0
1370+
}
12161371
}
12171372

12181373
impl<T: Clone> Receiver<T> {
@@ -1534,6 +1689,12 @@ impl<T> fmt::Debug for Sender<T> {
15341689
}
15351690
}
15361691

1692+
impl<T> fmt::Debug for WeakSender<T> {
1693+
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1694+
write!(fmt, "broadcast::WeakSender")
1695+
}
1696+
}
1697+
15371698
impl<T> fmt::Debug for Receiver<T> {
15381699
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
15391700
write!(fmt, "broadcast::Receiver")

tokio/tests/async_send_sync.rs

+3
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,9 @@ assert_value!(tokio::sync::broadcast::Receiver<YY>: Send & Sync & Unpin);
394394
assert_value!(tokio::sync::broadcast::Sender<NN>: !Send & !Sync & Unpin);
395395
assert_value!(tokio::sync::broadcast::Sender<YN>: Send & Sync & Unpin);
396396
assert_value!(tokio::sync::broadcast::Sender<YY>: Send & Sync & Unpin);
397+
assert_value!(tokio::sync::broadcast::WeakSender<NN>: !Send & !Sync & Unpin);
398+
assert_value!(tokio::sync::broadcast::WeakSender<YN>: Send & Sync & Unpin);
399+
assert_value!(tokio::sync::broadcast::WeakSender<YY>: Send & Sync & Unpin);
397400
assert_value!(tokio::sync::futures::Notified<'_>: Send & Sync & !Unpin);
398401
assert_value!(tokio::sync::mpsc::OwnedPermit<NN>: !Send & !Sync & Unpin);
399402
assert_value!(tokio::sync::mpsc::OwnedPermit<YN>: Send & Sync & Unpin);

tokio/tests/sync_broadcast.rs

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ macro_rules! assert_closed {
5656
trait AssertSend: Send + Sync {}
5757
impl AssertSend for broadcast::Sender<i32> {}
5858
impl AssertSend for broadcast::Receiver<i32> {}
59+
impl AssertSend for broadcast::WeakSender<i32> {}
5960

6061
#[test]
6162
fn send_try_recv_bounded() {

0 commit comments

Comments
 (0)