Skip to content

Commit 04cb028

Browse files
committed
add zstd stream support
1 parent f6e0cba commit 04cb028

File tree

5 files changed

+268
-4
lines changed

5 files changed

+268
-4
lines changed

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ bytes = "0.4.12"
1111
flate2 = "1.0.7"
1212
futures-preview = "0.3.0-alpha.16"
1313
pin-project = "0.3.2"
14+
zstd = "0.4"
15+
zstd-safe = "1.4"
1416

1517
[dev-dependencies]
1618
proptest = "0.9.3"

src/stream/mod.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ mod deflate;
1313
mod flate;
1414
mod gzip;
1515
mod zlib;
16+
mod zstd;
1617

17-
pub use brotli::{BrotliDecoder, BrotliEncoder};
18-
pub use deflate::{DeflateDecoder, DeflateEncoder};
19-
pub use gzip::{GzipDecoder, GzipEncoder};
20-
pub use zlib::{ZlibDecoder, ZlibEncoder};
18+
pub use self::brotli::{BrotliDecoder, BrotliEncoder};
19+
pub use self::deflate::{DeflateDecoder, DeflateEncoder};
20+
pub use self::gzip::{GzipDecoder, GzipEncoder};
21+
pub use self::zlib::{ZlibDecoder, ZlibEncoder};
22+
pub use self::zstd::{ZstdDecoder, ZstdEncoder};

src/stream/zstd.rs

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
use std::{
2+
io::Result,
3+
mem,
4+
pin::Pin,
5+
task::{Context, Poll},
6+
};
7+
8+
use bytes::{Bytes, BytesMut};
9+
use futures::{ready, stream::Stream};
10+
use pin_project::unsafe_project;
11+
use zstd::{
12+
stream::raw::{Decoder, Encoder, Operation},
13+
DEFAULT_COMPRESSION_LEVEL,
14+
};
15+
16+
#[derive(Debug)]
17+
enum State {
18+
Reading,
19+
Writing(Bytes),
20+
Flushing,
21+
Done,
22+
Invalid,
23+
}
24+
25+
#[derive(Debug)]
26+
enum DeState {
27+
Reading,
28+
Writing(Bytes),
29+
Done,
30+
Invalid,
31+
}
32+
33+
/// A zstd encoder, or compressor.
34+
///
35+
/// This structure implements a [`Stream`] interface and will read uncompressed data from an
36+
/// underlying stream and emit a stream of compressed data.
37+
#[unsafe_project(Unpin)]
38+
pub struct ZstdEncoder<S: Stream<Item = Result<Bytes>>> {
39+
#[pin]
40+
inner: S,
41+
state: State,
42+
output: BytesMut,
43+
encoder: Encoder,
44+
}
45+
46+
/// A zstd decoder, or decompressor.
47+
///
48+
/// This structure implements a [`Stream`] interface and will read compressed data from an
49+
/// underlying stream and emit a stream of uncompressed data.
50+
#[unsafe_project(Unpin)]
51+
pub struct ZstdDecoder<S: Stream<Item = Result<Bytes>>> {
52+
#[pin]
53+
inner: S,
54+
state: DeState,
55+
output: BytesMut,
56+
decoder: Decoder,
57+
}
58+
59+
impl<S: Stream<Item = Result<Bytes>>> ZstdEncoder<S> {
60+
/// Creates a new encoder which will read uncompressed data from the given stream and emit a
61+
/// compressed stream.
62+
pub fn new(stream: S) -> ZstdEncoder<S> {
63+
ZstdEncoder {
64+
inner: stream,
65+
state: State::Reading,
66+
output: BytesMut::new(),
67+
encoder: Encoder::new(DEFAULT_COMPRESSION_LEVEL).unwrap(),
68+
}
69+
}
70+
}
71+
72+
impl<S: Stream<Item = Result<Bytes>>> ZstdDecoder<S> {
73+
/// Creates a new decoder which will read compressed data from the given stream and emit an
74+
/// uncompressed stream.
75+
pub fn new(stream: S) -> ZstdDecoder<S> {
76+
ZstdDecoder {
77+
inner: stream,
78+
state: DeState::Reading,
79+
output: BytesMut::new(),
80+
decoder: Decoder::new().unwrap(),
81+
}
82+
}
83+
}
84+
85+
impl<S: Stream<Item = Result<Bytes>>> Stream for ZstdEncoder<S> {
86+
type Item = Result<Bytes>;
87+
88+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes>>> {
89+
let mut this = self.project();
90+
91+
fn compress(
92+
encoder: &mut Encoder,
93+
input: &mut Bytes,
94+
output: &mut BytesMut,
95+
) -> Result<Bytes> {
96+
const OUTPUT_BUFFER_SIZE: usize = 8_000;
97+
98+
if output.len() < OUTPUT_BUFFER_SIZE {
99+
output.resize(OUTPUT_BUFFER_SIZE, 0);
100+
}
101+
102+
let status = encoder.run_on_buffers(input, output)?;
103+
input.advance(status.bytes_read);
104+
Ok(output.split_to(status.bytes_written).freeze())
105+
}
106+
107+
#[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058
108+
loop {
109+
break match mem::replace(this.state, State::Invalid) {
110+
State::Reading => {
111+
*this.state = State::Reading;
112+
*this.state = match ready!(this.inner.as_mut().poll_next(cx)) {
113+
Some(chunk) => State::Writing(chunk?),
114+
None => State::Flushing,
115+
};
116+
continue;
117+
}
118+
State::Writing(mut input) => {
119+
if input.is_empty() {
120+
*this.state = State::Reading;
121+
continue;
122+
}
123+
124+
let chunk = compress(&mut this.encoder, &mut input, &mut this.output)?;
125+
126+
*this.state = State::Writing(input);
127+
128+
Poll::Ready(Some(Ok(chunk)))
129+
}
130+
State::Flushing => {
131+
let mut outbuffer = zstd_safe::OutBuffer::around(this.output);
132+
133+
let bytes_left = this.encoder.flush(&mut outbuffer).unwrap();
134+
*this.state = if bytes_left == 0 {
135+
let _ = this.encoder.finish(&mut outbuffer, true);
136+
State::Done
137+
} else {
138+
State::Flushing
139+
};
140+
Poll::Ready(Some(Ok(outbuffer.as_slice().into())))
141+
}
142+
State::Done => Poll::Ready(None),
143+
State::Invalid => panic!("ZstdEncoder reached invalid state"),
144+
};
145+
}
146+
}
147+
}
148+
149+
impl<S: Stream<Item = Result<Bytes>>> Stream for ZstdDecoder<S> {
150+
type Item = Result<Bytes>;
151+
152+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes>>> {
153+
let mut this = self.project();
154+
155+
fn decompress(
156+
decoder: &mut Decoder,
157+
input: &mut Bytes,
158+
output: &mut BytesMut,
159+
) -> Result<Bytes> {
160+
const OUTPUT_BUFFER_SIZE: usize = 8_000;
161+
162+
if output.len() < OUTPUT_BUFFER_SIZE {
163+
output.resize(OUTPUT_BUFFER_SIZE, 0);
164+
}
165+
166+
let status = decoder.run_on_buffers(input, output)?;
167+
dbg!(&status.remaining, &status.bytes_written, &status.bytes_read);
168+
input.advance(status.bytes_read);
169+
Ok(output.split_to(status.bytes_written).freeze())
170+
}
171+
172+
#[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058
173+
loop {
174+
break match mem::replace(this.state, DeState::Invalid) {
175+
DeState::Reading => {
176+
*this.state = DeState::Reading;
177+
*this.state = match ready!(this.inner.as_mut().poll_next(cx)) {
178+
Some(chunk) => DeState::Writing(chunk?),
179+
None => DeState::Done,
180+
};
181+
continue;
182+
}
183+
DeState::Writing(mut input) => {
184+
if input.is_empty() {
185+
*this.state = DeState::Reading;
186+
continue;
187+
}
188+
189+
let chunk = decompress(&mut this.decoder, &mut input, &mut this.output)?;
190+
191+
*this.state = DeState::Writing(input);
192+
193+
Poll::Ready(Some(Ok(chunk)))
194+
}
195+
DeState::Done => Poll::Ready(None),
196+
DeState::Invalid => panic!("ZstdDecoder reached invalid state"),
197+
};
198+
}
199+
}
200+
}

