Skip to content

Commit 928a836

Browse files
authoredMay 12, 2019
Merge pull request #13 from fairingrey/zstd
Zstd support over Streams
2 parents a36090e + 6cb2feb commit 928a836

File tree

5 files changed

+266
-4
lines changed

5 files changed

+266
-4
lines changed
 

‎Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ bytes = "0.4.12"
1111
flate2 = "1.0.7"
1212
futures-preview = "0.3.0-alpha.16"
1313
pin-project = "0.3.2"
14+
zstd = "0.4"
15+
zstd-safe = "1.4"
1416

1517
[dev-dependencies]
1618
proptest = "0.9.3"

‎src/stream/mod.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ mod deflate;
1313
mod flate;
1414
mod gzip;
1515
mod zlib;
16+
mod zstd;
1617

17-
pub use brotli::{BrotliDecoder, BrotliEncoder};
18-
pub use deflate::{DeflateDecoder, DeflateEncoder};
19-
pub use gzip::{GzipDecoder, GzipEncoder};
20-
pub use zlib::{ZlibDecoder, ZlibEncoder};
18+
pub use self::brotli::{BrotliDecoder, BrotliEncoder};
19+
pub use self::deflate::{DeflateDecoder, DeflateEncoder};
20+
pub use self::gzip::{GzipDecoder, GzipEncoder};
21+
pub use self::zlib::{ZlibDecoder, ZlibEncoder};
22+
pub use self::zstd::{ZstdDecoder, ZstdEncoder};

‎src/stream/zstd.rs

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
use std::{
2+
io::Result,
3+
mem,
4+
pin::Pin,
5+
task::{Context, Poll},
6+
};
7+
8+
use bytes::{Bytes, BytesMut};
9+
use futures::{ready, stream::Stream};
10+
use pin_project::unsafe_project;
11+
use zstd::stream::raw::{Decoder, Encoder, Operation};
12+
13+
#[derive(Debug)]
14+
enum State {
15+
Reading,
16+
Writing(Bytes),
17+
Flushing,
18+
Done,
19+
Invalid,
20+
}
21+
22+
#[derive(Debug)]
23+
enum DeState {
24+
Reading,
25+
Writing(Bytes),
26+
Done,
27+
Invalid,
28+
}
29+
30+
/// A zstd encoder, or compressor.
31+
///
32+
/// This structure implements a [`Stream`] interface and will read uncompressed data from an
33+
/// underlying stream and emit a stream of compressed data.
34+
#[unsafe_project(Unpin)]
35+
pub struct ZstdEncoder<S: Stream<Item = Result<Bytes>>> {
36+
#[pin]
37+
inner: S,
38+
state: State,
39+
output: BytesMut,
40+
encoder: Encoder,
41+
}
42+
43+
/// A zstd decoder, or decompressor.
44+
///
45+
/// This structure implements a [`Stream`] interface and will read compressed data from an
46+
/// underlying stream and emit a stream of uncompressed data.
47+
#[unsafe_project(Unpin)]
48+
pub struct ZstdDecoder<S: Stream<Item = Result<Bytes>>> {
49+
#[pin]
50+
inner: S,
51+
state: DeState,
52+
output: BytesMut,
53+
decoder: Decoder,
54+
}
55+
56+
impl<S: Stream<Item = Result<Bytes>>> ZstdEncoder<S> {
57+
/// Creates a new encoder which will read uncompressed data from the given stream and emit a
58+
/// compressed stream.
59+
///
60+
/// The `level` argument here can range from 1-21. A level of `0` will use zstd's default, which is `3`.
61+
pub fn new(stream: S, level: i32) -> ZstdEncoder<S> {
62+
ZstdEncoder {
63+
inner: stream,
64+
state: State::Reading,
65+
output: BytesMut::new(),
66+
encoder: Encoder::new(level).unwrap(),
67+
}
68+
}
69+
}
70+
71+
impl<S: Stream<Item = Result<Bytes>>> ZstdDecoder<S> {
72+
/// Creates a new decoder which will read compressed data from the given stream and emit an
73+
/// uncompressed stream.
74+
pub fn new(stream: S) -> ZstdDecoder<S> {
75+
ZstdDecoder {
76+
inner: stream,
77+
state: DeState::Reading,
78+
output: BytesMut::new(),
79+
decoder: Decoder::new().unwrap(),
80+
}
81+
}
82+
}
83+
84+
impl<S: Stream<Item = Result<Bytes>>> Stream for ZstdEncoder<S> {
85+
type Item = Result<Bytes>;
86+
87+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes>>> {
88+
let mut this = self.project();
89+
90+
fn compress(
91+
encoder: &mut Encoder,
92+
input: &mut Bytes,
93+
output: &mut BytesMut,
94+
) -> Result<Bytes> {
95+
const OUTPUT_BUFFER_SIZE: usize = 8_000;
96+
97+
if output.len() < OUTPUT_BUFFER_SIZE {
98+
output.resize(OUTPUT_BUFFER_SIZE, 0);
99+
}
100+
101+
let status = encoder.run_on_buffers(input, output)?;
102+
input.advance(status.bytes_read);
103+
Ok(output.split_to(status.bytes_written).freeze())
104+
}
105+
106+
#[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058
107+
loop {
108+
break match mem::replace(this.state, State::Invalid) {
109+
State::Reading => {
110+
*this.state = State::Reading;
111+
*this.state = match ready!(this.inner.as_mut().poll_next(cx)) {
112+
Some(chunk) => State::Writing(chunk?),
113+
None => State::Flushing,
114+
};
115+
continue;
116+
}
117+
State::Writing(mut input) => {
118+
if input.is_empty() {
119+
*this.state = State::Reading;
120+
continue;
121+
}
122+
123+
let chunk = compress(&mut this.encoder, &mut input, &mut this.output)?;
124+
125+
*this.state = State::Writing(input);
126+
127+
Poll::Ready(Some(Ok(chunk)))
128+
}
129+
State::Flushing => {
130+
let mut output = zstd_safe::OutBuffer::around(this.output);
131+
132+
let bytes_left = this.encoder.flush(&mut output).unwrap();
133+
*this.state = if bytes_left == 0 {
134+
let _ = this.encoder.finish(&mut output, true);
135+
State::Done
136+
} else {
137+
State::Flushing
138+
};
139+
Poll::Ready(Some(Ok(output.as_slice().into())))
140+
}
141+
State::Done => Poll::Ready(None),
142+
State::Invalid => panic!("ZstdEncoder reached invalid state"),
143+
};
144+
}
145+
}
146+
}
147+
148+
impl<S: Stream<Item = Result<Bytes>>> Stream for ZstdDecoder<S> {
149+
type Item = Result<Bytes>;
150+
151+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes>>> {
152+
let mut this = self.project();
153+
154+
fn decompress(
155+
decoder: &mut Decoder,
156+
input: &mut Bytes,
157+
output: &mut BytesMut,
158+
) -> Result<Bytes> {
159+
const OUTPUT_BUFFER_SIZE: usize = 8_000;
160+
161+
if output.len() < OUTPUT_BUFFER_SIZE {
162+
output.resize(OUTPUT_BUFFER_SIZE, 0);
163+
}
164+
165+
let status = decoder.run_on_buffers(input, output)?;
166+
input.advance(status.bytes_read);
167+
Ok(output.split_to(status.bytes_written).freeze())
168+
}
169+
170+
#[allow(clippy::never_loop)] // https://github.com/rust-lang/rust-clippy/issues/4058
171+
loop {
172+
break match mem::replace(this.state, DeState::Invalid) {
173+
DeState::Reading => {
174+
*this.state = DeState::Reading;
175+
*this.state = match ready!(this.inner.as_mut().poll_next(cx)) {
176+
Some(chunk) => DeState::Writing(chunk?),
177+
None => DeState::Done,
178+
};
179+
continue;
180+
}
181+
DeState::Writing(mut input) => {
182+
if input.is_empty() {
183+
*this.state = DeState::Reading;
184+
continue;
185+
}
186+
187+
let chunk = decompress(&mut this.decoder, &mut input, &mut this.output)?;
188+
189+
*this.state = DeState::Writing(input);
190+
191+
Poll::Ready(Some(Ok(chunk)))
192+
}
193+
DeState::Done => Poll::Ready(None),
194+
DeState::Invalid => panic!("ZstdDecoder reached invalid state"),
195+
};
196+
}
197+
}
198+
}

