Skip to content

Commit 37ffa13

Browse files
committed
Merge pull request #8432 from JuliaLang/teh/cartesian_iteration
Efficient cartesian iteration (new version of #6437)
2 parents 4f7b787 + 2fa852d commit 37ffa13

7 files changed

+179
-11
lines changed

base/abstractarray.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,16 @@ function trailingsize(A, n)
5858
return s
5959
end
6060

61+
## Traits for array types ##
62+
63+
abstract LinearIndexing
64+
immutable LinearFast <: LinearIndexing end
65+
immutable LinearSlow <: LinearIndexing end
66+
67+
linearindexing(::AbstractArray) = LinearSlow()
68+
linearindexing(::Array) = LinearFast()
69+
linearindexing(::Range) = LinearFast()
70+
6171
## Bounds checking ##
6272
checkbounds(sz::Int, i::Int) = 1 <= i <= sz || throw(BoundsError())
6373
checkbounds(sz::Int, i::Real) = checkbounds(sz, to_index(i))
@@ -241,7 +251,8 @@ zero{T}(x::AbstractArray{T}) = fill!(similar(x), zero(T))
241251

242252
## iteration support for arrays as ranges ##
243253

244-
start(a::AbstractArray) = 1
254+
start(A::AbstractArray) = _start(A,linearindexing(A))
255+
_start(::AbstractArray,::LinearFast) = 1
245256
next(a::AbstractArray,i) = (a[i],i+1)
246257
done(a::AbstractArray,i) = (i > length(a))
247258
isempty(a::AbstractArray) = (length(a) == 0)

base/dates/ranges.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ function in{T<:TimeType}(x::T, r::StepRange{T})
2929
end
3030

3131
Base.start{T<:TimeType}(r::StepRange{T}) = 0
32-
Base.next{T<:TimeType}(r::StepRange{T}, i) = (r.start+r.step*i,i+1)
32+
Base.next{T<:TimeType}(r::StepRange{T}, i::Int) = (r.start+r.step*i,i+1)
3333
Base.done{T<:TimeType,S<:Period}(r::StepRange{T,S}, i::Integer) = length(r) <= i

base/exports.jl

+2
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ export
515515
cumsum,
516516
cumsum!,
517517
cumsum_kbn,
518+
eachelement,
519+
eachindex,
518520
extrema,
519521
fill!,
520522
fill,

base/multidimensional.jl

+106
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,109 @@
1+
### Multidimensional iterators
2+
module IteratorsMD
3+
4+
import Base: start, _start, done, next, getindex, setindex!, linearindexing
5+
import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow
6+
7+
export eachindex
8+
9+
# Traits for linear indexing
10+
linearindexing(::BitArray) = LinearFast()
11+
12+
# Iterator/state
13+
abstract CartesianIndex{N} # the state for all multidimensional iterators
14+
abstract IndexIterator{N} # Iterator that visits the index associated with each element
15+
16+
stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int})
17+
indextype,itertype=gen_cartesian(N)
18+
return :($indextype(index))
19+
end
20+
stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int})
21+
indextype,itertype=gen_cartesian(N)
22+
return :($itertype(index))
23+
end
24+
25+
let implemented = IntSet()
26+
global gen_cartesian
27+
function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME))
28+
# Create the types
29+
indextype = symbol("CartesianIndex_$N")
30+
itertype = symbol("IndexIterator_$N")
31+
if !in(N,implemented)
32+
fieldnames = [symbol("I_$i") for i = 1:N]
33+
fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N]
34+
extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...))
35+
exindices = Expr[:(index[$i]) for i = 1:N]
36+
37+
onesN = ones(Int, N)
38+
infsN = fill(typemax(Int), N)
39+
anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...)
40+
41+
# Some necessary ambiguity resolution
42+
exrange = N != 1 ? nothing : quote
43+
next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
44+
next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1)
45+
end
46+
exshared = !with_shared ? nothing : quote
47+
getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I]
48+
setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v
49+
end
50+
totalex = quote
51+
# type definition
52+
$extype
53+
# extra constructor from tuple
54+
$indextype(index::NTuple{$N,Int}) = $indextype($(exindices...))
55+
56+
immutable $itertype <: IndexIterator{$N}
57+
dims::$indextype
58+
end
59+
$itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims))
60+
61+
# getindex and setindex!
62+
$exshared
63+
getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d)
64+
setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v
65+
66+
# next iteration
67+
$exrange
68+
@inline function next{T}(A::AbstractArray{T,$N}, state::$indextype)
69+
@inbounds v = A[state]
70+
newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
71+
v, newstate
72+
end
73+
@inline function next(iter::$itertype, state::$indextype)
74+
newstate = @nif $N d->(getfield(state,d) < getfield(iter.dims,d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1)))
75+
state, newstate
76+
end
77+
78+
# start
79+
start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...))
80+
end
81+
eval(totalex)
82+
push!(implemented,N)
83+
end
84+
return indextype, itertype
85+
end
86+
end
87+
88+
# Iteration
89+
eachindex(A::AbstractArray) = IndexIterator(size(A))
90+
91+
# start iteration
92+
_start{T,N}(A::AbstractArray{T,N},::LinearSlow) = CartesianIndex(ntuple(N,n->ifelse(isempty(A),typemax(Int),1))::NTuple{N,Int})
93+
94+
# Ambiguity resolution
95+
done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
96+
done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
97+
done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R)
98+
99+
done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N)
100+
done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N)
101+
102+
end # IteratorsMD
103+
104+
using .IteratorsMD
105+
106+
1107
### From array.jl
2108

