Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize AdditiveMonoids sum() method #39726

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sage/arith/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3625,6 +3625,7 @@ def CRT_list(values, moduli=None):

# The result is computed using a binary tree. In typical cases,
# this scales much better than folding the list from one side.
# See also sage.misc.misc_c.balanced_list_prod
from sage.arith.functions import lcm
while len(values) > 1:
vs, ms = values[::2], moduli[::2]
Expand Down
18 changes: 17 additions & 1 deletion src/sage/categories/additive_monoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,24 @@ def sum(self, args):
0
sage: S.sum(()).parent() == S
True

TESTS:

The following should be reasonably fast (0.5s each)::

sage: R.<x,y> = QQ[]
sage: ignore = R.sum(
....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200))
sage: ignore = R.sum([
....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200)])

Summing an empty iterator::

sage: R.sum(1 for i in range(0))
0
"""
return sum(args, self.zero())
from sage.misc.misc_c import balanced_sum
return balanced_sum(args, self.zero(), 20)

class Homsets(HomsetsCategory):

Expand Down
41 changes: 25 additions & 16 deletions src/sage/misc/misc_c.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ cdef balanced_list_prod(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutof
return balanced_list_prod(L, offset, k, cutoff) * balanced_list_prod(L, offset + k, count - k, cutoff)


cpdef iterator_prod(L, z=None):
cpdef iterator_prod(L, z=None, bint multiply=True):
"""
Attempt to do a balanced product of an arbitrary and unknown length
sequence (such as a generator). Intermediate multiplications are always
Expand All @@ -207,11 +207,18 @@ cpdef iterator_prod(L, z=None):
sage: L = [NonAssociative(label) for label in 'abcdef']
sage: iterator_prod(L)
(((a*b)*(c*d))*(e*f))

When ``multiply=False``, the items are added up instead (however this
interface should not be used directly, use :func:`balanced_sum` instead)::

sage: iterator_prod((1..5), multiply=False)
15
"""
# TODO: declaring sub_prods as a list should speed much of this up.
cdef list sub_prods
L = iter(L)
if z is None:
sub_prods = [next(L)] * 10
sub_prods = [next(L)] * 10 # only take one element from L, the rest are just placeholders
# the list size can be dynamically increased later
else:
sub_prods = [z] * 10

Expand All @@ -232,17 +239,26 @@ cpdef iterator_prod(L, z=None):
else:
# for even i we multiply the stack down
# by the number of factors of 2 in i
x = sub_prods[tip] * x
if multiply:
x = sub_prods[tip] * x
else:
x = sub_prods[tip] + x
for j from 1 <= j < 64:
if i & (1 << j):
break
tip -= 1
x = sub_prods[tip] * x
if multiply:
x = sub_prods[tip] * x
else:
x = sub_prods[tip] + x
sub_prods[tip] = x

while tip > 0:
tip -= 1
sub_prods[tip] *= sub_prods[tip + 1]
if multiply:
sub_prods[tip] *= sub_prods[tip + 1]
else:
sub_prods[tip] += sub_prods[tip + 1]

return sub_prods[0]

Expand Down Expand Up @@ -366,14 +382,7 @@ def balanced_sum(x, z=None, Py_ssize_t recursion_cutoff=5):
if type(x) is not list and type(x) is not tuple:

if PyGen_Check(x):
# lazy list, do lazy product
try:
sum = copy(next(x)) if z is None else z + next(x)
for a in x:
sum += a
return sum
except StopIteration:
x = []
return iterator_prod(x, z, multiply=False)
else:
try:
return x.sum()
Expand Down Expand Up @@ -405,8 +414,8 @@ cdef balanced_list_sum(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutoff

- ``L`` -- the terms (MUST be a tuple or list)
- ``off`` -- offset in the list from which to start
- ``count`` -- how many terms in the sum
- ``cutoff`` -- the minimum count to recurse on. Must be at least 2
- ``count`` -- how many terms in the sum. Must be positive
- ``cutoff`` -- the minimum count to recurse on. Must be at least 2

OUTPUT:

Expand Down
Loading