Skip to content

Commit 0098f07

Browse files
committed
Implement accumulate and friends for arbitrary iterators
1 parent c2cd601 commit 0098f07

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

NEWS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ New library features
6565

6666
* `isapprox` (or ``) now has a one-argument "curried" method `isapprox(x)` which returns a function, like `isequal` (or `==`)` ([#32305]).
6767
* `Ref{NTuple{N,T}}` can be passed to `Ptr{T}`/`Ref{T}` `ccall` signatures ([#34199])
68-
* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]).
68+
* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]) and arbitrary iterators ([#34656]).
6969

7070

7171
Standard library changes

base/accumulate.jl

+20-4
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ function cumsum(A::AbstractArray{T}; dims::Integer) where T
9292
end
9393

9494
"""
95-
cumsum(itr::Union{AbstractVector,Tuple})
95+
cumsum(itr)
9696
9797
Cumulative sum an iterator. See also [`cumsum!`](@ref)
9898
to use a preallocated output array, both for performance and to control the precision of the
9999
output (e.g. to avoid overflow).
100100
101101
!!! compat "Julia 1.5"
102-
`cumsum` on a tuple requires at least Julia 1.5.
102+
`cumsum` on a non-array iterator requires at least Julia 1.5.
103103
104104
# Examples
105105
```jldoctest
@@ -117,6 +117,12 @@ julia> cumsum([fill(1, 2) for i in 1:3])
117117
118118
julia> cumsum((1, 1, 1))
119119
(1, 2, 3)
120+
121+
julia> cumsum(x^2 for x in 1:3)
122+
3-element Array{Int64,1}:
123+
1
124+
5
125+
14
120126
```
121127
"""
122128
cumsum(x::AbstractVector) = cumsum(x, dims=1)
@@ -170,14 +176,14 @@ function cumprod(A::AbstractArray; dims::Integer)
170176
end
171177

172178
"""
173-
cumprod(itr::Union{AbstractVector,Tuple})
179+
cumprod(itr)
174180
175181
Cumulative product of an iterator. See also
176182
[`cumprod!`](@ref) to use a preallocated output array, both for performance and
177183
to control the precision of the output (e.g. to avoid overflow).
178184
179185
!!! compat "Julia 1.5"
180-
`cumprod` on a tuple requires at least Julia 1.5.
186+
`cumprod` on a non-array iterator requires at least Julia 1.5.
181187
182188
# Examples
183189
```jldoctest
@@ -195,6 +201,12 @@ julia> cumprod([fill(1//3, 2, 2) for i in 1:3])
195201
196202
julia> cumprod((1, 2, 1))
197203
(1, 2, 2)
204+
205+
julia> cumprod(x^2 for x in 1:3)
206+
3-element Array{Int64,1}:
207+
1
208+
4
209+
36
198210
```
199211
"""
200212
cumprod(x::AbstractVector) = cumprod(x, dims=1)
@@ -250,6 +262,10 @@ julia> accumulate(+, fill(1, 3, 3), dims=2)
250262
```
251263
"""
252264
function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...)
265+
if dims === nothing && !(A isa AbstractVector)
266+
# This branch takes care of the cases not handled by `_accumulate!`.
267+
return collect(Iterators.accumulate(op, A; kw...))
268+
end
253269
nt = kw.data
254270
if nt isa NamedTuple{()}
255271
out = similar(A, promote_op(op, eltype(A), eltype(A)))

base/iterators.jl

+15-7
Original file line numberDiff line numberDiff line change
@@ -443,13 +443,14 @@ reverse(f::Filter) = Filter(f.flt, reverse(f.itr))
443443

444444
# Accumulate -- partial reductions of a function over an iterator
445445

446-
struct Accumulate{F,I}
446+
struct Accumulate{F,I,T}
447447
f::F
448448
itr::I
449+
init::T
449450
end
450451

451452
"""
452-
Iterators.accumulate(f, itr)
453+
Iterators.accumulate(f, itr; [init])
453454
454455
Given a 2-argument function `f` and an iterator `itr`, return a new
455456
iterator that successively applies `f` to the previous value and the
@@ -459,24 +460,31 @@ This is effectively a lazy version of [`Base.accumulate`](@ref).
459460
460461
# Examples
461462
```jldoctest
462-
julia> f = Iterators.accumulate(+, [1,2,3,4])
463-
Base.Iterators.Accumulate{typeof(+),Array{Int64,1}}(+, [1, 2, 3, 4])
463+
julia> f = Iterators.accumulate(+, [1,2,3,4]);
464464
465465
julia> foreach(println, f)
466466
1
467467
3
468468
6
469469
10
470+
471+
julia> f = Iterators.accumulate(+, [1,2,3]; init = 100);
472+
473+
julia> foreach(println, f)
474+
101
475+
103
476+
106
470477
```
471478
"""
472-
accumulate(f, itr) = Accumulate(f, itr)
479+
accumulate(f, itr; init = Base._InitialValue()) = Accumulate(f, itr, init)
473480

474481
function iterate(itr::Accumulate)
475482
state = iterate(itr.itr)
476483
if state === nothing
477484
return nothing
478485
end
479-
return (state[1], state)
486+
val = Base.BottomRF(itr.f)(itr.init, state[1])
487+
return (val, (val, state[2]))
480488
end
481489

482490
function iterate(itr::Accumulate, state)
@@ -491,7 +499,7 @@ end
491499
length(itr::Accumulate) = length(itr.itr)
492500
size(itr::Accumulate) = size(itr.itr)
493501

494-
IteratorSize(::Type{Accumulate{F,I}}) where {F,I} = IteratorSize(I)
502+
IteratorSize(::Type{<:Accumulate{F,I}}) where {F,I} = IteratorSize(I)
495503
IteratorEltype(::Type{<:Accumulate}) = EltypeUnknown()
496504

497505
# Rest -- iterate starting at the given state

test/iterators.jl

+6
Original file line numberDiff line numberDiff line change
@@ -798,3 +798,9 @@ end
798798
@test Base.IteratorSize(Iterators.accumulate(max, rand(2,3))) === Base.IteratorSize(rand(2,3))
799799
@test Base.IteratorEltype(Iterators.accumulate(*, ())) isa Base.EltypeUnknown
800800
end
801+
802+
@testset "Base.accumulate" begin
803+
@test cumsum(x^2 for x in 1:3) == [1, 5, 14]
804+
@test cumprod(x + 1 for x in 1:3) == [2, 6, 24]
805+
@test accumulate(+, (x^2 for x in 1:3); init=100) == [101, 105, 114]
806+
end

0 commit comments

Comments
 (0)