Skip to content

Commit dc20816

Browse files
committed
Use destination type to determine output of cumsum! and cumprod!
1 parent 512fbcd commit dc20816

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed

base/accumulate.jl

+26-7
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,34 @@ function _accumulate_pairwise_small!(op, dest::AbstractArray{T}, itr, accv, w, i
416416
end
417417
end
418418

419+
"""
420+
Base.ConvertOp{T}(op)(x,y)
421+
422+
An operator which converts `x` and `y` to type `T` before performing the `op`.
423+
424+
The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
425+
"""
426+
struct ConvertOp{T,O} <: Function
427+
op::O
428+
end
429+
ConvertOp{T}(op::O) where {T,O} = ConvertOp{T,O}(op)
430+
(c::ConvertOp{T})(x,y) where {T} = c.op(convert(T,x),convert(T,y))
431+
432+
reduce_first(c::ConvertOp{T},x) = reduce_first(c.op, convert(T,x))
433+
434+
419435

420436

421437
function cumsum!(out, v::AbstractVector{T}) where T
422438
# we dispatch on the possibility of numerical accuracy issues
423439
cumsum!(out, v, ArithmeticStyle(T))
424440
end
425-
cumsum!(out, v::AbstractVector, ::ArithmeticRounds) = accumulate_pairwise!(+, out, v)
426-
cumsum!(out, v::AbstractVector, ::ArithmeticUnknown) = accumulate_pairwise!(+, out, v)
427-
cumsum!(out, v::AbstractVector, ::ArithmeticStyle) = accumulate!(+, out, v)
441+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticRounds) where {T} =
442+
accumulate_pairwise!(ConvertOp{T}(+), out, v)
443+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticUnknown) where {T} =
444+
accumulate_pairwise!(ConvertOp{T}(+), out, v)
445+
cumsum!(out::AbstractVector{T}, v::AbstractVector, ::ArithmeticStyle) where {T} =
446+
accumulate!(ConvertOp{T}(+), out, v)
428447

429448
"""
430449
cumsum(A, dim::Integer)
@@ -488,14 +507,14 @@ cumsum(v::AbstractVector, ::ArithmeticStyle) = accumulate(add_sum, v)
488507
489508
Cumulative sum of `A` along the dimension `dim`, storing the result in `B`. See also [`cumsum`](@ref).
490509
"""
491-
cumsum!(dest, A, dim::Integer) = accumulate!(+, dest, A, dim)
510+
cumsum!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(+), dest, A, dim)
492511

493512
"""
494513
cumsum!(y::AbstractVector, x::AbstractVector)
495514
496515
Cumulative sum of a vector `x`, storing the result in `y`. See also [`cumsum`](@ref).
497516
"""
498-
cumsum!(dest, itr) = accumulate!(+, dest, src)
517+
cumsum!(dest::AbstractArray{T}, itr) = accumulate!(ConvertOp{T}(+), dest, src)
499518

500519
"""
501520
cumprod(A, dim::Integer)
@@ -555,12 +574,12 @@ cumprod(x::AbstractVector) = accumulate(mul_prod, x)
555574
Cumulative product of `A` along the dimension `dim`, storing the result in `B`.
556575
See also [`cumprod`](@ref).
557576
"""
558-
cumprod!(dest, A, dim::Integer) = accumulate!(*, dest, A, dim)
577+
cumprod!(dest::AbstractArray{T}, A, dim::Integer) where {T} = accumulate!(ConvertOp{T}(*), dest, A, dim)
559578

560579
"""
561580
cumprod!(y::AbstractVector, x::AbstractVector)
562581
563582
Cumulative product of a vector `x`, storing the result in `y`.
564583
See also [`cumprod`](@ref).
565584
"""
566-
cumprod!(dest, itr) = accumulate!(*, dest, itr)
585+
cumprod!(dest::AbstractArray{T}, itr) where {T} = accumulate!(ConvertOp{T}(*), dest, itr)

base/reduce.jl

+27
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_sum(x::SmallSigned) = Int(x)
2828
add_sum(x::SmallUnsigned) = UInt(x)
2929
add_sum(X::AbstractArray) = broadcast(add_sum, X)
3030

31+
3132
"""
3233
Base.mul_prod(x,y)
3334
@@ -42,6 +43,32 @@ mul_prod(x) = *(x)
4243
mul_prod(x::SmallSigned) = Int(x)
4344
mul_prod(x::SmallUnsigned) = UInt(x)
4445

46+
"""
47+
Base.ConvertOp{T}(op)(x,y)
48+
49+
An operator which converts `x` and `y` to type `T` before performing the `op`.
50+
51+
The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
52+
"""
53+
struct ConvertOp{T,O} <: Function
54+
op::O
55+
end
56+
ConvertOp{T}(op::O) where {T,O} = ConvertOp{T,O}(op)
57+
(c::ConvertOp{T})(x,y) where {T} = c.op(convert(T,x),convert(T,y))
58+
reduce_first(c::ConvertOp{T},x) = reduce_first(c.op, convert(T,x))
59+
60+
61+
"""
62+
Base.ConvertAdd{T}()(x,y)
63+
64+
An addition operator which converts `x` and `y` to type `T` before performing the addition.
65+
The main purpose is for in [`cumsum!`](@ref), where `T` is determined by the output array.
66+
"""
67+
struct ConvertAdd{T} end
68+
(::ConvertAdd{T})(x) where {T} = +(convert(T,x))
69+
(::ConvertAdd{T})(x,y) where {T} = convert(T,x) + convert(T,y)
70+
71+
4572
## foldl && mapfoldl
4673

4774
@noinline function mapfoldl_impl(f, op, v0, itr, i)

0 commit comments

Comments
 (0)