Skip to content

Commit c7cf208

Browse files
committed
bindings/go: parallelize MSM for N<32.
1 parent 1f49aa0 commit c7cf208

File tree

4 files changed

+212
-3
lines changed

4 files changed

+212
-3
lines changed

bindings/go/blst.go

+132-2
Original file line numberDiff line numberDiff line change
@@ -2112,7 +2112,7 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
21122112

21132113
numThreads := numThreads(0)
21142114

2115-
if numThreads < 2 || npoints < 32 {
2115+
if numThreads < 2 {
21162116
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(C.size_t(npoints))) / 8
21172117
scratch := make([]uint64, sz)
21182118

@@ -2161,6 +2161,71 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
21612161
return &ret
21622162
}
21632163

2164+
if npoints < 32 {
2165+
if numThreads > npoints {
2166+
numThreads = npoints
2167+
}
2168+
2169+
curItem := uint32(0)
2170+
msgs := make(chan P1, numThreads)
2171+
2172+
for tid := 0; tid < numThreads; tid++ {
2173+
go func() {
2174+
var acc P1
2175+
2176+
for {
2177+
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
2178+
if workItem >= npoints {
2179+
break
2180+
}
2181+
2182+
var point *P1Affine
2183+
switch val := pointsIf.(type) {
2184+
case []*P1Affine:
2185+
point = val[workItem]
2186+
case []P1Affine:
2187+
point = &val[workItem]
2188+
case P1Affines:
2189+
point = &val[workItem]
2190+
}
2191+
2192+
var scalar *C.byte
2193+
switch val := scalarsIf.(type) {
2194+
case []byte:
2195+
scalar = (*C.byte)(&val[workItem*nbytes])
2196+
case [][]byte:
2197+
scalar = scalars[workItem]
2198+
case []Scalar:
2199+
if nbits > 248 {
2200+
scalar = &val[workItem].b[0]
2201+
} else {
2202+
scalar = scalars[workItem]
2203+
}
2204+
case []*Scalar:
2205+
scalar = scalars[workItem]
2206+
}
2207+
2208+
C.go_p1_mult_n_acc(&acc, &point.x, true,
2209+
scalar, C.size_t(nbits))
2210+
}
2211+
2212+
msgs <- acc
2213+
}()
2214+
}
2215+
2216+
ret := <-msgs
2217+
for tid := 1; tid < numThreads; tid++ {
2218+
point := <-msgs
2219+
C.blst_p1_add_or_double(&ret, &ret, &point)
2220+
}
2221+
2222+
for i := range scalars {
2223+
scalars[i] = nil
2224+
}
2225+
2226+
return &ret
2227+
}
2228+
21642229
// this is sizeof(scratch[0])
21652230
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(0)) / 8
21662231

