Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for NaN values; return error instead of panicking #12

Merged
merged 4 commits into from
Nov 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions varopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package varopt
import (
"container/heap"
"fmt"
"math"
"math/rand"
)

Expand Down Expand Up @@ -48,6 +49,8 @@ type vsample struct {

type largeHeap []vsample

var ErrInvalidWeight = fmt.Errorf("Negative, zero, or NaN weight")

// New returns a new Varopt sampler with given capacity (i.e.,
// reservoir size) and random number generator.
func New(capacity int, rnd *rand.Rand) *Varopt {
Expand All @@ -58,22 +61,24 @@ func New(capacity int, rnd *rand.Rand) *Varopt {
}

// Add considers a new observation for the sample with given weight.
func (s *Varopt) Add(sample Sample, weight float64) {
//
// An error will be returned if the weight is either negative or NaN.
func (s *Varopt) Add(sample Sample, weight float64) error {
individual := vsample{
sample: sample,
weight: weight,
}

if weight <= 0 {
panic(fmt.Sprint("Invalid weight <= 0: ", weight))
if weight <= 0 || math.IsNaN(weight) {
return ErrInvalidWeight
}

s.totalCount++
s.totalWeight += weight

if s.Size() < s.capacity {
heap.Push(&s.L, individual)
return
return nil
}

// the X <- {} step from the paper is not done here,
Expand Down Expand Up @@ -115,6 +120,7 @@ func (s *Varopt) Add(sample Sample, weight float64) {
}
s.T = append(s.T, s.X...)
s.X = s.X[:0]
return nil
}

func (s *Varopt) uniform() float64 {
Expand Down
14 changes: 14 additions & 0 deletions varopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,17 @@ func testUnbiased(t *testing.T, bbr, bsr float64) {
[][][]varopt.Sample{smallBlocks, bigBlocks},
)
}

func TestInvalidWeight(t *testing.T) {
rnd := rand.New(rand.NewSource(98887))
v := varopt.New(1, rnd)

err := v.Add(nil, math.NaN())
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, -1)
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, 0)
require.Equal(t, err, varopt.ErrInvalidWeight)
}