Skip to content

Commit 3de4d53

Browse files
committed
Merge pull request #9650 from stevengj/cumsum_pairwise
get pairwise-sum accuracy for cumsum (fix #9648)
2 parents d74d8cc + c7c89a4 commit 3de4d53

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

base/array.jl

+13-9
Original file line numberDiff line numberDiff line change
@@ -1455,20 +1455,24 @@ symdiff(a, b, rest...) = symdiff(a, symdiff(b, rest...))
14551455
_cumsum_type{T<:Number}(v::AbstractArray{T}) = typeof(+zero(T))
14561456
_cumsum_type(v) = typeof(v[1]+v[1])
14571457

1458-
for (f, fp, op) = ((:cumsum, :cumsum_pairwise, :+),
1459-
(:cumprod, :cumprod_pairwise, :*) )
1460-
# in-place cumsum of c = s+v(i1:n), using pairwise summation as for sum
1461-
@eval function ($fp)(v::AbstractVector, c::AbstractVector, s, i1, n)
1458+
for (f, fp, op) = ((:cumsum, :cumsum_pairwise!, :+),
1459+
(:cumprod, :cumprod_pairwise!, :*) )
1460+
# in-place cumsum of c = s+v[range(i1,n)], using pairwise summation
1461+
@eval function ($fp){T}(v::AbstractVector, c::AbstractVector{T}, s, i1, n)
1462+
local s_::T # for sum(v[range(i1,n)]), i.e. sum without s
14621463
if n < 128
1463-
@inbounds c[i1] = ($op)(s, v[i1])
1464+
@inbounds s_ = v[i1]
1465+
@inbounds c[i1] = ($op)(s, s_)
14641466
for i = i1+1:i1+n-1
1465-
@inbounds c[i] = $(op)(c[i-1], v[i])
1467+
@inbounds s_ = $(op)(s_, v[i])
1468+
@inbounds c[i] = $(op)(s, s_)
14661469
end
14671470
else
1468-
n2 = div(n,2)
1469-
($fp)(v, c, s, i1, n2)
1470-
($fp)(v, c, c[(i1+n2)-1], i1+n2, n-n2)
1471+
n2 = n >> 1
1472+
s_ = ($fp)(v, c, s, i1, n2)
1473+
s_ = $(op)(s_, ($fp)(v, c, s + s_, i1+n2, n-n2))
14711474
end
1475+
return s_
14721476
end
14731477

14741478
@eval function ($f)(v::AbstractVector)

test/arrayops.jl

+5
Original file line numberDiff line numberDiff line change
@@ -974,3 +974,8 @@ a = [ [ 1 0 0 ], [ 0 0 0 ] ]
974974
@test rotl90(a,4) == a
975975
@test rotr90(a,4) == a
976976
@test rot180(a,2) == a
977+
978+
# issue #9648
979+
let x = fill(1.5f0, 10^7)
980+
@test abs(1.5f7 - cumsum(x)[end]) < 3*eps(1.5f7)
981+
end

0 commit comments

Comments
 (0)