Skip to content

Commit 8af0baa

Browse files
committed
fix: Encoder state machine (Nullus157#308)
1 parent e259060 commit 8af0baa

File tree

2 files changed

+50
-118
lines changed

2 files changed

+50
-118
lines changed

src/tokio/write/generic/encoder.rs

+42-109
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,13 @@ use futures_core::ready;
1313
use pin_project_lite::pin_project;
1414
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
1515

16-
#[derive(Debug)]
17-
enum State {
18-
Encoding,
19-
Finishing,
20-
Done,
21-
}
22-
2316
pin_project! {
2417
#[derive(Debug)]
2518
pub struct Encoder<W, E> {
2619
#[pin]
2720
writer: BufWriter<W>,
2821
encoder: E,
29-
state: State,
22+
finished: bool
3023
}
3124
}
3225

@@ -35,7 +28,7 @@ impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
3528
Self {
3629
writer: BufWriter::new(writer),
3730
encoder,
38-
state: State::Encoding,
31+
finished: false,
3932
}
4033
}
4134
}
@@ -62,97 +55,6 @@ impl<W, E> Encoder<W, E> {
6255
}
6356
}
6457

65-
impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
66-
fn do_poll_write(
67-
self: Pin<&mut Self>,
68-
cx: &mut Context<'_>,
69-
input: &mut PartialBuffer<&[u8]>,
70-
) -> Poll<io::Result<()>> {
71-
let mut this = self.project();
72-
73-
loop {
74-
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
75-
let mut output = PartialBuffer::new(output);
76-
77-
*this.state = match this.state {
78-
State::Encoding => {
79-
this.encoder.encode(input, &mut output)?;
80-
State::Encoding
81-
}
82-
83-
State::Finishing | State::Done => {
84-
return Poll::Ready(Err(io::Error::new(
85-
io::ErrorKind::Other,
86-
"Write after shutdown",
87-
)))
88-
}
89-
};
90-
91-
let produced = output.written().len();
92-
this.writer.as_mut().produce(produced);
93-
94-
if input.unwritten().is_empty() {
95-
return Poll::Ready(Ok(()));
96-
}
97-
}
98-
}
99-
100-
fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
101-
let mut this = self.project();
102-
103-
loop {
104-
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
105-
let mut output = PartialBuffer::new(output);
106-
107-
let done = match this.state {
108-
State::Encoding => this.encoder.flush(&mut output)?,
109-
110-
State::Finishing | State::Done => {
111-
return Poll::Ready(Err(io::Error::new(
112-
io::ErrorKind::Other,
113-
"Flush after shutdown",
114-
)))
115-
}
116-
};
117-
118-
let produced = output.written().len();
119-
this.writer.as_mut().produce(produced);
120-
121-
if done {
122-
return Poll::Ready(Ok(()));
123-
}
124-
}
125-
}
126-
127-
fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
128-
let mut this = self.project();
129-
130-
loop {
131-
let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
132-
let mut output = PartialBuffer::new(output);
133-
134-
*this.state = match this.state {
135-
State::Encoding | State::Finishing => {
136-
if this.encoder.finish(&mut output)? {
137-
State::Done
138-
} else {
139-
State::Finishing
140-
}
141-
}
142-
143-
State::Done => State::Done,
144-
};
145-
146-
let produced = output.written().len();
147-
this.writer.as_mut().produce(produced);
148-
149-
if let State::Done = this.state {
150-
return Poll::Ready(Ok(()));
151-
}
152-
}
153-
}
154-
}
155-
15658
impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
15759
fn poll_write(
15860
self: Pin<&mut Self>,
@@ -163,24 +65,55 @@ impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
16365
return Poll::Ready(Ok(0));
16466
}
16567

166-
let mut input = PartialBuffer::new(buf);
68+
let mut this = self.project();
69+
70+
let mut encodeme = PartialBuffer::new(buf);
16771

168-
match self.do_poll_write(cx, &mut input)? {
169-
Poll::Pending if input.written().is_empty() => Poll::Pending,
170-
_ => Poll::Ready(Ok(input.written().len())),
72+
loop {
73+
let mut space =
74+
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
75+
this.encoder.encode(&mut encodeme, &mut space)?;
76+
let bytes_encoded = space.written().len();
77+
this.writer.as_mut().produce(bytes_encoded);
78+
if encodeme.unwritten().is_empty() {
79+
break;
80+
}
17181
}
82+
83+
Poll::Ready(Ok(encodeme.written().len()))
17284
}
17385