3109
@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)

base/range.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ copy(r::Range) = r
235235
## iteration
236236

237237
start(r::FloatRange) = 0
238-
next{T}(r::FloatRange{T}, i) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
239-
done(r::FloatRange, i) = (length(r) <= i)
238+
next{T}(r::FloatRange{T}, i::Int) = (convert(T, (r.start + i*r.step)/r.divisor), i+1)
239+
done(r::FloatRange, i::Int) = (length(r) <= i)
240240

241241
# NOTE: For ordinal ranges, we assume start+step might be from a
242242
# lifted domain (e.g. Int8+Int8 => Int); use that for iterating.

base/sharedarray.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -206,16 +206,16 @@ end
206206

207207
convert(::Type{Array}, S::SharedArray) = S.s
208208

209-
# # pass through getindex and setindex! - they always work on the complete array unlike DArrays
209+
# pass through getindex and setindex! - they always work on the complete array unlike DArrays
210210
getindex(S::SharedArray) = getindex(S.s)
211211
getindex(S::SharedArray, I::Real) = getindex(S.s, I)
212212
getindex(S::SharedArray, I::AbstractArray) = getindex(S.s, I)
213213
@nsplat N 1:5 getindex(S::SharedArray, I::NTuple{N,Any}...) = getindex(S.s, I...)
214214

215-
setindex!(S::SharedArray, x) = (setindex!(S.s, x); S)
216-
setindex!(S::SharedArray, x, I::Real) = (setindex!(S.s, x, I); S)
217-
setindex!(S::SharedArray, x, I::AbstractArray) = (setindex!(S.s, x, I); S)
218-
@nsplat N 1:5 setindex!(S::SharedArray, x, I::NTuple{N,Any}...) = (setindex!(S.s, x, I...); S)
215+
setindex!(S::SharedArray, x) = setindex!(S.s, x)
216+
setindex!(S::SharedArray, x, I::Real) = setindex!(S.s, x, I)
217+
setindex!(S::SharedArray, x, I::AbstractArray) = setindex!(S.s, x, I)
218+
@nsplat N 1:5 setindex!(S::SharedArray, x, I::NTuple{N,Any}...) = setindex!(S.s, x, I...)
219219

220220
function fill!(S::SharedArray, v)
221221
f = S->fill!(S.loc_subarr_1d, v)
@@ -377,5 +377,3 @@ end
377377
end
378378

379379
@unix_only shm_open(shm_seg_name, oflags, permissions) = ccall(:shm_open, Int, (Ptr{UInt8}, Int, Int), shm_seg_name, oflags, permissions)
380-
381-

test/arrayops.jl

+51
Original file line numberDiff line numberDiff line change
@@ -925,3 +925,54 @@ end
925925
b718cbc = 5
926926
@test b718cbc[1.0] == 5
927927
@test_throws InexactError b718cbc[1.1]
928+
929+
# Multidimensional iterators
930+
function mdsum(A)
931+
s = 0.0
932+
for a in A
933+
s += a
934+
end
935+
s
936+
end
937+
938+
function mdsum2(A)
939+
s = 0.0
940+
@inbounds for I in eachindex(A)
941+
s += A[I]
942+
end
943+
s
944+
end
945+
946+
a = [1:5]
947+
@test isa(Base.linearindexing(a), Base.LinearFast)
948+
b = sub(a, :)
949+
@test isa(Base.linearindexing(b), Base.IteratorsMD.LinearSlow)
950+
shp = [5]
951+
for i = 1:10
952+
A = reshape(a, tuple(shp...))
953+
@test mdsum(A) == 15
954+
@test mdsum2(A) == 15
955+
B = sub(A, ntuple(i, i->Colon())...)
956+
@test mdsum(B) == 15
957+
@test mdsum2(B) == 15
958+
unshift!(shp, 1)
959+
end
960+
961+
a = [1:10]
962+
shp = [2,5]
963+
for i = 2:10
964+
A = reshape(a, tuple(shp...))
965+
@test mdsum(A) == 55
966+
@test mdsum2(A) == 55
967+
B = sub(A, ntuple(i, i->Colon())...)
968+
@test mdsum(B) == 55
969+
@test mdsum2(B) == 55
970+
insert!(shp, 2, 1)
971+
end
972+
973+
a = ones(0,5)
974+
b = sub(a, :, :)
975+
@test mdsum(b) == 0
976+
a = ones(5,0)
977+
b = sub(a, :, :)
978+
@test mdsum(b) == 0

0 commit comments

Comments
 (0)