|
| 1 | +### Multidimensional iterators |
| 2 | +module IteratorsMD |
| 3 | + |
| 4 | +import Base: start, done, next, getindex, setindex! |
| 5 | +import Base: @nref, @ncall, @nif, @nexprs |
| 6 | + |
| 7 | +export eachelement, eachindex, linearindexing, LinearFast |
| 8 | + |
| 9 | +# Traits for linear indexing |
| 10 | +abstract LinearIndexing |
| 11 | +immutable LinearFast <: LinearIndexing end |
| 12 | +immutable LinearSlow <: LinearIndexing end |
| 13 | + |
| 14 | +linearindexing(::AbstractArray) = LinearSlow() |
| 15 | +linearindexing(::Array) = LinearFast() |
| 16 | +linearindexing(::BitArray) = LinearFast() |
| 17 | +linearindexing(::Range) = LinearFast() |
| 18 | + |
| 19 | +# this generates types like this: |
| 20 | +# immutable Subscripts_3 <: Subscripts{3} |
| 21 | +# I_1::Int |
| 22 | +# I_2::Int |
| 23 | +# I_3::Int |
| 24 | +# end |
| 25 | +# they are used as iterator states |
| 26 | +# TODO: when tuples get improved, replace with a tuple-based implementation. See #6437. |
| 27 | + |
| 28 | +abstract Subscripts{N} # the state for all multidimensional iterators |
| 29 | +abstract SizeIterator{N} # Iterator that visits the index associated with each element |
| 30 | + |
| 31 | +function gen_iterators(N::Int, with_shared=true) |
| 32 | + # Create the types |
| 33 | + namestate = symbol("Subscripts_$N") |
| 34 | + namesize = symbol("SizeIterator_$N") |
| 35 | + fieldnames = [symbol("I_$i") for i = 1:N] |
| 36 | + fields = [Expr(:(::), fieldnames[i], :Int) for i = 1:N] |
| 37 | + exstate = Expr(:type, false, Expr(:(<:), namestate, Expr(:curly, :Subscripts, N)), Expr(:block, fields...)) |
| 38 | + dimsindexes = Expr[:(dims[$i]) for i = 1:N] |
| 39 | + onesN = ones(Int, N) |
| 40 | + infsN = fill(typemax(Int), N) |
| 41 | + anyzero = Expr(:(||), [:(SZ.I.$(fieldnames[i]) == 0) for i = 1:N]...) |
| 42 | + # Some necessary ambiguity resolution |
| 43 | + exrange = N != 1 ? nothing : quote |
| 44 | + next(R::StepRange, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1) |
| 45 | + next{T}(R::UnitRange{T}, I::Subscripts_1) = R[I.I_1], Subscripts_1(I.I_1+1) |
| 46 | + end |
| 47 | + exshared = !with_shared ? nothing : quote |
| 48 | + getindex{T}(S::SharedArray{T,$N}, state::$namestate) = S.s[state] |
| 49 | + setindex!{T}(S::SharedArray{T,$N}, v, state::$namestate) = S.s[state] = v |
| 50 | + end |
| 51 | + quote |
| 52 | + $exstate |
| 53 | + immutable $namesize <: SizeIterator{$N} |
| 54 | + I::$namestate |
| 55 | + end |
| 56 | + $namestate(dims::NTuple{$N,Int}) = $namestate($(dimsindexes...)) |
| 57 | + _eachindex(dims::NTuple{$N,Int}) = $namesize($namestate(dims)) |
| 58 | + |
| 59 | + start{T}(AT::(AbstractArray{T,$N},LinearSlow)) = isempty(AT[1]) ? $namestate($(infsN...)) : $namestate($(onesN...)) |
| 60 | + start(SZ::$namesize) = $anyzero ? $namestate($(infsN...)) : $namestate($(onesN...)) |
| 61 | + |
| 62 | + $exrange |
| 63 | + |
| 64 | + @inline function next{T}(A::AbstractArray{T,$N}, state::$namestate) |
| 65 | + @inbounds v = A[state] |
| 66 | + newstate = @nif $N d->(getfield(state,d) < size(A, d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) |
| 67 | + v, newstate |
| 68 | + end |
| 69 | + @inline function next(iter::$namesize, state::$namestate) |
| 70 | + newstate = @nif $N d->(getfield(state,d) < getfield(iter.I,d)) d->(@ncall($N, $namestate, k->(k>d ? getfield(state,k) : k==d ? getfield(state,k)+1 : 1))) |
| 71 | + state, newstate |
| 72 | + end |
| 73 | + |
| 74 | + $exshared |
| 75 | + getindex{T}(A::AbstractArray{T,$N}, state::$namestate) = @nref $N A d->getfield(state,d) |
| 76 | + setindex!{T}(A::AbstractArray{T,$N}, v, state::$namestate) = (@nref $N A d->getfield(state,d)) = v |
| 77 | + end |
| 78 | +end |
| 79 | + |
| 80 | +# Ambiguity resolution |
| 81 | +done(R::StepRange, I::Subscripts{1}) = getfield(I, 1) > length(R) |
| 82 | +done(R::UnitRange, I::Subscripts{1}) = getfield(I, 1) > length(R) |
| 83 | + |
| 84 | +Base.start(A::AbstractArray) = start((A,linearindexing(A))) |
| 85 | +start(::(AbstractArray,LinearFast)) = 1 |
| 86 | +done{T,N}(A::AbstractArray{T,N}, I::Subscripts{N}) = getfield(I, N) > size(A, N) |
| 87 | +done{N}(iter::SizeIterator{N}, I::Subscripts{N}) = getfield(I, N) > getfield(iter.I, N) |
| 88 | + |
| 89 | +eachindex(A::AbstractArray) = eachindex(size(A)) |
| 90 | + |
| 91 | +let implemented = IntSet() |
| 92 | +global eachindex |
| 93 | +global eachelement |
| 94 | +function eachindex{N}(t::NTuple{N,Int}) |
| 95 | + if !in(N, implemented) |
| 96 | + eval(gen_iterators(N)) |
| 97 | + end |
| 98 | + _eachindex(t) |
| 99 | +end |
| 100 | +function eachelement{T,N}(A::AbstractArray{T,N}) |
| 101 | + if !in(N, implemented) |
| 102 | + eval(gen_iterators(N)) |
| 103 | + end |
| 104 | + A |
| 105 | +end |
| 106 | +end |
| 107 | + |
| 108 | +# Pre-generate for low dimensions |
| 109 | +for N = 1:8 |
| 110 | + eval(gen_iterators(N, false)) |
| 111 | + eval(:(eachindex(t::NTuple{$N,Int}) = _eachindex(t))) |
| 112 | + eval(:(eachelement{T}(A::AbstractArray{T,$N}) = A)) |
| 113 | +end |
| 114 | + |
| 115 | +end # IteratorsMD |
| 116 | + |
| 117 | +using .IteratorsMD |
| 118 | + |
| 119 | + |
1 | 120 | ### From array.jl
|
2 | 121 |
|
3 | 122 | @ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
|
|
0 commit comments