Skip to content

Commit 0c483bf

Browse files
committed
s2: Fix DecodeConcurrent deadlock on errors
When DecodeConcurrent encounters an error it can lock up in some cases. Fix and add fuzz test for stream decoding. Fixes #920
1 parent 32f34cf commit 0c483bf

File tree

3 files changed

+104
-3
lines changed

3 files changed

+104
-3
lines changed

internal/fuzz/helpers.go

+58
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,64 @@ func AddFromZip(f *testing.F, filename string, t InputType, short bool) {
8989
}
9090
}
9191

92+
// ReturnFromZip will read the supplied zip and add all as corpus for f.
93+
// Byte slices only.
94+
func ReturnFromZip(tb testing.TB, filename string, t InputType, fn func([]byte)) {
95+
file, err := os.Open(filename)
96+
if err != nil {
97+
tb.Fatal(err)
98+
}
99+
fi, err := file.Stat()
100+
if err != nil {
101+
tb.Fatal(err)
102+
}
103+
zr, err := zip.NewReader(file, fi.Size())
104+
if err != nil {
105+
tb.Fatal(err)
106+
}
107+
for _, file := range zr.File {
108+
rc, err := file.Open()
109+
if err != nil {
110+
tb.Fatal(err)
111+
}
112+
113+
b, err := io.ReadAll(rc)
114+
if err != nil {
115+
tb.Fatal(err)
116+
}
117+
rc.Close()
118+
t := t
119+
if t == TypeOSSFuzz {
120+
t = TypeRaw // Fallback
121+
if len(b) >= 4 {
122+
sz := binary.BigEndian.Uint32(b)
123+
if sz <= uint32(len(b))-4 {
124+
fn(b[4 : 4+sz])
125+
continue
126+
}
127+
}
128+
}
129+
130+
if bytes.HasPrefix(b, []byte("go test fuzz")) {
131+
t = TypeGoFuzz
132+
} else {
133+
t = TypeRaw
134+
}
135+
136+
if t == TypeRaw {
137+
fn(b)
138+
continue
139+
}
140+
vals, err := unmarshalCorpusFile(b)
141+
if err != nil {
142+
tb.Fatal(err)
143+
}
144+
for _, v := range vals {
145+
fn(v)
146+
}
147+
}
148+
}
149+
92150
// unmarshalCorpusFile decodes corpus bytes into their respective values.
93151
func unmarshalCorpusFile(b []byte) ([][]byte, error) {
94152
if len(b) == 0 {

s2/fuzz_test.go

+34
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package s2
66
import (
77
"bytes"
88
"fmt"
9+
"io"
910
"testing"
1011

1112
"github.com/klauspost/compress/internal/fuzz"
@@ -148,3 +149,36 @@ func FuzzEncodingBlocks(f *testing.F) {
148149
}
149150
})
150151
}
152+
153+
func FuzzStreamDecode(f *testing.F) {
154+
enc := NewWriter(nil, WriterBlockSize(8<<10))
155+
addCompressed := func(b []byte) {
156+
var buf bytes.Buffer
157+
enc.Reset(&buf)
158+
enc.Write(b)
159+
enc.Close()
160+
f.Add(buf.Bytes())
161+
}
162+
fuzz.ReturnFromZip(f, "testdata/enc_regressions.zip", fuzz.TypeRaw, addCompressed)
163+
fuzz.ReturnFromZip(f, "testdata/fuzz/block-corpus-raw.zip", fuzz.TypeRaw, addCompressed)
164+
fuzz.ReturnFromZip(f, "testdata/fuzz/block-corpus-enc.zip", fuzz.TypeGoFuzz, addCompressed)
165+
dec := NewReader(nil, ReaderIgnoreCRC())
166+
f.Fuzz(func(t *testing.T, data []byte) {
167+
// Using Read
168+
dec.Reset(bytes.NewReader(data))
169+
io.Copy(io.Discard, dec)
170+
171+
// Using DecodeConcurrent
172+
dec.Reset(bytes.NewReader(data))
173+
dec.DecodeConcurrent(io.Discard, 2)
174+
175+
// Use ByteReader.
176+
dec.Reset(bytes.NewReader(data))
177+
for {
178+
_, err := dec.ReadByte()
179+
if err != nil {
180+
break
181+
}
182+
}
183+
})
184+
}

s2/reader.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,12 @@ func (r *Reader) DecodeConcurrent(w io.Writer, concurrent int) (written int64, e
452452
for toWrite := range queue {
453453
entry := <-toWrite
454454
reUse <- toWrite
455+
if hasErr() || entry == nil {
456+
if entry != nil {
457+
writtenBlocks <- entry
458+
}
459+
continue
460+
}
455461
if hasErr() {
456462
writtenBlocks <- entry
457463
continue
@@ -471,20 +477,21 @@ func (r *Reader) DecodeConcurrent(w io.Writer, concurrent int) (written int64, e
471477
}
472478
}()
473479

474-
// Reader
475480
defer func() {
476-
close(queue)
477481
if r.err != nil {
478-
err = r.err
479482
setErr(r.err)
483+
} else if err != nil {
484+
setErr(err)
480485
}
486+
close(queue)
481487
wg.Wait()
482488
if err == nil {
483489
err = aErr
484490
}
485491
written = aWritten
486492
}()
487493

494+
// Reader
488495
for !hasErr() {
489496
if !r.readFull(r.buf[:4], true) {
490497
if r.err == io.EOF {
@@ -553,11 +560,13 @@ func (r *Reader) DecodeConcurrent(w io.Writer, concurrent int) (written int64, e
553560
if err != nil {
554561
writtenBlocks <- decoded
555562
setErr(err)
563+
entry <- nil
556564
return
557565
}
558566
if !r.ignoreCRC && crc(decoded) != checksum {
559567
writtenBlocks <- decoded
560568
setErr(ErrCRC)
569+
entry <- nil
561570
return
562571
}
563572
entry <- decoded

0 commit comments

Comments
 (0)