tests/utils/mod.rs

+23
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,26 @@ pub fn gzip_stream_decompress(input: impl Stream<Item = io::Result<Bytes>>) -> V
182182
pin_mut!(input);
183183
stream_to_vec(GzipDecoder::new(input))
184184
}
185+
186+
pub fn zstd_compress(bytes: &[u8]) -> Vec<u8> {
187+
use zstd::stream::read::Encoder;
188+
use zstd::DEFAULT_COMPRESSION_LEVEL;
189+
read_to_vec(Encoder::new(bytes, DEFAULT_COMPRESSION_LEVEL).unwrap())
190+
}
191+
192+
pub fn zstd_decompress(bytes: &[u8]) -> Vec<u8> {
193+
use zstd::stream::read::Decoder;
194+
read_to_vec(Decoder::new(bytes).unwrap())
195+
}
196+
197+
pub fn zstd_stream_compress(input: impl Stream<Item = io::Result<Bytes>>) -> Vec<u8> {
198+
use async_compression::stream::ZstdEncoder;
199+
pin_mut!(input);
200+
stream_to_vec(ZstdEncoder::new(input))
201+
}
202+
203+
pub fn zstd_stream_decompress(input: impl Stream<Item = io::Result<Bytes>>) -> Vec<u8> {
204+
use async_compression::stream::ZstdDecoder;
205+
pin_mut!(input);
206+
stream_to_vec(ZstdDecoder::new(input))
207+
}

tests/zstd.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use std::iter::FromIterator;
2+
3+
mod utils;
4+
5+
#[test]
6+
fn zstd_stream_compress() {
7+
let input = utils::InputStream::from([[1, 2, 3], [4, 5, 6]]);
8+
9+
let compressed = utils::zstd_stream_compress(input.stream());
10+
let output = utils::zstd_decompress(&compressed);
11+
12+
assert_eq!(output, vec![1, 2, 3, 4, 5, 6]);
13+
}
14+
15+
#[test]
16+
fn zstd_stream_compress_large() {
17+
let input = vec![
18+
Vec::from_iter((0..20_000).map(|_| rand::random())),
19+
Vec::from_iter((0..20_000).map(|_| rand::random())),
20+
];
21+
let input = utils::InputStream::from(input);
22+
23+
let compressed = utils::zstd_stream_compress(input.stream());
24+
let output = utils::zstd_decompress(&compressed);
25+
26+
assert_eq!(output, input.bytes());
27+
}
28+
29+
#[test]
30+
fn zstd_stream_decompress() {
31+
let compressed = utils::zstd_compress(&[1, 2, 3, 4, 5, 6][..]);
32+
33+
let stream = utils::InputStream::from(vec![compressed]);
34+
let output = utils::zstd_stream_decompress(stream.stream());
35+
36+
assert_eq!(output, vec![1, 2, 3, 4, 5, 6]);
37+
}

0 commit comments

Comments
 (0)