Skip to content

Commit e67c3ae

Browse files
JeffBezansonsamoconnor
authored andcommitted
add Flatten iterator
1 parent 9784df0 commit e67c3ae

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

base/iterator.jl

+54
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,57 @@ function collect{I<:IteratorND}(g::Generator{I})
341341
dest[1] = first
342342
return map_to!(g.f, 2, st, dest, g.iter)
343343
end
344+
345+
# flatten an iterator of iterators
346+
347+
immutable Flatten{I}
348+
it::I
349+
end
350+
351+
"""
352+
flatten(iter)
353+
354+
Given an iterator that yields iterators, return an iterator that yields the
355+
elements of those iterators.
356+
Put differently, the elements of the argument iterator are concatenated. Example:
357+
358+
julia> collect(flatten((1:2, 8:9)))
359+
4-element Array{Int64,1}:
360+
1
361+
2
362+
8
363+
9
364+
"""
365+
flatten(itr) = Flatten(itr)
366+
367+
eltype{I}(::Type{Flatten{I}}) = eltype(eltype(I))
368+
369+
function start(f::Flatten)
370+
local inner, s2
371+
s = start(f.it)
372+
d = done(f.it, s)
373+
# this is a simple way to make this function type stable
374+
d && error("argument to Flatten must contain at least one iterator")
375+
while !d
376+
inner, s = next(f.it, s)
377+
s2 = start(inner)
378+
!done(inner, s2) && break
379+
d = done(f.it, s)
380+
end
381+
return s, inner, s2
382+
end
383+
384+
function next(f::Flatten, state)
385+
s, inner, s2 = state
386+
val, s2 = next(inner, s2)
387+
while done(inner, s2) && !done(f.it, s)
388+
inner, s = next(f.it, s)
389+
s2 = start(inner)
390+
end
391+
return val, (s, inner, s2)
392+
end
393+
394+
@inline function done(f::Flatten, state)
395+
s, inner, s2 = state
396+
return done(f.it, s) && done(inner, s2)
397+
end

test/functional.jl

+13
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,19 @@ end
166166
@test isempty(collect(Base.product(1:0,1:2)))
167167
@test length(Base.product(1:2,1:10,4:6)) == 60
168168

169+
# flatten
170+
# -------
171+
172+
import Base.flatten
173+
174+
@test collect(flatten(Any[1:2, 4:5])) == Any[1,2,4,5]
175+
@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[10:7, 10:9])])) == Any[1,2]
176+
@test collect(flatten(Any[flatten(Any[1:2, 4:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,4,5,6,7,8,9]
177+
@test collect(flatten(Any[flatten(Any[1:2, 6:5]), flatten(Any[6:7, 8:9])])) == Any[1,2,6,7,8,9]
178+
@test collect(flatten(Any[2:1])) == Any[]
179+
@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8
180+
@test_throws ErrorException collect(flatten(Any[]))
181+
169182
# foreach
170183
let
171184
a = []

0 commit comments

Comments
 (0)