Skip to content

Commit c0f3c32

Browse files
authored
Fix mixed Write and ReadFrom calls (#282)
Flush any pending writes when ReadFrom is used.
1 parent f5ee0f4 commit c0f3c32

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

zstd/encoder.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,13 @@ func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
311311
if debug {
312312
println("Using ReadFrom")
313313
}
314-
// Maybe handle stuff queued?
314+
315+
// Flush any current writes.
316+
if len(e.state.filling) > 0 {
317+
if err := e.nextBlock(false); err != nil {
318+
return 0, err
319+
}
320+
}
315321
e.state.filling = e.state.filling[:e.o.blockSize]
316322
src := e.state.filling
317323
for {
@@ -328,7 +334,7 @@ func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
328334
if debug {
329335
println("ReadFrom: got EOF final block:", len(e.state.filling))
330336
}
331-
return n, e.nextBlock(true)
337+
return n, nil
332338
default:
333339
if debug {
334340
println("ReadFrom: got error:", err)

zstd/encoder_test.go

+42
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,48 @@ func TestEncoderReadFrom(t *testing.T) {
695695
dec.Close()
696696
}
697697

698+
func TestInterleavedWriteReadFrom(t *testing.T) {
699+
var encoded bytes.Buffer
700+
701+
enc, err := NewWriter(&encoded)
702+
if err != nil {
703+
t.Fatal(err)
704+
}
705+
706+
if _, err := enc.Write([]byte("write1")); err != nil {
707+
t.Fatal(err)
708+
}
709+
if _, err := enc.Write([]byte("write2")); err != nil {
710+
t.Fatal(err)
711+
}
712+
if _, err := enc.ReadFrom(strings.NewReader("readfrom1")); err != nil {
713+
t.Fatal(err)
714+
}
715+
if _, err := enc.Write([]byte("write3")); err != nil {
716+
t.Fatal(err)
717+
}
718+
719+
if err := enc.Close(); err != nil {
720+
t.Fatal(err)
721+
}
722+
723+
dec, err := NewReader(&encoded)
724+
if err != nil {
725+
t.Fatal(err)
726+
}
727+
defer dec.Close()
728+
729+
gotb, err := ioutil.ReadAll(dec)
730+
if err != nil {
731+
t.Fatal(err)
732+
}
733+
got := string(gotb)
734+
735+
if want := "write1write2readfrom1write3"; got != want {
736+
t.Errorf("got decoded %q, want %q", got, want)
737+
}
738+
}
739+
698740
func TestEncoder_EncodeAllEmpty(t *testing.T) {
699741
if testing.Short() {
700742
t.SkipNow()

0 commit comments

Comments
 (0)