Skip to content

Commit 0314913

Browse files
committed
Add a multidimensional (cartesian) iterator
Closes #1917, closes #6437
1 parent cad4eaa commit 0314913

File tree

6 files changed

+183
-3
lines changed

6 files changed

+183
-3
lines changed

base/dates/ranges.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ function in{T<:TimeType}(x::T, r::StepRange{T})
2929
end
3030

3131
Base.start{T<:TimeType}(r::StepRange{T}) = 0
32-
Base.next{T<:TimeType}(r::StepRange{T}, i) = (r.start+r.step*i,i+1)
32+
Base.next{T<:TimeType}(r::StepRange{T}, i::Int) = (r.start+r.step*i,i+1)
3333
Base.done{T<:TimeType,S<:Period}(r::StepRange{T,S}, i::Integer) = length(r) <= i

base/exports.jl

+2
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ export
515515
cumsum,
516516
cumsum!,
517517
cumsum_kbn,
518+
eachelement,
519+
eachindex,
518520
extrema,
519521
fill!,
520522
fill,

base/multidimensional.jl

+119
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,122 @@
1+
### Multidimensional iterators
2+
module IteratorsMD
3+
4+
import Base: start, done, next, getindex, setindex!
5+
import Base: @nref, @ncall, @nif, @nexprs
6+
7+
export eachelement, eachindex, linearindexing, LinearFast
8+
9+
# Traits for linear indexing
10+
abstract LinearIndexing
11+
immutable LinearFast <: LinearIndexing end
12+
immutable LinearSlow <: LinearIndexing end
13+
14+
linearindexing(::AbstractArray) = LinearSlow()
15+
linearindexing(::Array) = LinearFast()
16+
linearindexing(::BitArray) = LinearFast()
17+
linearindexing(::Range) = LinearFast()
18+
19+
# this generates types like this:
20+
# immutable Subscripts_3 <: Subscripts{3}
21+
# I_1::Int
22+
# I_2::Int
23+
# I_3::Int
24+
# end
25+
# they are used as iterator states
26+
# TODO: when tuples get improved, replace with a tuple-based implementation. See #6437.
27+
28+
abstract Subscripts{N} # the state for all multidimensional iterators
29+
abstract SizeIterator{N} # Iterator that visits the index associated with each element
30+
31+
function gen_iterators(N::Int, with_shared=true)
32+
# Create the types
33+
namestate = symbol("Subscripts_$N")
34+
namesize = symbol("SizeIterator_$N")
35+
fieldnames = [symbol("I_$i") for i = 1:N]
36+
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
37+
exstate = Expr(:type, false, Expr(:(<:), namestate, Expr(:curly, :Subscripts, N)), Expr(:block, fields...))
38+
dimsindexes = Expr[:(dims[$i]) for i = 1:N]
39+
onesN = ones(Int, N)
40+
infsN = fill(typemax(Int), N)
41+
anyzero = Expr(:(||), [:(SZ.I.$(fieldnames[i]) == 0) for i = 1:N]...)
42+
# Some necessary ambiguity resolution
43+
exrange = N != 1 ? nothing : quote
44+
next(R::StepRange, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1)
45+
next{T}(R::UnitRange{T}, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1)
46+
end
47+
exshared = !with_shared ? nothing : quote
48+
getindex{T}(S::SharedArray{T,$N}, state::$namestate) = S.s[state]
49+
setindex!{T}(S::SharedArray{T,$N}, v, state::$namestate) = S.s[state] = v
50+
end
51+
quote
52+
$exstate
53+
immutable $namesize <: SizeIterator{$N}
54+
I::$namestate
55+
end
56+
$namestate(dims::NTuple{$N,Int}) = $namestate($(dimsindexes...))
57+
_eachindex(dims::NTuple{$N,Int}) = $namesize($namestate(dims))
58+
59+
start{T}(AT::(AbstractArray{T,$N},LinearSlow)) = isempty(AT[1]) ? $namestate($(infsN...)) : $namestate($(onesN...))
60+
start(SZ::$namesize) = $anyzero ? $namestate($(infsN...)) : $namestate($(onesN...))
61+
62+
$exrange
63+
64+
@inline function next{T}(A::AbstractArray{T,$N}, state::$namestate)
65+
@inbounds v = A[state]
66+
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
67+
v, newstate
68+
end
69+
@inline function next(iter::$namesize, state::$namestate)
70+
newstate = @nif $N d->(getfield(state,d) < getfield(iter.I,d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
71+
state, newstate
72+
end
73+
74+
$exshared
75+
getindex{T}(A::AbstractArray{T,$N}, state::$namestate) = @nref $N A d->getfield(state,d)
76+
setindex!{T}(A::AbstractArray{T,$N}, v, state::$namestate) = (@nref $N A d->getfield(state,d)) = v
77+
end
78+
end
79+
80+
# Ambiguity resolution
81+
done(R::StepRange, I::Subscripts{1}) = getfield(I, 1) > length(R)
82+
done(R::UnitRange, I::Subscripts{1}) = getfield(I, 1) > length(R)
83+
84+
Base.start(A::AbstractArray) = start((A,linearindexing(A)))
85+
start(::(AbstractArray,LinearFast)) = 1
86+
done{T,N}(A::AbstractArray{T,N}, I::Subscripts{N}) = getfield(I, N) > size(A, N)
87+
done{N}(iter::SizeIterator{N}, I::Subscripts{N}) = getfield(I, N) > getfield(iter.I, N)
88+
89+
eachindex(A::AbstractArray) = eachindex(size(A))
90+
91+
let implemented = IntSet()
92+
global eachindex
93+
global eachelement
94+
function eachindex{N}(t::NTuple{N,Int})
95+
if !in(N, implemented)
96+
eval(gen_iterators(N))
97+
end
98+
_eachindex(t)
99+
end
100+
function eachelement{T,N}(A::AbstractArray{T,N})
101+
if !in(N, implemented)
102+
eval(gen_iterators(N))
103+
end
104+
A
105+
end
106+
end
107+
108+
# Pre-generate for low dimensions
109+
for N = 1:8
110+
eval(gen_iterators(N, false))
111+
eval(:(eachindex(t::NTuple{$N,Int}) = _eachindex(t)))
112+
eval(:(eachelement{T}(A::AbstractArray{T,$N}) = A))
113+
end
114+
115+
end # IteratorsMD
116+
117+
using .IteratorsMD
118+
119+
1120
### From array.jl
2121

3122
@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)

base/range.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ copy(r::Range) = r
235235
## iteration
236236

237237
start(r::FloatRange) = 0
238-
next{T}(r::FloatRange{T}, i) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
239-
done(r::FloatRange, i) = (length(r) <= i)
238+
next{T}(r::FloatRange{T}, i::Int) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
239+
done(r::FloatRange, i::Int) = (length(r) <= i)
240240

241241
# NOTE: For ordinal ranges, we assume start+step might be from a
242242
# lifted domain (e.g. Int8+Int8 => Int); use that for iterating.

base/sharedarray.jl

+8
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,14 @@ end
207207
convert(::Type{Array}, S::SharedArray) = S.s
208208

209209
# # pass through getindex and setindex! - they always work on the complete array unlike DArrays
210+
for N = 1:8
211+
name = symbol("Subscripts_$N")
212+
@eval begin
213+
getindex{T}(S::SharedArray{T,$N}, I::IteratorsMD.$name) = getindex(S.s, I)
214+
setindex!{T}(S::SharedArray{T,$N}, v, I::IteratorsMD.$name) = setindex!(S.s, v, I)
215+
end
216+
end
217+
210218
getindex(S::SharedArray) = getindex(S.s)
211219
getindex(S::SharedArray, I::Real) = getindex(S.s, I)
212220
getindex(S::SharedArray, I::AbstractArray) = getindex(S.s, I)

test/arrayops.jl

+51
Original file line numberDiff line numberDiff line change
@@ -925,3 +925,54 @@ end
925925
b718cbc = 5
926926
@test b718cbc[1.0] == 5
927927
@test_throws InexactError b718cbc[1.1]
928+
929+
# Multidimensional iterators
930+
function mdsum(A)
931+
s = 0.0
932+
for a in eachelement(A)
933+
s += a
934+
end
935+
s
936+
end
937+
938+
function mdsum2(A)
939+
s = 0.0
940+
@inbounds for I in eachindex(A)
941+
s += A[I]
942+
end
943+
s
944+
end
945+
946+
a = [1:5]
947+
@test isa(Base.linearindexing(a), Base.LinearFast)
948+
b = sub(a, :)
949+
@test isa(Base.linearindexing(b), Base.IteratorsMD.LinearSlow)
950+
shp = [5]
951+
for i = 1:10
952+
A = reshape(a, tuple(shp...))
953+
@test mdsum(A) == 15
954+
@test mdsum2(A) == 15
955+
B = sub(A, ntuple(i, i->Colon())...)
956+
@test mdsum(B) == 15
957+
@test mdsum2(B) == 15
958+
unshift!(shp, 1)
959+
end
960+
961+
a = [1:10]
962+
shp = [2,5]
963+
for i = 2:10
964+
A = reshape(a, tuple(shp...))
965+
@test mdsum(A) == 55
966+
@test mdsum2(A) == 55
967+
B = sub(A, ntuple(i, i->Colon())...)
968+
@test mdsum(B) == 55
969+
@test mdsum2(B) == 55
970+
insert!(shp, 2, 1)
971+
end
972+
973+
a = ones(0,5)
974+
b = sub(a, :, :)
975+
@test mdsum(b) == 0
976+
a = ones(5,0)
977+
b = sub(a, :, :)
978+
@test mdsum(b) == 0

0 commit comments

Comments
 (0)