From f77ade0df7cdeaad3ab388e4ad07fa26ffe422fa Mon Sep 17 00:00:00 2001 From: Aatif Syed Date: Fri, 1 Nov 2024 20:25:43 +0000 Subject: [PATCH] fix: Encoder state machine (#308) --- src/tokio/write/generic/encoder.rs | 151 ++++++++--------------------- tests/issues.rs | 11 +-- 2 files changed, 47 insertions(+), 115 deletions(-) diff --git a/src/tokio/write/generic/encoder.rs b/src/tokio/write/generic/encoder.rs index f5a83aa..421f064 100644 --- a/src/tokio/write/generic/encoder.rs +++ b/src/tokio/write/generic/encoder.rs @@ -13,20 +13,13 @@ use futures_core::ready; use pin_project_lite::pin_project; use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; -#[derive(Debug)] -enum State { - Encoding, - Finishing, - Done, -} - pin_project! { #[derive(Debug)] pub struct Encoder { #[pin] writer: BufWriter, encoder: E, - state: State, + finished: bool } } @@ -35,7 +28,7 @@ impl Encoder { Self { writer: BufWriter::new(writer), encoder, - state: State::Encoding, + finished: false, } } } @@ -62,97 +55,6 @@ impl Encoder { } } -impl Encoder { - fn do_poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - input: &mut PartialBuffer<&[u8]>, - ) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding => { - this.encoder.encode(input, &mut output)?; - State::Encoding - } - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Write after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if input.unwritten().is_empty() { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - let done = match this.state { - State::Encoding => this.encoder.flush(&mut output)?, - - State::Finishing | State::Done => { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::Other, - "Flush after shutdown", - ))) - } - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if done { - return Poll::Ready(Ok(())); - } - } - } - - fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?; - let mut output = PartialBuffer::new(output); - - *this.state = match this.state { - State::Encoding | State::Finishing => { - if this.encoder.finish(&mut output)? { - State::Done - } else { - State::Finishing - } - } - - State::Done => State::Done, - }; - - let produced = output.written().len(); - this.writer.as_mut().produce(produced); - - if let State::Done = this.state { - return Poll::Ready(Ok(())); - } - } - } -} - impl AsyncWrite for Encoder { fn poll_write( self: Pin<&mut Self>, @@ -163,24 +65,55 @@ impl AsyncWrite for Encoder { return Poll::Ready(Ok(0)); } - let mut input = PartialBuffer::new(buf); + let mut this = self.project(); + + let mut encodeme = PartialBuffer::new(buf); - match self.do_poll_write(cx, &mut input)? { - Poll::Pending if input.written().is_empty() => Poll::Pending, - _ => Poll::Ready(Ok(input.written().len())), + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + this.encoder.encode(&mut encodeme, &mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if encodeme.unwritten().is_empty() { + break; + } } + + Poll::Ready(Ok(encodeme.written().len())) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_flush(cx))?; - ready!(self.project().writer.as_mut().poll_flush(cx))?; + let mut this = self.project(); + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + let flushed = this.encoder.flush(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if flushed { + break; + } + } Poll::Ready(Ok(())) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - ready!(self.as_mut().do_poll_shutdown(cx))?; - ready!(self.project().writer.as_mut().poll_shutdown(cx))?; - Poll::Ready(Ok(())) + let mut this = self.project(); + if !*this.finished { + loop { + let mut space = + PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?); + let finished = this.encoder.finish(&mut space)?; + let bytes_encoded = space.written().len(); + this.writer.as_mut().produce(bytes_encoded); + if finished { + *this.finished = true; + break; + } + } + } + this.writer.poll_shutdown(cx) } } diff --git a/tests/issues.rs b/tests/issues.rs index a913a96..7bdcc37 100644 --- a/tests/issues.rs +++ b/tests/issues.rs @@ -23,7 +23,6 @@ use tracing_subscriber::fmt::format::FmtSpan; /// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec) /// [`poll_shutdown`](AsyncWrite::poll_shutdown) /// [`poll_flush`](AsyncWrite::poll_flush) -#[should_panic = "Flush after shutdown"] // TODO: this should be removed when the bug is fixed #[test] fn issue_246() { tracing_subscriber::fmt() @@ -34,25 +33,25 @@ fn issue_246() { .with_target(false) .with_span_events(FmtSpan::NEW) .init(); - let mut zstd_encoder = - Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); + let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default()))); futures::executor::block_on(zstd_encoder.shutdown()).unwrap(); } pin_project_lite::pin_project! { /// A simple wrapper struct that follows the [`AsyncWrite`] protocol. - struct Transparent { + /// This is a stand-in for combinators like `tokio_util::codec`s + struct Wrapper { #[pin] inner: T } } -impl Transparent { +impl Wrapper { fn new(inner: T) -> Self { Self { inner } } } -impl AsyncWrite for Transparent { +impl AsyncWrite for Wrapper { #[tracing::instrument(name = "Transparent::poll_write", skip_all, ret)] fn poll_write( self: Pin<&mut Self>,