|
| 1 | +### Multidimensional iterators |
| 2 | +module IteratorsMD |
| 3 | + |
| 4 | +import Base: start, _start, done, next, getindex, setindex!, linearindexing |
| 5 | +import Base: @nref, @ncall, @nif, @nexprs, LinearFast, LinearSlow |
| 6 | + |
| 7 | +export eachindex |
| 8 | + |
| 9 | +# Traits for linear indexing |
| 10 | +linearindexing(::BitArray) = LinearFast() |
| 11 | + |
| 12 | +# Iterator/state |
| 13 | +abstract CartesianIndex{N} # the state for all multidimensional iterators |
| 14 | +abstract IndexIterator{N} # Iterator that visits the index associated with each element |
| 15 | + |
| 16 | +stagedfunction Base.call{N}(::Type{CartesianIndex},index::NTuple{N,Int}) |
| 17 | + indextype,itertype=gen_cartesian(N) |
| 18 | + return :($indextype(index)) |
| 19 | +end |
| 20 | +stagedfunction Base.call{N}(::Type{IndexIterator},index::NTuple{N,Int}) |
| 21 | + indextype,itertype=gen_cartesian(N) |
| 22 | + return :($itertype(index)) |
| 23 | +end |
| 24 | + |
| 25 | +let implemented = IntSet() |
| 26 | +global gen_cartesian |
| 27 | +function gen_cartesian(N::Int, with_shared=Base.is_unix(OS_NAME)) |
| 28 | + # Create the types |
| 29 | + indextype = symbol("CartesianIndex_$N") |
| 30 | + itertype = symbol("IndexIterator_$N") |
| 31 | + if !in(N,implemented) |
| 32 | + fieldnames = [symbol("I_$i") for i = 1:N] |
| 33 | + fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N] |
| 34 | + extype = Expr(:type, false, Expr(:(<:), indextype, Expr(:curly, :CartesianIndex, N)), Expr(:block, fields...)) |
| 35 | + exindices = Expr[:(index[$i]) for i = 1:N] |
| 36 | + |
| 37 | + onesN = ones(Int, N) |
| 38 | + infsN = fill(typemax(Int), N) |
| 39 | + anyzero = Expr(:(||), [:(iter.dims.$(fieldnames[i]) == 0) for i = 1:N]...) |
| 40 | + |
| 41 | + # Some necessary ambiguity resolution |
| 42 | + exrange = N != 1 ? nothing : quote |
| 43 | + next(R::StepRange, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1) |
| 44 | + next{T}(R::UnitRange{T}, I::CartesianIndex_1) = R[I.I_1], CartesianIndex_1(I.I_1+1) |
| 45 | + end |
| 46 | + exshared = !with_shared ? nothing : quote |
| 47 | + getindex{T}(S::SharedArray{T,$N}, I::$indextype) = S.s[I] |
| 48 | + setindex!{T}(S::SharedArray{T,$N}, v, I::$indextype) = S.s[I] = v |
| 49 | + end |
| 50 | + totalex = quote |
| 51 | + # type definition |
| 52 | + $extype |
| 53 | + # extra constructor from tuple |
| 54 | + $indextype(index::NTuple{$N,Int}) = $indextype($(exindices...)) |
| 55 | + |
| 56 | + immutable $itertype <: IndexIterator{$N} |
| 57 | + dims::$indextype |
| 58 | + end |
| 59 | + $itertype(dims::NTuple{$N,Int})=$itertype($indextype(dims)) |
| 60 | + |
| 61 | + # getindex and setindex! |
| 62 | + $exshared |
| 63 | + getindex{T}(A::AbstractArray{T,$N}, index::$indextype) = @nref $N A d->getfield(index,d) |
| 64 | + setindex!{T}(A::AbstractArray{T,$N}, v, index::$indextype) = (@nref $N A d->getfield(index,d)) = v |
| 65 | + |
| 66 | + # next iteration |
| 67 | + $exrange |
| 68 | + @inline function next{T}(A::AbstractArray{T,$N}, state::$indextype) |
| 69 | + @inbounds v = A[state] |
| 70 | + newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) |
| 71 | + v, newstate |
| 72 | + end |
| 73 | + @inline function next(iter::$itertype, state::$indextype) |
| 74 | + newstate = @nif $N d->(getfield(state,d) < getfield(iter.dims,d)) d->(@ncall($N, $indextype, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) |
| 75 | + state, newstate |
| 76 | + end |
| 77 | + |
| 78 | + # start |
| 79 | + start(iter::$itertype) = $anyzero ? $indextype($(infsN...)) : $indextype($(onesN...)) |
| 80 | + end |
| 81 | + eval(totalex) |
| 82 | + push!(implemented,N) |
| 83 | + end |
| 84 | + return indextype, itertype |
| 85 | +end |
| 86 | +end |
| 87 | + |
| 88 | +# Iteration |
| 89 | +eachindex(A::AbstractArray) = IndexIterator(size(A)) |
| 90 | + |
| 91 | +# start iteration |
| 92 | +_start{T,N}(A::AbstractArray{T,N},::LinearSlow) = CartesianIndex(ntuple(N,n->ifelse(isempty(A),typemax(Int),1))::NTuple{N,Int}) |
| 93 | + |
| 94 | +# Ambiguity resolution |
| 95 | +done(R::StepRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) |
| 96 | +done(R::UnitRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) |
| 97 | +done(R::FloatRange, I::CartesianIndex{1}) = getfield(I, 1) > length(R) |
| 98 | + |
| 99 | +done{T,N}(A::AbstractArray{T,N}, I::CartesianIndex{N}) = getfield(I, N) > size(A, N) |
| 100 | +done{N}(iter::IndexIterator{N}, I::CartesianIndex{N}) = getfield(I, N) > getfield(iter.dims, N) |
| 101 | + |
| 102 | +end # IteratorsMD |
| 103 | + |
| 104 | +using .IteratorsMD |
| 105 | + |
| 106 | + |
1 | 107 | ### From array.jl
|
2 | 108 |
|
3 | 109 | @ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
|
|
0 commit comments