diff --git a/Cargo.toml b/Cargo.toml index 0ae108c1..bd52da8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,6 @@ bytes = "0.4.12" flate2 = "1.0.7" futures-preview = "0.3.0-alpha.15" pin-project = "0.3.2" + +[dev-dependencies] +rand = "0.6.5" diff --git a/src/stream/brotli.rs b/src/stream/brotli.rs index 75da1631..b39e127f 100644 --- a/src/stream/brotli.rs +++ b/src/stream/brotli.rs @@ -6,7 +6,7 @@ use std::io::Result; use brotli2::raw::{CoStatus, CompressOp}; pub use brotli2::{raw::Compress, CompressParams}; -use bytes::{BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use futures::{ready, stream::Stream}; use pin_project::unsafe_project; @@ -14,7 +14,7 @@ use pin_project::unsafe_project; pub struct BrotliStream>> { #[pin] inner: S, - flushing: bool, + flush: bool, compress: Compress, } @@ -26,14 +26,14 @@ impl>> Stream for BrotliStream { let this = self.project(); - if *this.flushing { + if *this.flush { return Poll::Ready(None); } let input_buffer = if let Some(bytes) = ready!(this.inner.poll_next(cx)) { bytes? } else { - *this.flushing = true; + *this.flush = true; Bytes::new() }; @@ -42,7 +42,7 @@ impl>> Stream for BrotliStream { let output_ref = &mut &mut [][..]; loop { let status = this.compress.compress( - if *this.flushing { + if *this.flush { CompressOp::Finish } else { CompressOp::Process @@ -51,7 +51,7 @@ impl>> Stream for BrotliStream { output_ref, )?; while let Some(buf) = this.compress.take_output(None) { - compressed_output.put(buf); + compressed_output.extend_from_slice(buf); } match status { CoStatus::Finished => break, @@ -67,7 +67,7 @@ impl>> BrotliStream { pub fn new(stream: S, compress: Compress) -> BrotliStream { BrotliStream { inner: stream, - flushing: false, + flush: false, compress, } } diff --git a/src/stream/flate.rs b/src/stream/flate.rs index b9353b89..383c554d 100644 --- a/src/stream/flate.rs +++ b/src/stream/flate.rs @@ -1,22 +1,31 @@ -use core::{ +use std::{ + io::Result, + mem, pin::Pin, task::{Context, Poll}, }; -use std::io::Result; use bytes::{Bytes, BytesMut}; -use flate2::FlushCompress; -pub use flate2::{Compress, Compression}; +pub(crate) use flate2::Compress; +use flate2::{FlushCompress, Status}; use futures::{ready, stream::Stream}; use pin_project::unsafe_project; +#[derive(Debug)] +enum State { + Reading, + Writing(Bytes), + Flushing, + Done, + Invalid, +} + #[unsafe_project(Unpin)] -pub struct CompressedStream>> { +pub(crate) struct CompressedStream>> { #[pin] inner: S, - flushing: bool, - input_buffer: Bytes, - output_buffer: BytesMut, + state: State, + output: BytesMut, compress: Compress, } @@ -24,49 +33,94 @@ impl>> Stream for CompressedStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - const OUTPUT_BUFFER_SIZE: usize = 8_000; + let mut this = self.project(); - let this = self.project(); + fn compress( + compress: &mut Compress, + input: &mut Bytes, + output: &mut BytesMut, + flush: FlushCompress, + ) -> Result<(Status, Bytes)> { + const OUTPUT_BUFFER_SIZE: usize = 8_000; - if this.input_buffer.is_empty() { - if *this.flushing { - return Poll::Ready(None); - } else if let Some(bytes) = ready!(this.inner.poll_next(cx)) { - *this.input_buffer = bytes?; - } else { - *this.flushing = true; + if output.len() < OUTPUT_BUFFER_SIZE { + output.resize(OUTPUT_BUFFER_SIZE, 0); } + + let (prior_in, prior_out) = (compress.total_in(), compress.total_out()); + let status = compress.compress(input, output, flush)?; + let input_len = compress.total_in() - prior_in; + let output_len = compress.total_out() - prior_out; + + input.advance(input_len as usize); + Ok((status, output.split_to(output_len as usize).freeze())) } - this.output_buffer.resize(OUTPUT_BUFFER_SIZE, 0); - - let flush = if *this.flushing { - FlushCompress::Finish - } else { - FlushCompress::None - }; - - let (prior_in, prior_out) = (this.compress.total_in(), this.compress.total_out()); - this.compress - .compress(this.input_buffer, this.output_buffer, flush)?; - let input = this.compress.total_in() - prior_in; - let output = this.compress.total_out() - prior_out; - - this.input_buffer.advance(input as usize); - Poll::Ready(Some(Ok(this - .output_buffer - .split_to(output as usize) - .freeze()))) + #[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058 + loop { + break match mem::replace(this.state, State::Invalid) { + State::Reading => { + *this.state = State::Reading; + *this.state = match ready!(this.inner.as_mut().poll_next(cx)) { + Some(chunk) => State::Writing(chunk?), + None => State::Flushing, + }; + continue; + } + + State::Writing(mut input) => { + if input.is_empty() { + *this.state = State::Reading; + continue; + } + + let (status, chunk) = compress( + &mut this.compress, + &mut input, + &mut this.output, + FlushCompress::None, + )?; + + *this.state = match status { + Status::Ok => State::Writing(input), + Status::StreamEnd => unreachable!(), + Status::BufError => panic!("unexpected BufError"), + }; + + Poll::Ready(Some(Ok(chunk))) + } + + State::Flushing => { + let (status, chunk) = compress( + &mut this.compress, + &mut Bytes::new(), + &mut this.output, + FlushCompress::Finish, + )?; + + *this.state = match status { + Status::Ok => State::Flushing, + Status::StreamEnd => State::Done, + Status::BufError => panic!("unexpected BufError"), + }; + + Poll::Ready(Some(Ok(chunk))) + } + + State::Done => Poll::Ready(None), + + State::Invalid => panic!("CompressedStream reached invalid state"), + }; + } } } impl>> CompressedStream { - pub fn new(stream: S, compress: Compress) -> CompressedStream { + pub(crate) fn new(stream: S, compress: Compress) -> CompressedStream { CompressedStream { inner: stream, - flushing: false, - input_buffer: Bytes::new(), - output_buffer: BytesMut::new(), + state: State::Reading, + output: BytesMut::new(), compress, } } diff --git a/src/stream/gzip.rs b/src/stream/gzip.rs index c9aeb69a..cbb18342 100644 --- a/src/stream/gzip.rs +++ b/src/stream/gzip.rs @@ -1,81 +1,141 @@ -use core::{ +use std::{ + io::Result, + mem, pin::Pin, task::{Context, Poll}, }; -use std::io::Result; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; pub use flate2::Compression; -use flate2::{Compress, Crc, FlushCompress}; +use flate2::{Compress, Crc, FlushCompress, Status}; use futures::{ready, stream::Stream}; use pin_project::unsafe_project; +#[derive(Debug)] +enum State { + WritingHeader(Compression), + Reading, + WritingChunk(Bytes), + FlushingData, + WritingFooter, + Done, + Invalid, +} + #[unsafe_project(Unpin)] +#[derive(Debug)] pub struct GzipStream>> { #[pin] inner: S, - flushing: bool, - input_buffer: Bytes, - output_buffer: BytesMut, + state: State, + output: BytesMut, crc: Crc, - header_appended: bool, - footer_appended: bool, compress: Compress, - level: Compression, } impl>> Stream for GzipStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { - const OUTPUT_BUFFER_SIZE: usize = 8_000; + let mut this = self.project(); - let this = self.project(); + fn compress( + compress: &mut Compress, + input: &mut Bytes, + output: &mut BytesMut, + crc: &mut Crc, + flush: FlushCompress, + ) -> Result<(Status, Bytes)> { + const OUTPUT_BUFFER_SIZE: usize = 8_000; + + if output.len() < OUTPUT_BUFFER_SIZE { + output.resize(OUTPUT_BUFFER_SIZE, 0); + } - if !*this.header_appended { - let header = get_header(*this.level); - *this.header_appended = true; - return Poll::Ready(Some(Ok(header))); + let (prior_in, prior_out) = (compress.total_in(), compress.total_out()); + let status = compress.compress(input, output, flush)?; + let input_len = compress.total_in() - prior_in; + let output_len = compress.total_out() - prior_out; + + crc.update(&input[0..input_len as usize]); + input.advance(input_len as usize); + Ok((status, output.split_to(output_len as usize).freeze())) } - if this.input_buffer.is_empty() { - if *this.flushing { - if !*this.footer_appended { - let mut footer = Bytes::from(&this.crc.sum().to_le_bytes()[..]); - let length_read = &this.crc.amount().to_le_bytes()[..]; - footer.extend_from_slice(length_read); - *this.footer_appended = true; - return Poll::Ready(Some(Ok(footer))); - } else { - return Poll::Ready(None); + #[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058 + loop { + break match mem::replace(this.state, State::Invalid) { + State::WritingHeader(level) => { + *this.state = State::Reading; + Poll::Ready(Some(Ok(get_header(level)))) } - } else if let Some(bytes) = ready!(this.inner.poll_next(cx)) { - *this.input_buffer = bytes?; - } else { - *this.flushing = true; - } - } - this.output_buffer.resize(OUTPUT_BUFFER_SIZE, 0); - - let flush = if *this.flushing { - FlushCompress::Finish - } else { - FlushCompress::None - }; - - let (prior_in, prior_out) = (this.compress.total_in(), this.compress.total_out()); - this.compress - .compress(this.input_buffer, this.output_buffer, flush)?; - let input = this.compress.total_in() - prior_in; - let output = this.compress.total_out() - prior_out; - - this.crc.update(&this.input_buffer.slice(0, input as usize)); - this.input_buffer.advance(input as usize); - Poll::Ready(Some(Ok(this - .output_buffer - .split_to(output as usize) - .freeze()))) + State::Reading => { + *this.state = State::Reading; + *this.state = match ready!(this.inner.as_mut().poll_next(cx)) { + Some(chunk) => State::WritingChunk(chunk?), + None => State::FlushingData, + }; + continue; + } + + State::WritingChunk(mut input) => { + if input.is_empty() { + *this.state = State::Reading; + continue; + } + + let (status, chunk) = compress( + &mut this.compress, + &mut input, + &mut this.output, + &mut this.crc, + FlushCompress::None, + )?; + + *this.state = match status { + Status::Ok => State::WritingChunk(input), + Status::StreamEnd => unreachable!(), + Status::BufError => panic!("unexpected BufError"), + }; + + Poll::Ready(Some(Ok(chunk))) + } + + State::FlushingData => { + let (status, chunk) = compress( + &mut this.compress, + &mut Bytes::new(), + &mut this.output, + &mut this.crc, + FlushCompress::Finish, + )?; + + *this.state = match status { + Status::StreamEnd => State::WritingFooter, + Status::Ok => State::FlushingData, + Status::BufError => panic!("unexpected BufError"), + }; + + Poll::Ready(Some(Ok(chunk))) + } + + State::WritingFooter => { + let mut footer = BytesMut::with_capacity(8); + + footer.put(this.crc.sum().to_le_bytes().as_ref()); + footer.put(this.crc.amount().to_le_bytes().as_ref()); + + *this.state = State::Done; + + Poll::Ready(Some(Ok(footer.freeze()))) + } + + State::Done => Poll::Ready(None), + + State::Invalid => panic!("GzipStream reached invalid state"), + }; + } } } @@ -83,14 +143,10 @@ impl>> GzipStream { pub fn new(stream: S, level: Compression) -> GzipStream { GzipStream { inner: stream, - flushing: false, - input_buffer: Bytes::new(), - output_buffer: BytesMut::new(), + state: State::WritingHeader(level), + output: BytesMut::new(), crc: Crc::new(), - header_appended: false, - footer_appended: false, compress: Compress::new(level, false), - level, } } } diff --git a/tests/brotli.rs b/tests/brotli.rs index 22908c9c..fc8e0e98 100644 --- a/tests/brotli.rs +++ b/tests/brotli.rs @@ -6,6 +6,7 @@ use futures::{ stream::{self, StreamExt}, }; use std::io::{self, Read}; +use std::iter::FromIterator; #[test] fn brotli_stream() { @@ -27,6 +28,34 @@ fn brotli_stream() { assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); } +#[test] +fn brotli_stream_large() { + use async_compression::stream::brotli; + + let bytes = [ + Vec::from_iter((0..20_000).map(|_| rand::random())), + Vec::from_iter((0..20_000).map(|_| rand::random())), + ]; + + let stream = stream::iter(vec![ + Bytes::from(bytes[0].clone()), + Bytes::from(bytes[1].clone()), + ]); + let compress = brotli::Compress::new(); + let compressed = brotli::BrotliStream::new(stream.map(Ok), compress); + let data: Vec<_> = block_on(compressed.collect()); + let data: io::Result> = data.into_iter().collect(); + let data: Vec = data.unwrap().into_iter().flatten().collect(); + let mut output = vec![]; + BrotliDecoder::new(&data[..]) + .read_to_end(&mut output) + .unwrap(); + assert_eq!( + output, + Vec::from_iter(bytes[0].iter().chain(bytes[1].iter()).cloned()) + ); +} + //#[test] //fn brotli_read() { // use async_compression::read::brotli; diff --git a/tests/deflate.rs b/tests/deflate.rs index 19f28f27..fe8b57fa 100644 --- a/tests/deflate.rs +++ b/tests/deflate.rs @@ -6,6 +6,7 @@ use futures::{ stream::{self, StreamExt}, }; use std::io::{self, Read}; +use std::iter::FromIterator; #[test] fn deflate_stream() { @@ -26,6 +27,33 @@ fn deflate_stream() { assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); } +#[test] +fn deflate_stream_large() { + use async_compression::stream::deflate; + + let bytes = [ + Vec::from_iter((0..20_000).map(|_| rand::random())), + Vec::from_iter((0..20_000).map(|_| rand::random())), + ]; + + let stream = stream::iter(vec![ + Bytes::from(bytes[0].clone()), + Bytes::from(bytes[1].clone()), + ]); + let compressed = deflate::DeflateStream::new(stream.map(Ok), deflate::Compression::default()); + let data: Vec<_> = block_on(compressed.collect()); + let data: io::Result> = data.into_iter().collect(); + let data: Vec = data.unwrap().into_iter().flatten().collect(); + let mut output = vec![]; + DeflateDecoder::new(&data[..]) + .read_to_end(&mut output) + .unwrap(); + assert_eq!( + output, + Vec::from_iter(bytes[0].iter().chain(bytes[1].iter()).cloned()) + ); +} + #[test] fn deflate_read() { use async_compression::read::deflate; diff --git a/tests/gzip.rs b/tests/gzip.rs index 68dfd70b..89553369 100644 --- a/tests/gzip.rs +++ b/tests/gzip.rs @@ -5,6 +5,7 @@ use futures::{ stream::{self, StreamExt}, }; use std::io::{self, Read}; +use std::iter::FromIterator; #[test] fn gzip_stream() { @@ -22,3 +23,28 @@ fn gzip_stream() { GzDecoder::new(&data[..]).read_to_end(&mut output).unwrap(); assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); } + +#[test] +fn gzip_stream_large() { + use async_compression::stream::gzip; + + let bytes = [ + Vec::from_iter((0..20_000).map(|_| rand::random())), + Vec::from_iter((0..20_000).map(|_| rand::random())), + ]; + + let stream = stream::iter(vec![ + Bytes::from(bytes[0].clone()), + Bytes::from(bytes[1].clone()), + ]); + let compressed = gzip::GzipStream::new(stream.map(Ok), gzip::Compression::default()); + let data: Vec<_> = block_on(compressed.collect()); + let data: io::Result> = data.into_iter().collect(); + let data: Vec = data.unwrap().into_iter().flatten().collect(); + let mut output = vec![]; + GzDecoder::new(&data[..]).read_to_end(&mut output).unwrap(); + assert_eq!( + output, + Vec::from_iter(bytes[0].iter().chain(bytes[1].iter()).cloned()) + ); +} diff --git a/tests/zlib.rs b/tests/zlib.rs index 0c7d75a0..9ba17e81 100644 --- a/tests/zlib.rs +++ b/tests/zlib.rs @@ -6,6 +6,7 @@ use futures::{ stream::{self, StreamExt}, }; use std::io::{self, Read}; +use std::iter::FromIterator; #[test] fn zlib_stream() { @@ -26,6 +27,33 @@ fn zlib_stream() { assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); } +#[test] +fn zlib_stream_large() { + use async_compression::stream::zlib; + + let bytes = [ + Vec::from_iter((0..20_000).map(|_| rand::random())), + Vec::from_iter((0..20_000).map(|_| rand::random())), + ]; + + let stream = stream::iter(vec![ + Bytes::from(bytes[0].clone()), + Bytes::from(bytes[1].clone()), + ]); + let compressed = zlib::ZlibStream::new(stream.map(Ok), zlib::Compression::default()); + let data: Vec<_> = block_on(compressed.collect()); + let data: io::Result> = data.into_iter().collect(); + let data: Vec = data.unwrap().into_iter().flatten().collect(); + let mut output = vec![]; + ZlibDecoder::new(&data[..]) + .read_to_end(&mut output) + .unwrap(); + assert_eq!( + output, + Vec::from_iter(bytes[0].iter().chain(bytes[1].iter()).cloned()) + ); +} + #[test] fn zlib_read() { use async_compression::read::zlib;