From a6290fe9ae2f5a068bc61219c83a565d48d10a01 Mon Sep 17 00:00:00 2001 From: user202729 <25191436+user202729@users.noreply.github.com> Date: Mon, 17 Mar 2025 16:02:51 +0700 Subject: [PATCH] Optimize AdditiveMonoids sum() method --- src/sage/arith/misc.py | 1 + src/sage/categories/additive_monoids.py | 18 ++++++++++- src/sage/misc/misc_c.pyx | 41 +++++++++++++++---------- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/src/sage/arith/misc.py b/src/sage/arith/misc.py index 6a6b42e26f7..310256f02f6 100644 --- a/src/sage/arith/misc.py +++ b/src/sage/arith/misc.py @@ -3625,6 +3625,7 @@ def CRT_list(values, moduli=None): # The result is computed using a binary tree. In typical cases, # this scales much better than folding the list from one side. + # See also sage.misc.misc_c.balanced_list_prod from sage.arith.functions import lcm while len(values) > 1: vs, ms = values[::2], moduli[::2] diff --git a/src/sage/categories/additive_monoids.py b/src/sage/categories/additive_monoids.py index 70a195f582e..0782c4f1540 100644 --- a/src/sage/categories/additive_monoids.py +++ b/src/sage/categories/additive_monoids.py @@ -68,8 +68,24 @@ def sum(self, args): 0 sage: S.sum(()).parent() == S True + + TESTS: + + The following should be reasonably fast (0.5s each):: + + sage: R. = QQ[] + sage: ignore = R.sum( + ....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200)) + sage: ignore = R.sum([ + ....: QQ.random_element()*x^i*y^j for i in range(200) for j in range(200)]) + + Summing an empty iterator:: + + sage: R.sum(1 for i in range(0)) + 0 """ - return sum(args, self.zero()) + from sage.misc.misc_c import balanced_sum + return balanced_sum(args, self.zero(), 20) class Homsets(HomsetsCategory): diff --git a/src/sage/misc/misc_c.pyx b/src/sage/misc/misc_c.pyx index 2b7136ce584..b499075c23c 100644 --- a/src/sage/misc/misc_c.pyx +++ b/src/sage/misc/misc_c.pyx @@ -185,7 +185,7 @@ cdef balanced_list_prod(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutof return balanced_list_prod(L, offset, k, cutoff) * balanced_list_prod(L, offset + k, count - k, cutoff) -cpdef iterator_prod(L, z=None): +cpdef iterator_prod(L, z=None, bint multiply=True): """ Attempt to do a balanced product of an arbitrary and unknown length sequence (such as a generator). Intermediate multiplications are always @@ -207,11 +207,18 @@ cpdef iterator_prod(L, z=None): sage: L = [NonAssociative(label) for label in 'abcdef'] sage: iterator_prod(L) (((a*b)*(c*d))*(e*f)) + + When ``multiply=False``, the items are added up instead (however this + interface should not be used directly, use :func:`balanced_sum` instead):: + + sage: iterator_prod((1..5), multiply=False) + 15 """ - # TODO: declaring sub_prods as a list should speed much of this up. + cdef list sub_prods L = iter(L) if z is None: - sub_prods = [next(L)] * 10 + sub_prods = [next(L)] * 10 # only take one element from L, the rest are just placeholders + # the list size can be dynamically increased later else: sub_prods = [z] * 10 @@ -232,17 +239,26 @@ cpdef iterator_prod(L, z=None): else: # for even i we multiply the stack down # by the number of factors of 2 in i - x = sub_prods[tip] * x + if multiply: + x = sub_prods[tip] * x + else: + x = sub_prods[tip] + x for j from 1 <= j < 64: if i & (1 << j): break tip -= 1 - x = sub_prods[tip] * x + if multiply: + x = sub_prods[tip] * x + else: + x = sub_prods[tip] + x sub_prods[tip] = x while tip > 0: tip -= 1 - sub_prods[tip] *= sub_prods[tip + 1] + if multiply: + sub_prods[tip] *= sub_prods[tip + 1] + else: + sub_prods[tip] += sub_prods[tip + 1] return sub_prods[0] @@ -366,14 +382,7 @@ def balanced_sum(x, z=None, Py_ssize_t recursion_cutoff=5): if type(x) is not list and type(x) is not tuple: if PyGen_Check(x): - # lazy list, do lazy product - try: - sum = copy(next(x)) if z is None else z + next(x) - for a in x: - sum += a - return sum - except StopIteration: - x = [] + return iterator_prod(x, z, multiply=False) else: try: return x.sum() @@ -405,8 +414,8 @@ cdef balanced_list_sum(L, Py_ssize_t offset, Py_ssize_t count, Py_ssize_t cutoff - ``L`` -- the terms (MUST be a tuple or list) - ``off`` -- offset in the list from which to start - - ``count`` -- how many terms in the sum - - ``cutoff`` -- the minimum count to recurse on. Must be at least 2 + - ``count`` -- how many terms in the sum. Must be positive + - ``cutoff`` -- the minimum count to recurse on. Must be at least 2 OUTPUT: