Skip to content

Commit a6290fe

Browse files
committed
Optimize AdditiveMonoids sum() method
1 parent 7d83063 commit a6290fe

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

src/sage/arith/misc.py

+1
Original file line numberDiff line numberDiff line change
@@ -3625,6 +3625,7 @@ def CRT_list(values, moduli=None):
36253625

36263626
# The result is computed using a binary tree. In typical cases,
36273627
# this scales much better than folding the list from one side.
3628+
# See also sage.misc.misc_c.balanced_list_prod
36283629
from sage.arith.functions import lcm
36293630
while len(values) > 1:
36303631
vs, ms = values[::2], moduli[::2]

src/sage/categories/additive_monoids.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,24 @@ def sum(self, args):
6868
0
6969
sage: S.sum(()).parent() == S
7070
True
71+
72+
TESTS:
73+
74+
The following should be reasonably fast (0.5s each)::
75+
76+
sage: R.<x,y> = QQ[]
77+
sage: ignore = R.sum(
78+
....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200))
79+
sage: ignore = R.sum([
80+
....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200)])
81+
82+
Summing an empty iterator::
83+
84+
sage: R.sum(1 for i in range(0))
85+
0
7186
"""
72-
return sum(args, self.zero())
87+
from sage.misc.misc_c import balanced_sum
88+
return balanced_sum(args, self.zero(), 20)
7389

7490
class Homsets(HomsetsCategory):
7591

src/sage/misc/misc_c.pyx

+25-16
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ cdef balanced_list_prod(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutof
185185
return balanced_list_prod(L, offset, k, cutoff) * balanced_list_prod(L, offset + k, count - k, cutoff)
186186

187187

188-
cpdef iterator_prod(L, z=None):
188+
cpdef iterator_prod(L, z=None, bint multiply=True):
189189
"""
190190
Attempt to do a balanced product of an arbitrary and unknown length
191191
sequence (such as a generator). Intermediate multiplications are always
@@ -207,11 +207,18 @@ cpdef iterator_prod(L, z=None):
207207
sage: L = [NonAssociative(label) for label in 'abcdef']
208208
sage: iterator_prod(L)
209209
(((a*b)*(c*d))*(e*f))
210+
211+
When ``multiply=False``, the items are added up instead (however this
212+
interface should not be used directly, use :func:`balanced_sum` instead)::
213+
214+
sage: iterator_prod((1..5), multiply=False)
215+
15
210216
"""
211-
# TODO: declaring sub_prods as a list should speed much of this up.
217+
cdef list sub_prods
212218
L = iter(L)
213219
if z is None:
214-
sub_prods = [next(L)] * 10
220+
sub_prods = [next(L)] * 10 # only take one element from L, the rest are just placeholders
221+
# the list size can be dynamically increased later
215222
else:
216223
sub_prods = [z] * 10
217224

@@ -232,17 +239,26 @@ cpdef iterator_prod(L, z=None):
232239
else:
233240
# for even i we multiply the stack down
234241
# by the number of factors of 2 in i
235-
x = sub_prods[tip] * x
242+
if multiply:
243+
x = sub_prods[tip] * x
244+
else:
245+
x = sub_prods[tip] + x
236246
for j from 1 <= j < 64:
237247
if i & (1 << j):
238248
break
239249
tip -= 1
240-
x = sub_prods[tip] * x
250+
if multiply:
251+
x = sub_prods[tip] * x
252+
else:
253+
x = sub_prods[tip] + x
241254
sub_prods[tip] = x
242255

243256
while tip > 0:
244257
tip -= 1
245-
sub_prods[tip] *= sub_prods[tip + 1]
258+
if multiply:
259+
sub_prods[tip] *= sub_prods[tip + 1]
260+
else:
261+
sub_prods[tip] += sub_prods[tip + 1]
246262

247263
return sub_prods[0]
248264

@@ -366,14 +382,7 @@ def balanced_sum(x, z=None, Py_ssize_t recursion_cutoff=5):
366382
if type(x) is not list and type(x) is not tuple:
367383

368384
if PyGen_Check(x):
369-
# lazy list, do lazy product
370-
try:
371-
sum = copy(next(x)) if z is None else z + next(x)
372-
for a in x:
373-
sum += a
374-
return sum
375-
except StopIteration:
376-
x = []
385+
return iterator_prod(x, z, multiply=False)
377386
else:
378387
try:
379388
return x.sum()
@@ -405,8 +414,8 @@ cdef balanced_list_sum(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutoff
405414
406415
- ``L`` -- the terms (MUST be a tuple or list)
407416
- ``off`` -- offset in the list from which to start
408-
- ``count`` -- how many terms in the sum
409-
- ``cutoff`` -- the minimum count to recurse on. Must be at least 2
417+
- ``count`` -- how many terms in the sum. Must be positive
418+
- ``cutoff`` -- the minimum count to recurse on. Must be at least 2
410419
411420
OUTPUT:
412421

0 commit comments

Comments
 (0)