@@ -2852,7 +2917,7 @@ func P2AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P2 {
28522917

28532918
numThreads := numThreads(0)
28542919

2855-
if numThreads < 2 || npoints < 32 {
2920+
if numThreads < 2 {
28562921
sz := int(C.blst_p2s_mult_pippenger_scratch_sizeof(C.size_t(npoints))) / 8
28572922
scratch := make([]uint64, sz)
28582923

@@ -2901,6 +2966,71 @@ func P2AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P2 {
29012966
return &ret
29022967
}
29032968

2969+
if npoints < 32 {
2970+
if numThreads > npoints {
2971+
numThreads = npoints
2972+
}
2973+
2974+
curItem := uint32(0)
2975+
msgs := make(chan P2, numThreads)
2976+
2977+
for tid := 0; tid < numThreads; tid++ {
2978+
go func() {
2979+
var acc P2
2980+
2981+
for {
2982+
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
2983+
if workItem >= npoints {
2984+
break
2985+
}
2986+
2987+
var point *P2Affine
2988+
switch val := pointsIf.(type) {
2989+
case []*P2Affine:
2990+
point = val[workItem]
2991+
case []P2Affine:
2992+
point = &val[workItem]
2993+
case P2Affines:
2994+
point = &val[workItem]
2995+
}
2996+
2997+
var scalar *C.byte
2998+
switch val := scalarsIf.(type) {
2999+
case []byte:
3000+
scalar = (*C.byte)(&val[workItem*nbytes])
3001+
case [][]byte:
3002+
scalar = scalars[workItem]
3003+
case []Scalar:
3004+
if nbits > 248 {
3005+
scalar = &val[workItem].b[0]
3006+
} else {
3007+
scalar = scalars[workItem]
3008+
}
3009+
case []*Scalar:
3010+
scalar = scalars[workItem]
3011+
}
3012+
3013+
C.go_p2_mult_n_acc(&acc, &point.x, true,
3014+
scalar, C.size_t(nbits))
3015+
}
3016+
3017+
msgs <- acc
3018+
}()
3019+
}
3020+
3021+
ret := <-msgs
3022+
for tid := 1; tid < numThreads; tid++ {
3023+
point := <-msgs
3024+
C.blst_p2_add_or_double(&ret, &ret, &point)
3025+
}
3026+
3027+
for i := range scalars {
3028+
scalars[i] = nil
3029+
}
3030+
3031+
return &ret
3032+
}
3033+
29043034
// this is sizeof(scratch[0])
29053035
sz := int(C.blst_p2s_mult_pippenger_scratch_sizeof(0)) / 8
29063036

bindings/go/blst_minpk_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,13 @@ func TestMultiScalarP1(t *testing.T) {
649649
for i := range points {
650650
points[i] = *generator.Mult(scalars[i*4:(i+1)*4])
651651
refs[i] = *points[i].Mult(scalars[i*16:(i+1)*16], 128)
652+
if i < 27 {
653+
ref := P1s(refs[:i+1]).Add()
654+
ret := P1s(points[:i+1]).Mult(scalars, 128)
655+
if !ref.Equals(ret) {
656+
t.Errorf("failed self-consistency multi-scalar test")
657+
}
658+
}
652659
}
653660
ref := P1s(refs).Add()
654661
ret := P1s(points).Mult(scalars, 128)

bindings/go/blst_minsig_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,13 @@ func TestMultiScalarP2(t *testing.T) {
653653
for i := range points {
654654
points[i] = *generator.Mult(scalars[i*4:(i+1)*4])
655655
refs[i] = *points[i].Mult(scalars[i*16:(i+1)*16], 128)
656+
if i < 27 {
657+
ref := P2s(refs[:i+1]).Add()
658+
ret := P2s(points[:i+1]).Mult(scalars, 128)
659+
if !ref.Equals(ret) {
660+
t.Errorf("failed self-consistency multi-scalar test")
661+
}
662+
}
656663
}
657664
ref := P2s(refs).Add()
658665
ret := P2s(points).Mult(scalars, 128)

bindings/go/blst_px.tgo

+66-1
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
459459

460460
numThreads := numThreads(0)
461461

462-
if numThreads < 2 || npoints < 32 {
462+
if numThreads < 2 {
463463
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(C.size_t(npoints)))/8
464464
scratch := make([]uint64, sz)
465465

@@ -508,6 +508,71 @@ func P1AffinesMult(pointsIf interface{}, scalarsIf interface{}, nbits int) *P1 {
508508
return &ret
509509
}
510510

511+
if npoints < 32 {
512+
if numThreads > npoints {
513+
numThreads = npoints
514+
}
515+
516+
curItem := uint32(0)
517+
msgs := make(chan P1, numThreads)
518+
519+
for tid := 0; tid < numThreads; tid++ {
520+
go func() {
521+
var acc P1
522+
523+
for {
524+
workItem := int(atomic.AddUint32(&curItem, 1) - 1)
525+
if workItem >= npoints {
526+
break
527+
}
528+
529+
var point *P1Affine
530+
switch val := pointsIf.(type) {
531+
case []*P1Affine:
532+
point = val[workItem]
533+
case []P1Affine:
534+
point = &val[workItem]
535+
case P1Affines:
536+
point = &val[workItem]
537+
}
538+
539+
var scalar *C.byte
540+
switch val := scalarsIf.(type) {
541+
case []byte:
542+
scalar = (*C.byte)(&val[workItem*nbytes])
543+
case [][]byte:
544+
scalar = scalars[workItem]
545+
case []Scalar:
546+
if nbits > 248 {
547+
scalar = &val[workItem].b[0]
548+
} else {
549+
scalar = scalars[workItem]
550+
}
551+
case []*Scalar:
552+
scalar = scalars[workItem]
553+
}
554+
555+
C.go_p1_mult_n_acc(&acc, &point.x, true,
556+
scalar, C.size_t(nbits))
557+
}
558+
559+
msgs <- acc
560+
}()
561+
}
562+
563+
ret := <-msgs
564+
for tid := 1; tid < numThreads; tid++ {
565+
point := <- msgs
566+
C.blst_p1_add_or_double(&ret, &ret, &point);
567+
}
568+
569+
for i := range(scalars) {
570+
scalars[i] = nil
571+
}
572+
573+
return &ret
574+
}
575+
511576
// this is sizeof(scratch[0])
512577
sz := int(C.blst_p1s_mult_pippenger_scratch_sizeof(0))/8
513578

0 commit comments

Comments
 (0)