Skip to content

Commit cf50bb9

Browse files
committed
bindings/rust: add MultiScalar trait.
This allows to perform multi-scalar operations directly on slices of affine points without going through p{12}_affines classes.
1 parent afd60c5 commit cf50bb9

File tree

3 files changed

+43
-14
lines changed

3 files changed

+43
-14
lines changed

bindings/rust/src/lib.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,13 @@ pub mod min_sig {
19701970
);
19711971
}
19721972

1973+
pub trait MultiScalar {
1974+
type Output;
1975+
1976+
fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output;
1977+
fn add(&self) -> Self::Output;
1978+
}
1979+
19731980
#[cfg(feature = "std")]
19741981
include!("pippenger.rs");
19751982

bindings/rust/src/pippenger-no_std.rs

+17-6
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,26 @@ macro_rules! pippenger_mult_impl {
6060
}
6161

6262
pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
63-
let npoints = self.points.len();
63+
self.as_slice().mult(scalars, nbits)
64+
}
65+
66+
pub fn add(&self) -> $point {
67+
self.as_slice().add()
68+
}
69+
}
70+
71+
impl MultiScalar for [$point_affine] {
72+
type Output = $point;
73+
74+
fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
75+
let npoints = self.len();
6476
let nbytes = (nbits + 7) / 8;
6577

6678
if scalars.len() < nbytes * npoints {
6779
panic!("scalars length mismatch");
6880
}
6981

70-
let p: [*const $point_affine; 2] =
71-
[&self.points[0], ptr::null()];
82+
let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
7283
let s: [*const u8; 2] = [&scalars[0], ptr::null()];
7384

7485
let mut ret = <$point>::default();
@@ -89,10 +100,10 @@ macro_rules! pippenger_mult_impl {
89100
ret
90101
}
91102

92-
pub fn add(&self) -> $point {
93-
let npoints = self.points.len();
103+
fn add(&self) -> $point {
104+
let npoints = self.len();
94105

95-
let p: [*const _; 2] = [&self.points[0], ptr::null()];
106+
let p: [*const _; 2] = [&self[0], ptr::null()];
96107
let mut ret = <$point>::default();
97108
unsafe { $add(&mut ret, &p[0], npoints) };
98109

bindings/rust/src/pippenger.rs

+19-8
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,19 @@ macro_rules! pippenger_mult_impl {
114114
}
115115

116116
pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
117-
let npoints = self.points.len();
117+
self.as_slice().mult(scalars, nbits)
118+
}
119+
120+
pub fn add(&self) -> $point {
121+
self.as_slice().add()
122+
}
123+
}
124+
125+
impl MultiScalar for [$point_affine] {
126+
type Output = $point;
127+
128+
fn mult(&self, scalars: &[u8], nbits: usize) -> $point {
129+
let npoints = self.len();
118130
let nbytes = (nbits + 7) / 8;
119131

120132
if scalars.len() < nbytes * npoints {
@@ -124,8 +136,7 @@ macro_rules! pippenger_mult_impl {
124136
let pool = mt::da_pool();
125137
let ncpus = pool.max_count();
126138
if ncpus < 2 || npoints < 32 {
127-
let p: [*const $point_affine; 2] =
128-
[&self.points[0], ptr::null()];
139+
let p: [*const $point_affine; 2] = [&self[0], ptr::null()];
129140
let s: [*const u8; 2] = [&scalars[0], ptr::null()];
130141

131142
unsafe {
@@ -178,7 +189,7 @@ macro_rules! pippenger_mult_impl {
178189
}
179190
let grid = &grid[..];
180191

181-
let points = &self.points[..];
192+
let points = &self[..];
182193
let sz = unsafe { $scratch_sizeof(0) / 8 };
183194

184195
let mut row_sync: Vec<AtomicUsize> = Vec::with_capacity(ny);
@@ -262,13 +273,13 @@ macro_rules! pippenger_mult_impl {
262273
ret
263274
}
264275

265-
pub fn add(&self) -> $point {
266-
let npoints = self.points.len();
276+
fn add(&self) -> $point {
277+
let npoints = self.len();
267278

268279
let pool = mt::da_pool();
269280
let ncpus = pool.max_count();
270281
if ncpus < 2 || npoints < 384 {
271-
let p: [*const _; 2] = [&self.points[0], ptr::null()];
282+
let p: [*const _; 2] = [&self[0], ptr::null()];
272283
let mut ret = <$point>::default();
273284
unsafe { $add(&mut ret, &p[0], npoints) };
274285
return ret;
@@ -295,7 +306,7 @@ macro_rules! pippenger_mult_impl {
295306
if work >= npoints {
296307
break;
297308
}
298-
p[0] = &self.points[work];
309+
p[0] = &self[work];
299310
if work + chunk > npoints {
300311
chunk = npoints - work;
301312
}

0 commit comments

Comments
 (0)