‎tests/utils/mod.rs

+23
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,26 @@ pub fn gzip_stream_decompress(input: impl Stream<Item = io::Result<Bytes>>) -> V
182182
pin_mut!(input);
183183
stream_to_vec(GzipDecoder::new(input))
184184
}
185+
186+
pub fn zstd_compress(bytes: &[u8]) -> Vec<u8> {
187+
use zstd::stream::read::Encoder;
188+
use zstd::DEFAULT_COMPRESSION_LEVEL;
189+
read_to_vec(Encoder::new(bytes, DEFAULT_COMPRESSION_LEVEL).unwrap())
190+
}
191+
192+
pub fn zstd_decompress(bytes: &[u8]) -> Vec<u8> {
193+
use zstd::stream::read::Decoder;
194+
read_to_vec(Decoder::new(bytes).unwrap())
195+
}
196+
197+
pub fn zstd_stream_compress(input: impl Stream<Item = io::Result<Bytes>>) -> Vec<u8> {
198+
use async_compression::stream::ZstdEncoder;
199+
pin_mut!(input);
200+
stream_to_vec(ZstdEncoder::new(input, 0))
201+
}
202+
203+
pub fn zstd_stream_decompress(input: impl Stream<Item = io::Result<Bytes>>) -> Vec<u8> {
204+
use async_compression::stream::ZstdDecoder;
205+
pin_mut!(input);
206+
stream_to_vec(ZstdDecoder::new(input))
207+
}

‎tests/zstd.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use std::iter::FromIterator;
2+
3+
mod utils;
4+
5+
#[test]
6+
fn zstd_stream_compress() {
7+
let input = utils::InputStream::from([[1, 2, 3], [4, 5, 6]]);
8+
9+
let compressed = utils::zstd_stream_compress(input.stream());
10+
let output = utils::zstd_decompress(&compressed);
11+
12+
assert_eq!(output, vec![1, 2, 3, 4, 5, 6]);
13+
}
14+
15+
#[test]
16+
fn zstd_stream_compress_large() {
17+
let input = vec![
18+
Vec::from_iter((0..20_000).map(|_| rand::random())),
19+
Vec::from_iter((0..20_000).map(|_| rand::random())),
20+
];
21+
let input = utils::InputStream::from(input);
22+
23+
let compressed = utils::zstd_stream_compress(input.stream());
24+
let output = utils::zstd_decompress(&compressed);
25+
26+
assert_eq!(output, input.bytes());
27+
}
28+
29+
#[test]
30+
fn zstd_stream_decompress() {
31+
let compressed = utils::zstd_compress(&[1, 2, 3, 4, 5, 6][..]);
32+
33+
let stream = utils::InputStream::from(vec![compressed]);
34+
let output = utils::zstd_stream_decompress(stream.stream());
35+
36+
assert_eq!(output, vec![1, 2, 3, 4, 5, 6]);
37+
}

0 commit comments

Comments
 (0)
Please sign in to comment.