Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization support #13

Merged
merged 3 commits into from
Nov 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions varopt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,38 @@ func New(capacity int, rnd *rand.Rand) *Varopt {
}
}

// Reset returns the sampler to its initial state, maintaining its
// capacity and random number source.
func (s *Varopt) Reset() {
s.L = s.L[:0]
s.T = s.T[:0]
s.X = s.X[:0]
s.tau = 0
s.totalCount = 0
s.totalWeight = 0
}

// Add considers a new observation for the sample with given weight.
// If there is an item ejected from the sample as a result, the item
// is returned to allow re-use of memory.
//
// An error will be returned if the weight is either negative or NaN.
func (s *Varopt) Add(sample Sample, weight float64) error {
func (s *Varopt) Add(sample Sample, weight float64) (Sample, error) {
individual := internal.Vsample{
Sample: sample,
Weight: weight,
}

if weight <= 0 || math.IsNaN(weight) {
return ErrInvalidWeight
return nil, ErrInvalidWeight
}

s.totalCount++
s.totalWeight += weight

if s.Size() < s.capacity {
s.L.Push(individual)
return nil
return nil, nil
}

// the X <- {} step from the paper is not done here,
Expand Down Expand Up @@ -102,19 +115,22 @@ func (s *Varopt) Add(sample Sample, weight float64) error {
r -= (1 - wxd/s.tau)
d++
}
var eject Sample
if r < 0 {
if d < len(s.X) {
s.X[d], s.X[len(s.X)-1] = s.X[len(s.X)-1], s.X[d]
}
eject = s.X[len(s.X)-1].Sample
s.X = s.X[:len(s.X)-1]
} else {
ti := s.rnd.Intn(len(s.T))
s.T[ti], s.T[len(s.T)-1] = s.T[len(s.T)-1], s.T[ti]
eject = s.T[len(s.T)-1].Sample
s.T = s.T[:len(s.T)-1]
}
s.T = append(s.T, s.X...)
s.X = s.X[:0]
return nil
return eject, nil
}

func (s *Varopt) uniform() float64 {
Expand Down
83 changes: 80 additions & 3 deletions varopt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,89 @@ func TestInvalidWeight(t *testing.T) {
rnd := rand.New(rand.NewSource(98887))
v := varopt.New(1, rnd)

err := v.Add(nil, math.NaN())
_, err := v.Add(nil, math.NaN())
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, -1)
_, err = v.Add(nil, -1)
require.Equal(t, err, varopt.ErrInvalidWeight)

err = v.Add(nil, 0)
_, err = v.Add(nil, 0)
require.Equal(t, err, varopt.ErrInvalidWeight)
}

func TestReset(t *testing.T) {
const capacity = 10
const insert = 100
rnd := rand.New(rand.NewSource(98887))
v := varopt.New(capacity, rnd)

sum := 0.
for i := 1.; i <= insert; i++ {
v.Add(nil, i)
sum += i
}

require.Equal(t, capacity, v.Size())
require.Equal(t, insert, v.TotalCount())
require.Equal(t, sum, v.TotalWeight())
require.Less(t, 0., v.Tau())

v.Reset()

require.Equal(t, 0, v.Size())
require.Equal(t, 0, v.TotalCount())
require.Equal(t, 0., v.TotalWeight())
require.Equal(t, 0., v.Tau())
}

func TestEject(t *testing.T) {
const capacity = 100
const rounds = 10000
const maxvalue = 10000

entries := make([]int, capacity+1)
freelist := make([]*int, capacity+1)

for i := range entries {
freelist[i] = &entries[i]
}

// Make two deterministically equal samplers
rnd1 := rand.New(rand.NewSource(98887))
rnd2 := rand.New(rand.NewSource(98887))
vsrc := rand.New(rand.NewSource(98887))

expected := varopt.New(capacity, rnd1)
ejector := varopt.New(capacity, rnd2)

for i := 0; i < rounds; i++ {
value := vsrc.Intn(maxvalue)
weight := vsrc.ExpFloat64()

_, _ = expected.Add(&value, weight)

lastitem := len(freelist) - 1
item := freelist[lastitem]
freelist = freelist[:lastitem]

*item = value
eject, _ := ejector.Add(item, weight)

if eject != nil {
freelist = append(freelist, eject.(*int))
}
}

require.Equal(t, expected.Size(), ejector.Size())
require.Equal(t, expected.TotalCount(), ejector.TotalCount())
require.Equal(t, expected.TotalWeight(), ejector.TotalWeight())
require.Equal(t, expected.Tau(), ejector.Tau())

for i := 0; i < capacity; i++ {
expectItem, expectWeight := expected.Get(i)
ejectItem, ejectWeight := expected.Get(i)

require.Equal(t, *expectItem.(*int), *ejectItem.(*int))
require.Equal(t, expectWeight, ejectWeight)
}
}