Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

net: add UdpSocket::peek_sender() #5520

Merged
merged 3 commits into from
Mar 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
net: add UdpSocket::peek_sender() and variants
closes #5491
abonander committed Mar 7, 2023

Verified

This commit was signed with the committer’s verified signature.
abonander Austin Bonander
commit 4a5cac6419ca91ab758fa87a8a8c62319944acc6
4 changes: 2 additions & 2 deletions tokio/Cargo.toml
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ num_cpus = { version = "1.8.0", optional = true }
parking_lot = { version = "0.12.0", optional = true }

[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dependencies]
socket2 = { version = "0.4.4", optional = true, features = [ "all" ] }
socket2 = { version = "0.4.9", optional = true, features = [ "all" ] }

# Currently unstable. The API exposed by these features may be broken at any time.
# Requires `--cfg tokio_unstable` to enable.
@@ -146,7 +146,7 @@ mockall = "0.11.1"
async-stream = "0.3"

[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dev-dependencies]
socket2 = "0.4"
socket2 = "0.4.9"
tempfile = "3.1.0"

[target.'cfg(not(all(any(target_arch = "wasm32", target_arch = "wasm64"), target_os = "unknown")))'.dev-dependencies]
69 changes: 69 additions & 0 deletions tokio/src/net/udp.rs
Original file line number Diff line number Diff line change
@@ -1331,6 +1331,11 @@ impl UdpSocket {
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// MacOS will return an error if you pass a zero-sized buffer.
///
/// If you're merely interested in learning the sender of the data at the head of the queue,
/// try [`peek_sender`].
///
/// # Examples
///
/// ```no_run
@@ -1349,6 +1354,8 @@ impl UdpSocket {
/// Ok(())
/// }
/// ```
///
/// [`peek_sender`]: method@Self::peek_sender
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.io
.registration()
@@ -1371,6 +1378,11 @@ impl UdpSocket {
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// MacOS will return an error if you pass a zero-sized buffer.
///
/// If you're merely interested in learning the sender of the data at the head of the queue,
/// try [`poll_peek_sender`].
///
/// # Return value
///
/// The function returns:
@@ -1382,6 +1394,8 @@ impl UdpSocket {
/// # Errors
///
/// This function may encounter any standard I/O error except `WouldBlock`.
///
/// [`poll_peek_sender`]: method@Self::poll_peek_sender
pub fn poll_peek_from(
&self,
cx: &mut Context<'_>,
@@ -1404,6 +1418,61 @@ impl UdpSocket {
Poll::Ready(Ok(addr))
}

/// Retrieve the sender of the data at the head of the input queue, waiting if empty.
///
/// This is equivalent to calling [`peek_from`] with a zero-sized buffer,
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
///
/// [`peek_from`]: method@Self::peek_from
pub async fn peek_sender(&self) -> io::Result<SocketAddr> {
self.io
.registration()
.async_io(Interest::READABLE, || self.peek_sender_inner())
.await
}

/// Retrieve the sender of the data at the head of the input queue,
/// scheduling a wakeup if empty.
///
/// This is equivalent to calling [`poll_peek_from`] with a zero-sized buffer,
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
///
/// # Notes
///
/// Note that on multiple calls to a `poll_*` method in the recv direction, only the
/// `Waker` from the `Context` passed to the most recent call will be scheduled to
/// receive a wakeup.
///
/// [`poll_peek_from`]: method@Self::poll_peek_from
pub fn poll_peek_sender(&self, cx: &mut Context<'_>) -> Poll<io::Result<SocketAddr>> {
self.io
.registration()
.poll_read_io(cx, || self.peek_sender_inner())
}

/// Try to retrieve the sender of the data at the head of the input queue.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
pub fn try_peek_sender(&self) -> io::Result<SocketAddr> {
self.io
.registration()
.try_io(Interest::READABLE, || self.peek_sender_inner())
}

#[inline]
fn peek_sender_inner(&self) -> io::Result<SocketAddr> {
self.io.try_io(|| {
self.as_socket()
.peek_sender()?
// May be `None` if the platform doesn't populate the sender for some reason.
// In testing, that only occurred on macOS if you pass a zero-sized buffer,
// but the implementation of `Socket::peek_sender()` covers that.
.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available"))
})
}

/// Gets the value of the `SO_BROADCAST` option for this socket.
///
/// For more information about this option, see [`set_broadcast`].
86 changes: 86 additions & 0 deletions tokio/tests/udp.rs
Original file line number Diff line number Diff line change
@@ -134,6 +134,92 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
sender.send_to(msg, receiver_addr).await?;

let peeked_sender = receiver.peek_sender().await?;
assert_eq!(peeked_sender, sender_addr);

// Assert that `peek_sender()` returns the right sender but
// doesn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
let (read, received_sender) = receiver.recv_from(&mut recv_buf).await?;

assert_eq!(&recv_buf[..read], msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn poll_peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
poll_fn(|cx| sender.poll_send_to(cx, msg, receiver_addr)).await?;

let peeked_sender = poll_fn(|cx| receiver.poll_peek_sender(cx)).await?;
assert_eq!(peeked_sender, sender_addr);

// Assert that `poll_peek_sender()` returns the right sender but
// doesn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
let received_sender = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;

assert_eq!(read.filled(), msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn try_peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
sender.send_to(msg, receiver_addr).await?;

let peeked_sender = loop {
match receiver.try_peek_sender() {
Ok(peeked_sender) => break peeked_sender,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
receiver.readable().await?;
}
Err(e) => return Err(e),
}
};

assert_eq!(peeked_sender, sender_addr);

// Assert that `try_peek_sender()` returns the right sender but
// didn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
// We already peeked the sender so there must be data in the receive queue.
let (read, received_sender) = receiver.try_recv_from(&mut recv_buf).unwrap();

assert_eq!(&recv_buf[..read], msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn split() -> std::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:0").await?;