diff --git a/tokio/src/io/util/buf_reader.rs b/tokio/src/io/util/buf_reader.rs index 271f61bf2ae..cc65ef2099f 100644 --- a/tokio/src/io/util/buf_reader.rs +++ b/tokio/src/io/util/buf_reader.rs @@ -1,11 +1,11 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io; +use std::io::{self, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{cmp, fmt}; +use std::{cmp, fmt, mem}; pin_project! { /// The `BufReader` struct adds buffering to any reader. @@ -30,6 +30,7 @@ pin_project! { pub(super) buf: Box<[u8]>, pub(super) pos: usize, pub(super) cap: usize, + pub(super) seek_state: SeekState, } } @@ -48,6 +49,7 @@ impl BufReader { buf: buffer.into_boxed_slice(), pos: 0, cap: 0, + seek_state: SeekState::Init, } } @@ -141,6 +143,122 @@ impl AsyncBufRead for BufReader { } } +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of the first poll_complete in the `n.checked_sub(remainder).is_none()` branch. + PendingOverflowed(i64), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seek to an offset, in bytes, in the underlying reader. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying reader would be at if the `BufReader` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl AsyncSeek for BufReader { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We needs to call seek operation multiple times. + // And we should always call both start_seek and poll_complete, + // as start_seek alone cannot guarantee that the operation will be completed. + // poll_complete receives a Context and returns a Poll, so it cannot be called + // inside start_seek. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) { + SeekState::Init => { + // 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + SeekState::Start(SeekFrom::Current(n)) => { + let remainder = (self.cap - self.pos) as i64; + // it should be safe to assume that remainder fits within an i64 as the alternative + // means we managed to allocate 8 exbibytes and that's absurd. + // But it's not out of the realm of possibility for some weird underlying reader to + // support seeking by i64::min_value() so we need to handle underflow when subtracting + // remainder. + if let Some(offset) = n.checked_sub(remainder) { + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(offset))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } else { + // seek backwards by our remainder, and then by the offset + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(-remainder))?; + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + } + SeekState::PendingOverflowed(n) => { + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Start(pos) => { + // Seeking with Start/End doesn't care about our buffer length. + self.as_mut().get_pin_mut().start_seek(pos)?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?, + }; + + match res { + Poll::Ready(res) => { + self.discard_buffer(); + Poll::Ready(Ok(res)) + } + Poll::Pending => { + *self.as_mut().project().seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + impl AsyncWrite for BufReader { fn poll_write( self: Pin<&mut Self>, diff --git a/tokio/src/io/util/buf_stream.rs b/tokio/src/io/util/buf_stream.rs index cc857e225bc..92386658e3d 100644 --- a/tokio/src/io/util/buf_stream.rs +++ b/tokio/src/io/util/buf_stream.rs @@ -94,9 +94,11 @@ impl From>> for BufStream { buf: rbuf, pos, cap, + seek_state: rseek_state, }, buf: wbuf, written, + seek_state: wseek_state, } = b; BufStream { @@ -105,10 +107,12 @@ impl From>> for BufStream { inner, buf: wbuf, written, + seek_state: wseek_state, }, buf: rbuf, pos, cap, + seek_state: rseek_state, }, } } diff --git a/tokio/src/io/util/buf_writer.rs b/tokio/src/io/util/buf_writer.rs index 5e3d4b710f2..4e8e493cefe 100644 --- a/tokio/src/io/util/buf_writer.rs +++ b/tokio/src/io/util/buf_writer.rs @@ -1,9 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, Write}; +use std::io::{self, SeekFrom, Write}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -34,6 +34,7 @@ pin_project! { pub(super) inner: W, pub(super) buf: Vec, pub(super) written: usize, + pub(super) seek_state: SeekState, } } @@ -50,6 +51,7 @@ impl BufWriter { inner, buf: Vec::with_capacity(cap), written: 0, + seek_state: SeekState::Init, } } @@ -142,6 +144,62 @@ impl AsyncWrite for BufWriter { } } +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seek to the offset, in bytes, in the underlying writer. +/// +/// Seeking always writes out the internal buffer before seeking. +impl AsyncSeek for BufWriter { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We need to flush the internal buffer before seeking. + // It receives a `Context` and returns a `Poll`, so it cannot be called + // inside `start_seek`. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let pos = match self.seek_state { + SeekState::Init => { + return self.project().inner.poll_complete(cx); + } + SeekState::Start(pos) => Some(pos), + SeekState::Pending => None, + }; + + // Flush the internal buffer before seeking. + ready!(self.as_mut().flush_buf(cx))?; + + let mut me = self.project(); + if let Some(pos) = pos { + // Ensure previous seeks have finished before starting a new one + ready!(me.inner.as_mut().poll_complete(cx))?; + if let Err(e) = me.inner.as_mut().start_seek(pos) { + *me.seek_state = SeekState::Init; + return Poll::Ready(Err(e)); + } + } + match me.inner.poll_complete(cx) { + Poll::Ready(res) => { + *me.seek_state = SeekState::Init; + Poll::Ready(res) + } + Poll::Pending => { + *me.seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + impl AsyncRead for BufWriter { fn poll_read( self: Pin<&mut Self>, diff --git a/tokio/tests/io_buf_reader.rs b/tokio/tests/io_buf_reader.rs new file mode 100644 index 00000000000..ac5f11c727c --- /dev/null +++ b/tokio/tests/io_buf_reader.rs @@ -0,0 +1,362 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +// https://github.com/rust-lang/futures-rs/blob/1803948ff091b4eabf7f3bf39e16bbbdefca5cc8/futures/tests/io_buf_reader.rs + +use futures::task::{noop_waker_ref, Context, Poll}; +use std::cmp; +use std::io::{self, Cursor}; +use std::pin::Pin; +use tokio::io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, BufReader, + ReadBuf, SeekFrom, +}; + +macro_rules! run_fill_buf { + ($reader:expr) => {{ + let mut cx = Context::from_waker(noop_waker_ref()); + loop { + if let Poll::Ready(x) = Pin::new(&mut $reader).poll_fill_buf(&mut cx) { + break x; + } + } + }}; +} + +struct MaybePending<'a> { + inner: &'a [u8], + ready_read: bool, + ready_fill_buf: bool, +} + +impl<'a> MaybePending<'a> { + fn new(inner: &'a [u8]) -> Self { + Self { + inner, + ready_read: false, + ready_fill_buf: false, + } + } +} + +impl AsyncRead for MaybePending<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.ready_read { + self.ready_read = false; + Pin::new(&mut self.inner).poll_read(cx, buf) + } else { + self.ready_read = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +impl AsyncBufRead for MaybePending<'_> { + fn poll_fill_buf(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + if self.ready_fill_buf { + self.ready_fill_buf = false; + if self.inner.is_empty() { + return Poll::Ready(Ok(&[])); + } + let len = cmp::min(2, self.inner.len()); + Poll::Ready(Ok(&self.inner[0..len])) + } else { + self.ready_fill_buf = true; + Poll::Pending + } + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + self.inner = &self.inner[amt..]; + } +} + +#[tokio::test] +async fn test_buffered_reader() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, inner); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 3); + assert_eq!(buf, [5, 6, 7]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 2); + assert_eq!(buf, [0, 1]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [2]); + assert_eq!(reader.buffer(), [3]); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [3, 0, 0]); + assert_eq!(reader.buffer(), []); + + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [4, 0, 0]); + assert_eq!(reader.buffer(), []); + + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn test_buffered_reader_seek() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, Cursor::new(inner)); + + assert_eq!(reader.seek(SeekFrom::Start(3)).await.unwrap(), 3); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert!(reader + .seek(SeekFrom::Current(i64::min_value())) + .await + .is_err()); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert_eq!(reader.seek(SeekFrom::Current(1)).await.unwrap(), 4); + assert_eq!(run_fill_buf!(reader).unwrap(), &[1, 2][..]); + Pin::new(&mut reader).consume(1); + assert_eq!(reader.seek(SeekFrom::Current(-2)).await.unwrap(), 3); +} + +#[tokio::test] +async fn test_buffered_reader_seek_underflow() { + // gimmick reader that yields its position modulo 256 for each byte + struct PositionReader { + pos: u64, + } + impl AsyncRead for PositionReader { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let b = buf.initialize_unfilled(); + let len = b.len(); + for x in b { + *x = self.pos as u8; + self.pos = self.pos.wrapping_add(1); + } + buf.advance(len); + Poll::Ready(Ok(())) + } + } + impl AsyncSeek for PositionReader { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + match pos { + SeekFrom::Start(n) => { + self.pos = n; + } + SeekFrom::Current(n) => { + self.pos = self.pos.wrapping_add(n as u64); + } + SeekFrom::End(n) => { + self.pos = u64::max_value().wrapping_add(n as u64); + } + } + Ok(()) + } + fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(self.pos)) + } + } + + let mut reader = BufReader::with_capacity(5, PositionReader { pos: 0 }); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1, 2, 3, 4][..]); + assert_eq!( + reader.seek(SeekFrom::End(-5)).await.unwrap(), + u64::max_value() - 5 + ); + assert_eq!(run_fill_buf!(reader).unwrap().len(), 5); + // the following seek will require two underlying seeks + let expected = 9_223_372_036_854_775_802; + assert_eq!( + reader + .seek(SeekFrom::Current(i64::min_value())) + .await + .unwrap(), + expected + ); + assert_eq!(run_fill_buf!(reader).unwrap().len(), 5); + // seeking to 0 should empty the buffer. + assert_eq!(reader.seek(SeekFrom::Current(0)).await.unwrap(), expected); + assert_eq!(reader.get_ref().pos, expected); +} + +#[tokio::test] +async fn test_short_reads() { + /// A dummy reader intended at testing short-reads propagation. + struct ShortReader { + lengths: Vec, + } + + impl AsyncRead for ShortReader { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if !self.lengths.is_empty() { + buf.advance(self.lengths.remove(0)); + } + Poll::Ready(Ok(())) + } + } + + let inner = ShortReader { + lengths: vec![0, 1, 2, 0, 1, 0], + }; + let mut reader = BufReader::new(inner); + let mut buf = [0, 0]; + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 1); + assert_eq!(reader.read(&mut buf).await.unwrap(), 2); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 1); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn maybe_pending() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, MaybePending::new(inner)); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 3); + assert_eq!(buf, [5, 6, 7]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 2); + assert_eq!(buf, [0, 1]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [2]); + assert_eq!(reader.buffer(), [3]); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [3, 0, 0]); + assert_eq!(reader.buffer(), []); + + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [4, 0, 0]); + assert_eq!(reader.buffer(), []); + + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn maybe_pending_buf_read() { + let inner = MaybePending::new(&[0, 1, 2, 3, 1, 0]); + let mut reader = BufReader::with_capacity(2, inner); + let mut v = Vec::new(); + reader.read_until(3, &mut v).await.unwrap(); + assert_eq!(v, [0, 1, 2, 3]); + v.clear(); + reader.read_until(1, &mut v).await.unwrap(); + assert_eq!(v, [1]); + v.clear(); + reader.read_until(8, &mut v).await.unwrap(); + assert_eq!(v, [0]); + v.clear(); + reader.read_until(9, &mut v).await.unwrap(); + assert_eq!(v, []); +} + +// https://github.com/rust-lang/futures-rs/pull/1573#discussion_r281162309 +#[tokio::test] +async fn maybe_pending_seek() { + struct MaybePendingSeek<'a> { + inner: Cursor<&'a [u8]>, + ready: bool, + seek_res: Option>, + } + + impl<'a> MaybePendingSeek<'a> { + fn new(inner: &'a [u8]) -> Self { + Self { + inner: Cursor::new(inner), + ready: true, + seek_res: None, + } + } + } + + impl AsyncRead for MaybePendingSeek<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } + } + + impl AsyncBufRead for MaybePendingSeek<'_> { + fn poll_fill_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this: *mut Self = &mut *self as *mut _; + Pin::new(&mut unsafe { &mut *this }.inner).poll_fill_buf(cx) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut self.inner).consume(amt) + } + } + + impl AsyncSeek for MaybePendingSeek<'_> { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.seek_res = Some(Pin::new(&mut self.inner).start_seek(pos)); + Ok(()) + } + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.ready { + self.ready = false; + self.seek_res.take().unwrap_or(Ok(()))?; + Pin::new(&mut self.inner).poll_complete(cx) + } else { + self.ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, MaybePendingSeek::new(inner)); + + assert_eq!(reader.seek(SeekFrom::Current(3)).await.unwrap(), 3); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert!(reader + .seek(SeekFrom::Current(i64::min_value())) + .await + .is_err()); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert_eq!(reader.seek(SeekFrom::Current(1)).await.unwrap(), 4); + assert_eq!(run_fill_buf!(reader).unwrap(), &[1, 2][..]); + Pin::new(&mut reader).consume(1); + assert_eq!(reader.seek(SeekFrom::Current(-2)).await.unwrap(), 3); +} diff --git a/tokio/tests/io_buf_writer.rs b/tokio/tests/io_buf_writer.rs new file mode 100644 index 00000000000..6f4f10a8e2e --- /dev/null +++ b/tokio/tests/io_buf_writer.rs @@ -0,0 +1,251 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +// https://github.com/rust-lang/futures-rs/blob/1803948ff091b4eabf7f3bf39e16bbbdefca5cc8/futures/tests/io_buf_writer.rs + +use futures::task::{Context, Poll}; +use std::io::{self, Cursor}; +use std::pin::Pin; +use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom}; + +struct MaybePending { + inner: Vec, + ready: bool, +} + +impl MaybePending { + fn new(inner: Vec) -> Self { + Self { + inner, + ready: false, + } + } +} + +impl AsyncWrite for MaybePending { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.ready { + self.ready = false; + Pin::new(&mut self.inner).poll_write(cx, buf) + } else { + self.ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +#[tokio::test] +async fn buf_writer() { + let mut writer = BufWriter::with_capacity(2, Vec::new()); + + writer.write(&[0, 1]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.write(&[2]).await.unwrap(); + assert_eq!(writer.buffer(), [2]); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.write(&[3]).await.unwrap(); + assert_eq!(writer.buffer(), [2, 3]); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3]); + + writer.write(&[4]).await.unwrap(); + writer.write(&[5]).await.unwrap(); + assert_eq!(writer.buffer(), [4, 5]); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3]); + + writer.write(&[6]).await.unwrap(); + assert_eq!(writer.buffer(), [6]); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5]); + + writer.write(&[7, 8]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8]); + + writer.write(&[9, 10, 11]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); +} + +#[tokio::test] +async fn buf_writer_inner_flushes() { + let mut w = BufWriter::with_capacity(3, Vec::new()); + w.write(&[0, 1]).await.unwrap(); + assert_eq!(*w.get_ref(), []); + w.flush().await.unwrap(); + let w = w.into_inner(); + assert_eq!(w, [0, 1]); +} + +#[tokio::test] +async fn buf_writer_seek() { + let mut w = BufWriter::with_capacity(3, Cursor::new(Vec::new())); + w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap(); + w.write_all(&[6, 7]).await.unwrap(); + assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8); + assert_eq!(&w.get_ref().get_ref()[..], &[0, 1, 2, 3, 4, 5, 6, 7][..]); + assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2); + w.write_all(&[8, 9]).await.unwrap(); + w.flush().await.unwrap(); + assert_eq!(&w.into_inner().into_inner()[..], &[0, 1, 8, 9, 4, 5, 6, 7]); +} + +#[tokio::test] +async fn maybe_pending_buf_writer() { + let mut writer = BufWriter::with_capacity(2, MaybePending::new(Vec::new())); + + writer.write(&[0, 1]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.write(&[2]).await.unwrap(); + assert_eq!(writer.buffer(), [2]); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.write(&[3]).await.unwrap(); + assert_eq!(writer.buffer(), [2, 3]); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]); + + writer.write(&[4]).await.unwrap(); + writer.write(&[5]).await.unwrap(); + assert_eq!(writer.buffer(), [4, 5]); + assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]); + + writer.write(&[6]).await.unwrap(); + assert_eq!(writer.buffer(), [6]); + assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5]); + + writer.write(&[7, 8]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5, 6, 7, 8]); + + writer.write(&[9, 10, 11]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!( + writer.get_ref().inner, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!( + &writer.get_ref().inner, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ); +} + +#[tokio::test] +async fn maybe_pending_buf_writer_inner_flushes() { + let mut w = BufWriter::with_capacity(3, MaybePending::new(Vec::new())); + w.write(&[0, 1]).await.unwrap(); + assert_eq!(&w.get_ref().inner, &[]); + w.flush().await.unwrap(); + let w = w.into_inner().inner; + assert_eq!(w, [0, 1]); +} + +#[tokio::test] +async fn maybe_pending_buf_writer_seek() { + struct MaybePendingSeek { + inner: Cursor>, + ready_write: bool, + ready_seek: bool, + seek_res: Option>, + } + + impl MaybePendingSeek { + fn new(inner: Vec) -> Self { + Self { + inner: Cursor::new(inner), + ready_write: false, + ready_seek: false, + seek_res: None, + } + } + } + + impl AsyncWrite for MaybePendingSeek { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.ready_write { + self.ready_write = false; + Pin::new(&mut self.inner).poll_write(cx, buf) + } else { + self.ready_write = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + } + + impl AsyncSeek for MaybePendingSeek { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.seek_res = Some(Pin::new(&mut self.inner).start_seek(pos)); + Ok(()) + } + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.ready_seek { + self.ready_seek = false; + self.seek_res.take().unwrap_or(Ok(()))?; + Pin::new(&mut self.inner).poll_complete(cx) + } else { + self.ready_seek = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + let mut w = BufWriter::with_capacity(3, MaybePendingSeek::new(Vec::new())); + w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap(); + w.write_all(&[6, 7]).await.unwrap(); + assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8); + assert_eq!( + &w.get_ref().inner.get_ref()[..], + &[0, 1, 2, 3, 4, 5, 6, 7][..] + ); + assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2); + w.write_all(&[8, 9]).await.unwrap(); + w.flush().await.unwrap(); + assert_eq!( + &w.into_inner().inner.into_inner()[..], + &[0, 1, 8, 9, 4, 5, 6, 7] + ); +}