Skip to content

Commit be100d6

Browse files
authored
zstd: Check destination buffer size (#171)
When writing FSE tables, check if we have enough space. Fixes VictoriaMetrics/VictoriaMetrics#215
1 parent 169bb21 commit be100d6

File tree

2 files changed

+66
-8
lines changed

2 files changed

+66
-8
lines changed

zstd/encoder_test.go

+53
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"io/ioutil"
1212
"math/rand"
1313
"os"
14+
"runtime"
1415
"strings"
16+
"sync"
1517
"testing"
1618
"time"
1719

@@ -57,6 +59,57 @@ func TestEncoder_EncodeAllSimple(t *testing.T) {
5759
}
5860
}
5961

62+
func TestEncoder_EncodeAllConcurrent(t *testing.T) {
63+
in, err := ioutil.ReadFile("testdata/z000028")
64+
if err != nil {
65+
t.Fatal(err)
66+
}
67+
in = append(in, in...)
68+
69+
// When running race no more than 8k goroutines allowed.
70+
n := 4000 / runtime.GOMAXPROCS(0)
71+
if testing.Short() {
72+
n = 200 / runtime.GOMAXPROCS(0)
73+
}
74+
dec, err := NewReader(nil)
75+
if err != nil {
76+
t.Fatal(err)
77+
}
78+
defer dec.Close()
79+
for level := EncoderLevel(speedNotSet + 1); level < speedLast; level++ {
80+
t.Run(level.String(), func(t *testing.T) {
81+
rng := rand.New(rand.NewSource(0x1337))
82+
e, err := NewWriter(nil, WithEncoderLevel(level), WithZeroFrames(true))
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
defer e.Close()
87+
var wg sync.WaitGroup
88+
wg.Add(n)
89+
for i := 0; i < n; i++ {
90+
in := in[rng.Int()&1023:]
91+
in = in[:rng.Intn(len(in))]
92+
go func() {
93+
defer wg.Done()
94+
dst := e.EncodeAll(in, nil)
95+
//t.Log("Simple Encoder len", len(in), "-> zstd len", len(dst))
96+
decoded, err := dec.DecodeAll(dst, nil)
97+
if err != nil {
98+
t.Error(err, len(decoded))
99+
}
100+
if !bytes.Equal(decoded, in) {
101+
//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.got", decoded, os.ModePerm)
102+
//ioutil.WriteFile("testdata/"+t.Name()+"-z000028.want", in, os.ModePerm)
103+
t.Fatal("Decoded does not match")
104+
}
105+
}()
106+
}
107+
wg.Wait()
108+
t.Log("Encoded content matched.", n, "goroutines")
109+
})
110+
}
111+
}
112+
60113
func TestEncoder_EncodeAllEncodeXML(t *testing.T) {
61114
f, err := os.Open("testdata/xml.zst")
62115
if err != nil {

zstd/fse_encoder.go

+13-8
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,14 @@ func (s *fseEncoder) validateNorm() (err error) {
502502
// writeCount will write the normalized histogram count to header.
503503
// This is read back by readNCount.
504504
func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
505+
if s.useRLE {
506+
return append(out, s.rleVal), nil
507+
}
508+
if s.preDefined || s.reUsed {
509+
// Never write predefined.
510+
return out, nil
511+
}
512+
505513
var (
506514
tableLog = s.actualTableLog
507515
tableSize = 1 << tableLog
@@ -516,15 +524,12 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
516524
remaining = int16(tableSize + 1) /* +1 for extra accuracy */
517525
threshold = int16(tableSize)
518526
nbBits = uint(tableLog + 1)
527+
outP = len(out)
519528
)
520-
if s.useRLE {
521-
return append(out, s.rleVal), nil
522-
}
523-
if s.preDefined || s.reUsed {
524-
// Never write predefined.
525-
return out, nil
529+
if cap(out) < outP+maxHeaderSize {
530+
out = append(out, make([]byte, maxHeaderSize*3)...)
531+
out = out[:len(out)-maxHeaderSize*3]
526532
}
527-
outP := len(out)
528533
out = out[:outP+maxHeaderSize]
529534

530535
// stops at 1
@@ -598,7 +603,7 @@ func (s *fseEncoder) writeCount(out []byte) ([]byte, error) {
598603
out[outP+1] = byte(bitStream >> 8)
599604
outP += int((bitCount + 7) / 8)
600605

601-
if uint16(charnum) > s.symbolLen {
606+
if charnum > s.symbolLen {
602607
return nil, errors.New("internal error: charnum > s.symbolLen")
603608
}
604609
return out[:outP], nil

0 commit comments

Comments
 (0)