Skip to content

Commit 80e29ed

Browse files
committed
Concat iterator that concatenates the elements of an iterator
1 parent a7fa96f commit 80e29ed

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

base/iterator.jl

+57
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,60 @@ function next(itr::PartitionIterator, state)
559559
end
560560
return resize!(v, i), state
561561
end
562+
563+
564+
immutable Concat{I,T,S}
565+
first::T
566+
itr::I
567+
st::S
568+
end
569+
570+
function Concat(c)
571+
s = start(c)
572+
done(c, s) && throw(ArgumentError("argument to Concat must contain at least one element"))
573+
head, s = next(c, s)
574+
Concat(head, c, s)
575+
end
576+
577+
"""
578+
concat(iter)
579+
580+
Given an iterator with trait `HasShape()` that yields iterators with trait
581+
`HasShape()`, return an iterator that yields the elements of those iterators
582+
concatenated.
583+
584+
```jldoctest
585+
julia> collect(Base.concat(((3*(i-1)+j for j in 1:3) for i in 1:4)))
586+
3×4 Array{Int64,2}:
587+
1 4 7 10
588+
2 5 8 11
589+
3 6 9 12
590+
```
591+
"""
592+
concat(it) = Concat(it)
593+
594+
eltype{I,T,S}(::Type{Concat{I,T,S}}) = eltype(T)
595+
iteratorsize{I,T,S}(::Type{Concat{I,T,S}}) = HasShape()
596+
iteratoreltype{I,T,S}(::Type{Concat{I,T,S}}) = iteratoreltype(T)
597+
size(it::Concat) = (size(it.first)...,size(it.itr)...)
598+
length(it::Concat) = prod(size(it))
599+
600+
function start(c::Concat)
601+
return c.st, c.first, start(c.first)
602+
end
603+
604+
function next(c::Concat, state)
605+
s, inner, s2 = state
606+
val, s2 = next(inner, s2)
607+
while done(inner, s2) && !done(c.itr, s)
608+
inner, s = next(c.itr, s)
609+
size(inner) == size(c.first) || throw(DimensionMismatch("elements of different size in argument to Concat"))
610+
s2 = start(inner)
611+
end
612+
return val, (s, inner, s2)
613+
end
614+
615+
@inline function done(c::Concat, state)
616+
s, inner, s2 = state
617+
return done(c.itr, s) && done(inner, s2)
618+
end

test/functional.jl

+14
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,20 @@ import Base.flatten
374374
@test eltype(flatten(UnitRange{Int8}[1:2, 3:4])) == Int8
375375
@test_throws ArgumentError collect(flatten(Any[]))
376376

377+
378+
# concat
379+
# ------
380+
381+
import Base.concat
382+
383+
@test collect(concat(i:i+10 for i in 1:3)) == [1:11 2:12 3:13]
384+
@test typeof(collect(concat(i:i+10 for i in 1:3))) == Array{Int,2}
385+
@test_throws DimensionMismatch collect(concat(i:10 for i in 1:3))
386+
@test_throws ArgumentError collect(concat(Any[]))
387+
388+
@test [reshape(1:6,3,2)[k,l]*i+j for k in 1:3, l in 1:2, i in 1:4, j in 1:5] == collect(Base.concat([reshape(1:6,3,2)*i+j for i in 1:4, j in 1:5]))
389+
390+
377391
# foreach
378392
let
379393
a = []

0 commit comments

Comments
 (0)