17486
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175-
ready!(self.as_mut().do_poll_flush(cx))?;
176-
ready!(self.project().writer.as_mut().poll_flush(cx))?;
87+
let mut this = self.project();
88+
loop {
89+
let mut space =
90+
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
91+
let flushed = this.encoder.flush(&mut space)?;
92+
let bytes_encoded = space.written().len();
93+
this.writer.as_mut().produce(bytes_encoded);
94+
if flushed {
95+
break;
96+
}
97+
}
17798
Poll::Ready(Ok(()))
17899
}
179100

180101
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181-
ready!(self.as_mut().do_poll_shutdown(cx))?;
182-
ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
183-
Poll::Ready(Ok(()))
102+
let mut this = self.project();
103+
if !*this.finished {
104+
loop {
105+
let mut space =
106+
PartialBuffer::new(ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?);
107+
let finished = this.encoder.finish(&mut space)?;
108+
let bytes_encoded = space.written().len();
109+
this.writer.as_mut().produce(bytes_encoded);
110+
if finished {
111+
*this.finished = true;
112+
break;
113+
}
114+
}
115+
}
116+
this.writer.poll_shutdown(cx)
184117
}
185118
}
186119

tests/issues.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use tracing_subscriber::fmt::format::FmtSpan;
2323
/// [`tokio_util::codec`](https://docs.rs/tokio-util/latest/tokio_util/codec)
2424
/// [`poll_shutdown`](AsyncWrite::poll_shutdown)
2525
/// [`poll_flush`](AsyncWrite::poll_flush)
26-
#[should_panic = "Flush after shutdown"] // TODO: this should be removed when the bug is fixed
2726
#[test]
2827
fn issue_246() {
2928
tracing_subscriber::fmt()
@@ -34,26 +33,26 @@ fn issue_246() {
3433
.with_target(false)
3534
.with_span_events(FmtSpan::NEW)
3635
.init();
37-
let mut zstd_encoder =
38-
Transparent::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default())));
36+
let mut zstd_encoder = Wrapper::new(Trace::new(ZstdEncoder::new(DelayedShutdown::default())));
3937
futures::executor::block_on(zstd_encoder.shutdown()).unwrap();
4038
}
4139

4240
pin_project_lite::pin_project! {
4341
/// A simple wrapper struct that follows the [`AsyncWrite`] protocol.
44-
struct Transparent<T> {
42+
/// This is a stand-in for combinators like `tokio_util::codec`s
43+
struct Wrapper<T> {
4544
#[pin] inner: T
4645
}
4746
}
4847

49-
impl<T> Transparent<T> {
48+
impl<T> Wrapper<T> {
5049
fn new(inner: T) -> Self {
5150
Self { inner }
5251
}
5352
}
5453

55-
impl<T: AsyncWrite> AsyncWrite for Transparent<T> {
56-
#[tracing::instrument(name = "Transparent::poll_write", skip_all, ret)]
54+
impl<T: AsyncWrite> AsyncWrite for Wrapper<T> {
55+
#[tracing::instrument(name = "Wrapper::poll_write", skip_all, ret)]
5756
fn poll_write(
5857
self: Pin<&mut Self>,
5958
cx: &mut Context<'_>,
@@ -62,7 +61,7 @@ impl<T: AsyncWrite> AsyncWrite for Transparent<T> {
6261
self.project().inner.poll_write(cx, buf)
6362
}
6463

65-
#[tracing::instrument(name = "Transparent::poll_flush", skip_all, ret)]
64+
#[tracing::instrument(name = "Wrapper::poll_flush", skip_all, ret)]
6665
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
6766
self.project().inner.poll_flush(cx)
6867
}
@@ -72,7 +71,7 @@ impl<T: AsyncWrite> AsyncWrite for Transparent<T> {
7271
/// > Once this method returns Ready it implies that a flush successfully happened before the shutdown happened.
7372
/// > That is, callers don't need to call flush before calling shutdown.
7473
/// > They can rely that by calling shutdown any pending buffered data will be written out.
75-
#[tracing::instrument(name = "Transparent::poll_shutdown", skip_all, ret)]
74+
#[tracing::instrument(name = "Wrapper::poll_shutdown", skip_all, ret)]
7675
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
7776
let mut this = self.project();
7877
ready!(this.inner.as_mut().poll_flush(cx))?;

0 commit comments

Comments
 (0)