@@ -416,15 +416,34 @@ function _accumulate_pairwise_small!(op, dest::AbstractArray{T}, itr, accv, w, i
416
416
end
417
417
end
418
418
419
+ """
420
+ Base.ConvertOp{T}(op)(x,y)
421
+
422
+ An operator which converts `x` and `y` to type `T` before performing the `op`.
423
+
424
+ The main purpose is for use in [`cumsum!`](@ref) and [`cumprod!`](@ref), where `T` is determined by the output array.
425
+ """
426
+ struct ConvertOp{T,O} <: Function
427
+ op:: O
428
+ end
429
+ ConvertOp {T} (op:: O ) where {T,O} = ConvertOp {T,O} (op)
430
+ (c:: ConvertOp{T} )(x,y) where {T} = c. op (convert (T,x),convert (T,y))
431
+
432
+ reduce_first (c:: ConvertOp{T} ,x) = reduce_first (c. op, convert (T,x))
433
+
434
+
419
435
420
436
421
437
function cumsum! (out, v:: AbstractVector{T} ) where T
422
438
# we dispatch on the possibility of numerical accuracy issues
423
439
cumsum! (out, v, ArithmeticStyle (T))
424
440
end
425
- cumsum! (out, v:: AbstractVector , :: ArithmeticRounds ) = accumulate_pairwise! (+ , out, v)
426
- cumsum! (out, v:: AbstractVector , :: ArithmeticUnknown ) = accumulate_pairwise! (+ , out, v)
427
- cumsum! (out, v:: AbstractVector , :: ArithmeticStyle ) = accumulate! (+ , out, v)
441
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticRounds ) where {T} =
442
+ accumulate_pairwise! (ConvertOp {T} (+ ), out, v)
443
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticUnknown ) where {T} =
444
+ accumulate_pairwise! (ConvertOp {T} (+ ), out, v)
445
+ cumsum! (out:: AbstractVector{T} , v:: AbstractVector , :: ArithmeticStyle ) where {T} =
446
+ accumulate! (ConvertOp {T} (+ ), out, v)
428
447
429
448
"""
430
449
cumsum(A, dim::Integer)
@@ -488,14 +507,14 @@ cumsum(v::AbstractVector, ::ArithmeticStyle) = accumulate(add_sum, v)
488
507
489
508
Cumulative sum of `A` along the dimension `dim`, storing the result in `B`. See also [`cumsum`](@ref).
490
509
"""
491
- cumsum! (dest, A, dim:: Integer ) = accumulate! (+ , dest, A, dim)
510
+ cumsum! (dest:: AbstractArray{T} , A, dim:: Integer ) where {T} = accumulate! (ConvertOp {T} ( + ) , dest, A, dim)
492
511
493
512
"""
494
513
cumsum!(y::AbstractVector, x::AbstractVector)
495
514
496
515
Cumulative sum of a vector `x`, storing the result in `y`. See also [`cumsum`](@ref).
497
516
"""
498
- cumsum! (dest, itr) = accumulate! (+ , dest, src)
517
+ cumsum! (dest:: AbstractArray{T} , itr) = accumulate! (ConvertOp {T} ( + ) , dest, src)
499
518
500
519
"""
501
520
cumprod(A, dim::Integer)
@@ -555,12 +574,12 @@ cumprod(x::AbstractVector) = accumulate(mul_prod, x)
555
574
Cumulative product of `A` along the dimension `dim`, storing the result in `B`.
556
575
See also [`cumprod`](@ref).
557
576
"""
558
- cumprod! (dest, A, dim:: Integer ) = accumulate! (* , dest, A, dim)
577
+ cumprod! (dest:: AbstractArray{T} , A, dim:: Integer ) where {T} = accumulate! (ConvertOp {T} ( * ) , dest, A, dim)
559
578
560
579
"""
561
580
cumprod!(y::AbstractVector, x::AbstractVector)
562
581
563
582
Cumulative product of a vector `x`, storing the result in `y`.
564
583
See also [`cumprod`](@ref).
565
584
"""
566
- cumprod! (dest, itr) = accumulate! (* , dest, itr)
585
+ cumprod! (dest:: AbstractArray{T} , itr) where {T} = accumulate! (ConvertOp {T} ( * ) , dest, itr)
0 commit comments