Skip to content

Commit 617727c

Browse files
authored
s2: Add ReaderMaxBlockSize (#311)
ReaderMaxBlockSize allows to control allocations if the stream has been compressed with a smaller WriterBlockSize, or with the default 1MB. Blocks must be this size or smaller to decompress, otherwise the decoder will return ErrUnsupported. Fixes incorrect validation of WriterBlockSize parameter. Changes NewReader to take options as varargs.
1 parent bf241f6 commit 617727c

File tree

4 files changed

+222
-28
lines changed

4 files changed

+222
-28
lines changed

s2/decode.go

+98-20
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,38 @@ func Decode(dst, src []byte) ([]byte, error) {
7777
// NewReader returns a new Reader that decompresses from r, using the framing
7878
// format described at
7979
// https://github.com/google/snappy/blob/master/framing_format.txt with S2 changes.
80-
func NewReader(r io.Reader) *Reader {
81-
return &Reader{
82-
r: r,
83-
buf: make([]byte, MaxEncodedLen(maxBlockSize)+checksumSize),
80+
func NewReader(r io.Reader, opts ...ReaderOption) *Reader {
81+
nr := Reader{
82+
r: r,
83+
maxBlock: maxBlockSize,
84+
}
85+
for _, opt := range opts {
86+
if err := opt(&nr); err != nil {
87+
nr.err = err
88+
return &nr
89+
}
90+
}
91+
nr.buf = make([]byte, MaxEncodedLen(nr.maxBlock)+checksumSize)
92+
nr.paramsOK = true
93+
return &nr
94+
}
95+
96+
// ReaderOption is an option for creating a decoder.
97+
type ReaderOption func(*Reader) error
98+
99+
// ReaderMaxBlockSize allows to control allocations if the stream
100+
// has been compressed with a smaller WriterBlockSize, or with the default 1MB.
101+
// Blocks must be this size or smaller to decompress,
102+
// otherwise the decoder will return ErrUnsupported.
103+
//
104+
// Default is the maximum limit of 4MB.
105+
func ReaderMaxBlockSize(n int) ReaderOption {
106+
return func(r *Reader) error {
107+
if n > maxBlockSize || n <= 0 {
108+
return errors.New("s2: block size too large. Must be <= 4MB and > 0")
109+
}
110+
r.maxBlock = n
111+
return nil
84112
}
85113
}
86114

@@ -92,13 +120,18 @@ type Reader struct {
92120
buf []byte
93121
// decoded[i:j] contains decoded bytes that have not yet been passed on.
94122
i, j int
123+
maxBlock int
95124
readHeader bool
125+
paramsOK bool
96126
}
97127

98128
// Reset discards any buffered data, resets all state, and switches the Snappy
99129
// reader to read from r. This permits reusing a Reader rather than allocating
100130
// a new one.
101131
func (r *Reader) Reset(reader io.Reader) {
132+
if !r.paramsOK {
133+
return
134+
}
102135
r.r = reader
103136
r.err = nil
104137
r.i = 0
@@ -116,6 +149,36 @@ func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
116149
return true
117150
}
118151

152+
// skipN will skip n bytes.
153+
// If the supplied reader supports seeking that is used.
154+
// tmp is used as a temporary buffer for reading.
155+
// The supplied slice does not need to be the size of the read.
156+
func (r *Reader) skipN(tmp []byte, n int, allowEOF bool) (ok bool) {
157+
if rs, ok := r.r.(io.ReadSeeker); ok {
158+
_, err := rs.Seek(int64(n), io.SeekCurrent)
159+
if err == nil {
160+
return true
161+
}
162+
if err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
163+
r.err = ErrCorrupt
164+
return false
165+
}
166+
}
167+
for n > 0 {
168+
if n < len(tmp) {
169+
tmp = tmp[:n]
170+
}
171+
if _, r.err = io.ReadFull(r.r, tmp); r.err != nil {
172+
if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
173+
r.err = ErrCorrupt
174+
}
175+
return false
176+
}
177+
n -= len(tmp)
178+
}
179+
return true
180+
}
181+
119182
// Read satisfies the io.Reader interface.
120183
func (r *Reader) Read(p []byte) (int, error) {
121184
if r.err != nil {
@@ -139,10 +202,6 @@ func (r *Reader) Read(p []byte) (int, error) {
139202
r.readHeader = true
140203
}
141204
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
142-
if chunkLen > len(r.buf) {
143-
r.err = ErrUnsupported
144-
return 0, r.err
145-
}
146205

147206
// The chunk types are specified at
148207
// https://github.com/google/snappy/blob/master/framing_format.txt
@@ -153,6 +212,10 @@ func (r *Reader) Read(p []byte) (int, error) {
153212
r.err = ErrCorrupt
154213
return 0, r.err
155214
}
215+
if chunkLen > len(r.buf) {
216+
r.err = ErrUnsupported
217+
return 0, r.err
218+
}
156219
buf := r.buf[:chunkLen]
157220
if !r.readFull(buf, false) {
158221
return 0, r.err
@@ -166,7 +229,7 @@ func (r *Reader) Read(p []byte) (int, error) {
166229
return 0, r.err
167230
}
168231
if n > len(r.decoded) {
169-
if n > maxBlockSize {
232+
if n > r.maxBlock {
170233
r.err = ErrCorrupt
171234
return 0, r.err
172235
}
@@ -189,6 +252,10 @@ func (r *Reader) Read(p []byte) (int, error) {
189252
r.err = ErrCorrupt
190253
return 0, r.err
191254
}
255+
if chunkLen > len(r.buf) {
256+
r.err = ErrUnsupported
257+
return 0, r.err
258+
}
192259
buf := r.buf[:checksumSize]
193260
if !r.readFull(buf, false) {
194261
return 0, r.err
@@ -197,7 +264,7 @@ func (r *Reader) Read(p []byte) (int, error) {
197264
// Read directly into r.decoded instead of via r.buf.
198265
n := chunkLen - checksumSize
199266
if n > len(r.decoded) {
200-
if n > maxBlockSize {
267+
if n > r.maxBlock {
201268
r.err = ErrCorrupt
202269
return 0, r.err
203270
}
@@ -238,7 +305,12 @@ func (r *Reader) Read(p []byte) (int, error) {
238305
}
239306
// Section 4.4 Padding (chunk type 0xfe).
240307
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
241-
if !r.readFull(r.buf[:chunkLen], false) {
308+
if chunkLen > maxBlockSize {
309+
r.err = ErrUnsupported
310+
return 0, r.err
311+
}
312+
313+
if !r.skipN(r.buf, chunkLen, false) {
242314
return 0, r.err
243315
}
244316
}
@@ -286,10 +358,6 @@ func (r *Reader) Skip(n int64) error {
286358
r.readHeader = true
287359
}
288360
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
289-
if chunkLen > len(r.buf) {
290-
r.err = ErrUnsupported
291-
return r.err
292-
}
293361

294362
// The chunk types are specified at
295363
// https://github.com/google/snappy/blob/master/framing_format.txt
@@ -300,6 +368,10 @@ func (r *Reader) Skip(n int64) error {
300368
r.err = ErrCorrupt
301369
return r.err
302370
}
371+
if chunkLen > len(r.buf) {
372+
r.err = ErrUnsupported
373+
return r.err
374+
}
303375
buf := r.buf[:chunkLen]
304376
if !r.readFull(buf, false) {
305377
return r.err
@@ -312,7 +384,7 @@ func (r *Reader) Skip(n int64) error {
312384
r.err = err
313385
return r.err
314386
}
315-
if dLen > maxBlockSize {
387+
if dLen > r.maxBlock {
316388
r.err = ErrCorrupt
317389
return r.err
318390
}
@@ -342,6 +414,10 @@ func (r *Reader) Skip(n int64) error {
342414
r.err = ErrCorrupt
343415
return r.err
344416
}
417+
if chunkLen > len(r.buf) {
418+
r.err = ErrUnsupported
419+
return r.err
420+
}
345421
buf := r.buf[:checksumSize]
346422
if !r.readFull(buf, false) {
347423
return r.err
@@ -350,7 +426,7 @@ func (r *Reader) Skip(n int64) error {
350426
// Read directly into r.decoded instead of via r.buf.
351427
n2 := chunkLen - checksumSize
352428
if n2 > len(r.decoded) {
353-
if n2 > maxBlockSize {
429+
if n2 > r.maxBlock {
354430
r.err = ErrCorrupt
355431
return r.err
356432
}
@@ -391,13 +467,15 @@ func (r *Reader) Skip(n int64) error {
391467
r.err = ErrUnsupported
392468
return r.err
393469
}
470+
if chunkLen > maxBlockSize {
471+
r.err = ErrUnsupported
472+
return r.err
473+
}
394474
// Section 4.4 Padding (chunk type 0xfe).
395475
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
396-
if !r.readFull(r.buf[:chunkLen], false) {
476+
if !r.skipN(r.buf, chunkLen, false) {
397477
return r.err
398478
}
399-
400-
return io.ErrUnexpectedEOF
401479
}
402480
return nil
403481
}

s2/decode_test.go

+118
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package s2
66

77
import (
88
"bytes"
9+
"fmt"
910
"io/ioutil"
1011
"strings"
1112
"testing"
@@ -41,3 +42,120 @@ func TestDecodeRegression(t *testing.T) {
4142
})
4243
}
4344
}
45+
46+
func TestDecoderMaxBlockSize(t *testing.T) {
47+
data, err := ioutil.ReadFile("testdata/enc_regressions.zip")
48+
if err != nil {
49+
t.Fatal(err)
50+
}
51+
zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
52+
if err != nil {
53+
t.Fatal(err)
54+
}
55+
sizes := []int{4 << 10, 10 << 10, 1 << 20, 4 << 20}
56+
test := func(t *testing.T, data []byte) {
57+
for _, size := range sizes {
58+
t.Run(fmt.Sprintf("%d", size), func(t *testing.T) {
59+
var buf bytes.Buffer
60+
dec := NewReader(nil, ReaderMaxBlockSize(size))
61+
enc := NewWriter(&buf, WriterBlockSize(size), WriterPadding(16<<10), WriterPaddingSrc(zeroReader{}))
62+
63+
// Test writer.
64+
n, err := enc.Write(data)
65+
if err != nil {
66+
t.Error(err)
67+
return
68+
}
69+
if n != len(data) {
70+
t.Error(fmt.Errorf("Write: Short write, want %d, got %d", len(data), n))
71+
return
72+
}
73+
err = enc.Close()
74+
if err != nil {
75+
t.Error(err)
76+
return
77+
}
78+
// Calling close twice should not affect anything.
79+
err = enc.Close()
80+
if err != nil {
81+
t.Error(err)
82+
return
83+
}
84+
85+
dec.Reset(&buf)
86+
got, err := ioutil.ReadAll(dec)
87+
if err != nil {
88+
t.Error(err)
89+
return
90+
}
91+
if !bytes.Equal(data, got) {
92+
t.Error("block (reset) decoder mismatch")
93+
return
94+
}
95+
96+
// Test Reset on both and use ReadFrom instead.
97+
buf.Reset()
98+
enc.Reset(&buf)
99+
n2, err := enc.ReadFrom(bytes.NewBuffer(data))
100+
if err != nil {
101+
t.Error(err)
102+
return
103+
}
104+
if n2 != int64(len(data)) {
105+
t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2))
106+
return
107+
}
108+
// Encode twice...
109+
n2, err = enc.ReadFrom(bytes.NewBuffer(data))
110+
if err != nil {
111+
t.Error(err)
112+
return
113+
}
114+
if n2 != int64(len(data)) {
115+
t.Error(fmt.Errorf("ReadFrom: Short read, want %d, got %d", len(data), n2))
116+
return
117+
}
118+
119+
err = enc.Close()
120+
if err != nil {
121+
t.Error(err)
122+
return
123+
}
124+
if enc.pad > 0 && buf.Len()%enc.pad != 0 {
125+
t.Error(fmt.Errorf("wanted size to be mutiple of %d, got size %d with remainder %d", enc.pad, buf.Len(), buf.Len()%enc.pad))
126+
return
127+
}
128+
dec.Reset(&buf)
129+
// Skip first...
130+
dec.Skip(int64(len(data)))
131+
got, err = ioutil.ReadAll(dec)
132+
if err != nil {
133+
t.Error(err)
134+
return
135+
}
136+
if !bytes.Equal(data, got) {
137+
t.Error("frame (reset) decoder mismatch")
138+
return
139+
}
140+
})
141+
}
142+
}
143+
for _, tt := range zr.File {
144+
if !strings.HasSuffix(t.Name(), "") {
145+
continue
146+
}
147+
t.Run(tt.Name, func(t *testing.T) {
148+
r, err := tt.Open()
149+
if err != nil {
150+
t.Error(err)
151+
return
152+
}
153+
b, err := ioutil.ReadAll(r)
154+
if err != nil {
155+
t.Error(err)
156+
return
157+
}
158+
test(t, b[:len(b):len(b)])
159+
})
160+
}
161+
}

s2/encode.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -983,7 +983,7 @@ func WriterUncompressed() WriterOption {
983983
// Default block size is 1MB.
984984
func WriterBlockSize(n int) WriterOption {
985985
return func(w *Writer) error {
986-
if w.blockSize > maxBlockSize || w.blockSize < minBlockSize {
986+
if n > maxBlockSize || n < minBlockSize {
987987
return errors.New("s2: block size too large. Must be <= 4MB and >=4KB")
988988
}
989989
w.blockSize = n

0 commit comments

Comments
 (0)