From 04cb0281799e7ce74be1bf6712ab57b91774114a Mon Sep 17 00:00:00 2001 From: grey Date: Sat, 11 May 2019 23:07:02 -0700 Subject: [PATCH 1/3] add zstd stream support --- Cargo.toml | 2 + src/stream/mod.rs | 10 ++- src/stream/zstd.rs | 200 +++++++++++++++++++++++++++++++++++++++++++++ tests/utils/mod.rs | 23 ++++++ tests/zstd.rs | 37 +++++++++ 5 files changed, 268 insertions(+), 4 deletions(-) create mode 100644 src/stream/zstd.rs create mode 100644 tests/zstd.rs diff --git a/Cargo.toml b/Cargo.toml index 66464e59..5618e5bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ bytes = "0.4.12" flate2 = "1.0.7" futures-preview = "0.3.0-alpha.16" pin-project = "0.3.2" +zstd = "0.4" +zstd-safe = "1.4" [dev-dependencies] proptest = "0.9.3" diff --git a/src/stream/mod.rs b/src/stream/mod.rs index a73aeff3..1a97fe76 100644 --- a/src/stream/mod.rs +++ b/src/stream/mod.rs @@ -13,8 +13,10 @@ mod deflate; mod flate; mod gzip; mod zlib; +mod zstd; -pub use brotli::{BrotliDecoder, BrotliEncoder}; -pub use deflate::{DeflateDecoder, DeflateEncoder}; -pub use gzip::{GzipDecoder, GzipEncoder}; -pub use zlib::{ZlibDecoder, ZlibEncoder}; +pub use self::brotli::{BrotliDecoder, BrotliEncoder}; +pub use self::deflate::{DeflateDecoder, DeflateEncoder}; +pub use self::gzip::{GzipDecoder, GzipEncoder}; +pub use self::zlib::{ZlibDecoder, ZlibEncoder}; +pub use self::zstd::{ZstdDecoder, ZstdEncoder}; diff --git a/src/stream/zstd.rs b/src/stream/zstd.rs new file mode 100644 index 00000000..4fe38d64 --- /dev/null +++ b/src/stream/zstd.rs @@ -0,0 +1,200 @@ +use std::{ + io::Result, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{Bytes, BytesMut}; +use futures::{ready, stream::Stream}; +use pin_project::unsafe_project; +use zstd::{ + stream::raw::{Decoder, Encoder, Operation}, + DEFAULT_COMPRESSION_LEVEL, +}; + +#[derive(Debug)] +enum State { + Reading, + Writing(Bytes), + Flushing, + Done, + Invalid, +} + +#[derive(Debug)] +enum DeState { + Reading, + Writing(Bytes), + Done, + Invalid, +} + +/// A zstd encoder, or compressor. +/// +/// This structure implements a [`Stream`] interface and will read uncompressed data from an +/// underlying stream and emit a stream of compressed data. +#[unsafe_project(Unpin)] +pub struct ZstdEncoder>> { + #[pin] + inner: S, + state: State, + output: BytesMut, + encoder: Encoder, +} + +/// A zstd decoder, or decompressor. +/// +/// This structure implements a [`Stream`] interface and will read compressed data from an +/// underlying stream and emit a stream of uncompressed data. +#[unsafe_project(Unpin)] +pub struct ZstdDecoder>> { + #[pin] + inner: S, + state: DeState, + output: BytesMut, + decoder: Decoder, +} + +impl>> ZstdEncoder { + /// Creates a new encoder which will read uncompressed data from the given stream and emit a + /// compressed stream. + pub fn new(stream: S) -> ZstdEncoder { + ZstdEncoder { + inner: stream, + state: State::Reading, + output: BytesMut::new(), + encoder: Encoder::new(DEFAULT_COMPRESSION_LEVEL).unwrap(), + } + } +} + +impl>> ZstdDecoder { + /// Creates a new decoder which will read compressed data from the given stream and emit an + /// uncompressed stream. + pub fn new(stream: S) -> ZstdDecoder { + ZstdDecoder { + inner: stream, + state: DeState::Reading, + output: BytesMut::new(), + decoder: Decoder::new().unwrap(), + } + } +} + +impl>> Stream for ZstdEncoder { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + let mut this = self.project(); + + fn compress( + encoder: &mut Encoder, + input: &mut Bytes, + output: &mut BytesMut, + ) -> Result { + const OUTPUT_BUFFER_SIZE: usize = 8_000; + + if output.len() < OUTPUT_BUFFER_SIZE { + output.resize(OUTPUT_BUFFER_SIZE, 0); + } + + let status = encoder.run_on_buffers(input, output)?; + input.advance(status.bytes_read); + Ok(output.split_to(status.bytes_written).freeze()) + } + + #[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058 + loop { + break match mem::replace(this.state, State::Invalid) { + State::Reading => { + *this.state = State::Reading; + *this.state = match ready!(this.inner.as_mut().poll_next(cx)) { + Some(chunk) => State::Writing(chunk?), + None => State::Flushing, + }; + continue; + } + State::Writing(mut input) => { + if input.is_empty() { + *this.state = State::Reading; + continue; + } + + let chunk = compress(&mut this.encoder, &mut input, &mut this.output)?; + + *this.state = State::Writing(input); + + Poll::Ready(Some(Ok(chunk))) + } + State::Flushing => { + let mut outbuffer = zstd_safe::OutBuffer::around(this.output); + + let bytes_left = this.encoder.flush(&mut outbuffer).unwrap(); + *this.state = if bytes_left == 0 { + let _ = this.encoder.finish(&mut outbuffer, true); + State::Done + } else { + State::Flushing + }; + Poll::Ready(Some(Ok(outbuffer.as_slice().into()))) + } + State::Done => Poll::Ready(None), + State::Invalid => panic!("ZstdEncoder reached invalid state"), + }; + } + } +} + +impl>> Stream for ZstdDecoder { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + let mut this = self.project(); + + fn decompress( + decoder: &mut Decoder, + input: &mut Bytes, + output: &mut BytesMut, + ) -> Result { + const OUTPUT_BUFFER_SIZE: usize = 8_000; + + if output.len() < OUTPUT_BUFFER_SIZE { + output.resize(OUTPUT_BUFFER_SIZE, 0); + } + + let status = decoder.run_on_buffers(input, output)?; + dbg!(&status.remaining, &status.bytes_written, &status.bytes_read); + input.advance(status.bytes_read); + Ok(output.split_to(status.bytes_written).freeze()) + } + + #[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058 + loop { + break match mem::replace(this.state, DeState::Invalid) { + DeState::Reading => { + *this.state = DeState::Reading; + *this.state = match ready!(this.inner.as_mut().poll_next(cx)) { + Some(chunk) => DeState::Writing(chunk?), + None => DeState::Done, + }; + continue; + } + DeState::Writing(mut input) => { + if input.is_empty() { + *this.state = DeState::Reading; + continue; + } + + let chunk = decompress(&mut this.decoder, &mut input, &mut this.output)?; + + *this.state = DeState::Writing(input); + + Poll::Ready(Some(Ok(chunk))) + } + DeState::Done => Poll::Ready(None), + DeState::Invalid => panic!("ZstdDecoder reached invalid state"), + }; + } + } +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 9168b8cb..a4e5d83f 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -182,3 +182,26 @@ pub fn gzip_stream_decompress(input: impl Stream>) -> V pin_mut!(input); stream_to_vec(GzipDecoder::new(input)) } + +pub fn zstd_compress(bytes: &[u8]) -> Vec { + use zstd::stream::read::Encoder; + use zstd::DEFAULT_COMPRESSION_LEVEL; + read_to_vec(Encoder::new(bytes, DEFAULT_COMPRESSION_LEVEL).unwrap()) +} + +pub fn zstd_decompress(bytes: &[u8]) -> Vec { + use zstd::stream::read::Decoder; + read_to_vec(Decoder::new(bytes).unwrap()) +} + +pub fn zstd_stream_compress(input: impl Stream>) -> Vec { + use async_compression::stream::ZstdEncoder; + pin_mut!(input); + stream_to_vec(ZstdEncoder::new(input)) +} + +pub fn zstd_stream_decompress(input: impl Stream>) -> Vec { + use async_compression::stream::ZstdDecoder; + pin_mut!(input); + stream_to_vec(ZstdDecoder::new(input)) +} diff --git a/tests/zstd.rs b/tests/zstd.rs new file mode 100644 index 00000000..18b8c2e4 --- /dev/null +++ b/tests/zstd.rs @@ -0,0 +1,37 @@ +use std::iter::FromIterator; + +mod utils; + +#[test] +fn zstd_stream_compress() { + let input = utils::InputStream::from([[1, 2, 3], [4, 5, 6]]); + + let compressed = utils::zstd_stream_compress(input.stream()); + let output = utils::zstd_decompress(&compressed); + + assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); +} + +#[test] +fn zstd_stream_compress_large() { + let input = vec![ + Vec::from_iter((0..20_000).map(|_| rand::random())), + Vec::from_iter((0..20_000).map(|_| rand::random())), + ]; + let input = utils::InputStream::from(input); + + let compressed = utils::zstd_stream_compress(input.stream()); + let output = utils::zstd_decompress(&compressed); + + assert_eq!(output, input.bytes()); +} + +#[test] +fn zstd_stream_decompress() { + let compressed = utils::zstd_compress(&[1, 2, 3, 4, 5, 6][..]); + + let stream = utils::InputStream::from(vec![compressed]); + let output = utils::zstd_stream_decompress(stream.stream()); + + assert_eq!(output, vec![1, 2, 3, 4, 5, 6]); +} From cce68708f48bd56c0094be86ad8bb99f2190e7e8 Mon Sep 17 00:00:00 2001 From: grey Date: Sat, 11 May 2019 23:17:23 -0700 Subject: [PATCH 2/3] sharpen up zstd stream API a bit, add docs for level arg --- src/stream/zstd.rs | 19 +++++++++---------- tests/utils/mod.rs | 2 +- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/stream/zstd.rs b/src/stream/zstd.rs index 4fe38d64..1325ff16 100644 --- a/src/stream/zstd.rs +++ b/src/stream/zstd.rs @@ -8,10 +8,7 @@ use std::{ use bytes::{Bytes, BytesMut}; use futures::{ready, stream::Stream}; use pin_project::unsafe_project; -use zstd::{ - stream::raw::{Decoder, Encoder, Operation}, - DEFAULT_COMPRESSION_LEVEL, -}; +use zstd::stream::raw::{Decoder, Encoder, Operation}; #[derive(Debug)] enum State { @@ -59,12 +56,14 @@ pub struct ZstdDecoder>> { impl>> ZstdEncoder { /// Creates a new encoder which will read uncompressed data from the given stream and emit a /// compressed stream. - pub fn new(stream: S) -> ZstdEncoder { + /// + /// The `level` argument here can range from 1-21. A level of `0` will use zstd's default, which is `3`. + pub fn new(stream: S, level: i32) -> ZstdEncoder { ZstdEncoder { inner: stream, state: State::Reading, output: BytesMut::new(), - encoder: Encoder::new(DEFAULT_COMPRESSION_LEVEL).unwrap(), + encoder: Encoder::new(level).unwrap(), } } } @@ -128,16 +127,16 @@ impl>> Stream for ZstdEncoder { Poll::Ready(Some(Ok(chunk))) } State::Flushing => { - let mut outbuffer = zstd_safe::OutBuffer::around(this.output); + let mut output = zstd_safe::OutBuffer::around(this.output); - let bytes_left = this.encoder.flush(&mut outbuffer).unwrap(); + let bytes_left = this.encoder.flush(&mut output).unwrap(); *this.state = if bytes_left == 0 { - let _ = this.encoder.finish(&mut outbuffer, true); + let _ = this.encoder.finish(&mut output, true); State::Done } else { State::Flushing }; - Poll::Ready(Some(Ok(outbuffer.as_slice().into()))) + Poll::Ready(Some(Ok(output.as_slice().into()))) } State::Done => Poll::Ready(None), State::Invalid => panic!("ZstdEncoder reached invalid state"), diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index a4e5d83f..0cab5077 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -197,7 +197,7 @@ pub fn zstd_decompress(bytes: &[u8]) -> Vec { pub fn zstd_stream_compress(input: impl Stream>) -> Vec { use async_compression::stream::ZstdEncoder; pin_mut!(input); - stream_to_vec(ZstdEncoder::new(input)) + stream_to_vec(ZstdEncoder::new(input, 0)) } pub fn zstd_stream_decompress(input: impl Stream>) -> Vec { From 0cad210f9e963467958ea42b86a0b739b4f88edf Mon Sep 17 00:00:00 2001 From: grey Date: Sat, 11 May 2019 23:23:28 -0700 Subject: [PATCH 3/3] remove dbg call --- src/stream/zstd.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stream/zstd.rs b/src/stream/zstd.rs index 1325ff16..b47408c0 100644 --- a/src/stream/zstd.rs +++ b/src/stream/zstd.rs @@ -163,7 +163,6 @@ impl>> Stream for ZstdDecoder { } let status = decoder.run_on_buffers(input, output)?; - dbg!(&status.remaining, &status.bytes_written, &status.bytes_read); input.advance(status.bytes_read); Ok(output.split_to(status.bytes_written).freeze()) }