|
| 1 | + |
| 2 | +""" |
| 3 | + known_first(::Type{T}) |
| 4 | +
|
| 5 | +If `first` of an instance of type `T` is known at compile time, return it. |
| 6 | +Otherwise, return `nothing`. |
| 7 | +
|
| 8 | +@test isnothing(known_first(typeof(1:4))) |
| 9 | +@test isone(known_first(typeof(Base.OneTo(4)))) |
| 10 | +""" |
| 11 | +known_first(x) = known_first(typeof(x)) |
| 12 | +known_first(::Type{T}) where {T} = nothing |
| 13 | +known_first(::Type{Base.OneTo{T}}) where {T} = one(T) |
| 14 | +known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T)) |
| 15 | + |
| 16 | +""" |
| 17 | + known_last(::Type{T}) |
| 18 | +
|
| 19 | +If `last` of an instance of type `T` is known at compile time, return it. |
| 20 | +Otherwise, return `nothing`. |
| 21 | +
|
| 22 | +@test isnothing(known_last(typeof(1:4))) |
| 23 | +using StaticArrays |
| 24 | +@test known_last(typeof(SOneTo(4))) == 4 |
| 25 | +""" |
| 26 | +known_last(x) = known_last(typeof(x)) |
| 27 | +known_last(::Type{T}) where {T} = nothing |
| 28 | +known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T)) |
| 29 | + |
| 30 | +""" |
| 31 | + known_step(::Type{T}) |
| 32 | +
|
| 33 | +If `step` of an instance of type `T` is known at compile time, return it. |
| 34 | +Otherwise, return `nothing`. |
| 35 | +
|
| 36 | +@test isnothing(known_step(typeof(1:0.2:4))) |
| 37 | +@test isone(known_step(typeof(1:4))) |
| 38 | +""" |
| 39 | +known_step(x) = known_step(typeof(x)) |
| 40 | +known_step(::Type{T}) where {T} = nothing |
| 41 | +known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T) |
| 42 | + |
| 43 | +# add methods to support ArrayInterface |
| 44 | + |
| 45 | +_get(x) = x |
| 46 | +_get(::Val{V}) where {V} = V |
| 47 | +_convert(::Type{T}, x) where {T} = convert(T, x) |
| 48 | +_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V)) |
| 49 | + |
| 50 | +""" |
| 51 | + OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T} |
| 52 | +
|
| 53 | +This range permits diverse representations of arrays to comunicate common information |
| 54 | +about their indices. Each field may be an integer or `Val(<:Integer)` if it is known |
| 55 | +at compile time. An `OptionallyStaticUnitRange` is intended to be constructed internally |
| 56 | +from other valid indices. Therefore, users should not expect the same checks are used |
| 57 | +to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`. |
| 58 | +""" |
| 59 | +struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T} |
| 60 | + start::F |
| 61 | + stop::L |
| 62 | + |
| 63 | + function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real} |
| 64 | + if _get(start) isa T |
| 65 | + if _get(stop) isa T |
| 66 | + return new{T,typeof(start),typeof(stop)}(start, stop) |
| 67 | + else |
| 68 | + return OptionallyStaticUnitRange{T}(start, _convert(T, stop)) |
| 69 | + end |
| 70 | + else |
| 71 | + return OptionallyStaticUnitRange{T}(_convert(T, start), stop) |
| 72 | + end |
| 73 | + end |
| 74 | + |
| 75 | + function OptionallyStaticUnitRange(start, stop) |
| 76 | + T = promote_type(typeof(_get(start)), typeof(_get(stop))) |
| 77 | + return OptionallyStaticUnitRange{T}(start, stop) |
| 78 | + end |
| 79 | + |
| 80 | + function OptionallyStaticUnitRange(x::AbstractRange) |
| 81 | + if step(x) == 1 |
| 82 | + fst = known_first(x) |
| 83 | + fst = fst === nothing ? first(x) : Val(fst) |
| 84 | + lst = known_last(x) |
| 85 | + lst = lst === nothing ? last(x) : Val(lst) |
| 86 | + return OptionallyStaticUnitRange(fst, lst) |
| 87 | + else |
| 88 | + throw(ArgumentError("step must be 1, got $(step(r))")) |
| 89 | + end |
| 90 | + end |
| 91 | +end |
| 92 | + |
| 93 | +Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F |
| 94 | +Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start |
| 95 | + |
| 96 | +Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T) |
| 97 | + |
| 98 | +Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L |
| 99 | +Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop |
| 100 | + |
| 101 | +known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F |
| 102 | +known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T) |
| 103 | +known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L |
| 104 | + |
| 105 | +function Base.isempty(r::OptionallyStaticUnitRange) |
| 106 | + if known_first(r) === oneunit(eltype(r)) |
| 107 | + return unsafe_isempty_one_to(last(r)) |
| 108 | + else |
| 109 | + return unsafe_isempty_unit_range(first(r), last(r)) |
| 110 | + end |
| 111 | +end |
| 112 | + |
| 113 | +unsafe_isempty_one_to(lst) = lst <= zero(lst) |
| 114 | +unsafe_isempty_unit_range(fst, lst) = fst > lst |
| 115 | + |
| 116 | +unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T)) |
| 117 | + |
| 118 | +unsafe_length_one_to(lst::T) where {T<:Int} = T(lst) |
| 119 | +unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst)) |
| 120 | + |
| 121 | +Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer) |
| 122 | + if known_first(r) === oneunit(r) |
| 123 | + return get_index_one_to(r, i) |
| 124 | + else |
| 125 | + return get_index_unit_range(r, i) |
| 126 | + end |
| 127 | +end |
| 128 | + |
| 129 | +@inline function get_index_one_to(r, i) |
| 130 | + @boundscheck if ((i > 0) & (i <= last(r))) |
| 131 | + throw(BoundsError(r, i)) |
| 132 | + end |
| 133 | + return convert(eltype(r), i) |
| 134 | +end |
| 135 | + |
| 136 | +@inline function get_index_unit_range(r, i) |
| 137 | + val = first(r) + (i - 1) |
| 138 | + @boundscheck if i > 0 && val <= last(r) && val >= first(r) |
| 139 | + throw(BoundsError(r, i)) |
| 140 | + end |
| 141 | + return convert(eltype(r), val) |
| 142 | +end |
| 143 | + |
| 144 | +_try_static(x, y) = Val(x) |
| 145 | +_try_static(::Nothing, y) = Val(y) |
| 146 | +_try_static(x, ::Nothing) = Val(x) |
| 147 | +_try_static(::Nothing, ::Nothing) = nothing |
| 148 | + |
| 149 | +### |
| 150 | +### length |
| 151 | +### |
| 152 | +@inline function known_length(::Type{T}) where {T<:AbstractUnitRange} |
| 153 | + fst = known_first(T) |
| 154 | + lst = known_last(T) |
| 155 | + if fst === nothing || lst === nothing |
| 156 | + return nothing |
| 157 | + else |
| 158 | + if fst === oneunit(eltype(T)) |
| 159 | + return unsafe_length_one_to(lst) |
| 160 | + else |
| 161 | + return unsafe_length_unit_range(fst, lst) |
| 162 | + end |
| 163 | + end |
| 164 | +end |
| 165 | + |
| 166 | +function Base.length(r::OptionallyStaticUnitRange{T}) where {T} |
| 167 | + if isempty(r) |
| 168 | + return zero(T) |
| 169 | + else |
| 170 | + if known_one(r) === one(T) |
| 171 | + return unsafe_length_one_to(last(r)) |
| 172 | + else |
| 173 | + return unsafe_length_unit_range(first(r), last(r)) |
| 174 | + end |
| 175 | + end |
| 176 | +end |
| 177 | + |
| 178 | +function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}} |
| 179 | + return Base.checked_add(Base.checked_sub(lst, fst), one(T)) |
| 180 | +end |
| 181 | +function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}} |
| 182 | + return Base.checked_add(lst - fst, one(T)) |
| 183 | +end |
| 184 | + |
| 185 | +""" |
| 186 | + indices(x[, d]) |
| 187 | +
|
| 188 | +Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple |
| 189 | +of arrays then the indices corresponding to dimension `d` of all arrays in `x` are |
| 190 | +returned. If any indices are not equal along dimension `d` an error is thrown. A |
| 191 | +tuple may be used to specify a different dimension for each array. If `d` is not |
| 192 | +specified then indices for visiting each index of `x` is returned. |
| 193 | +""" |
| 194 | +@inline function indices(x) |
| 195 | + inds = eachindex(x) |
| 196 | + if inds isa AbstractUnitRange{<:Integer} |
| 197 | + return Base.Slice(OptionallyStaticUnitRange(inds)) |
| 198 | + else |
| 199 | + return inds |
| 200 | + end |
| 201 | +end |
| 202 | + |
| 203 | +function indices(x::Tuple) |
| 204 | + inds = map(eachindex, x) |
| 205 | + @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" |
| 206 | + return reduce(_pick_range, inds) |
| 207 | +end |
| 208 | + |
| 209 | +indices(x, d) = indices(axes(x, d)) |
| 210 | + |
| 211 | +@inline function indices(x::NTuple{N,<:Any}, dim) where {N} |
| 212 | + inds = map(x_i -> indices(x_i, dim), x) |
| 213 | + @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" |
| 214 | + return reduce(_pick_range, inds) |
| 215 | +end |
| 216 | + |
| 217 | +@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N} |
| 218 | + inds = map(indices, x, dim) |
| 219 | + @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" |
| 220 | + return reduce(_pick_range, inds) |
| 221 | +end |
| 222 | + |
| 223 | +@inline function _pick_range(x, y) |
| 224 | + fst = _try_static(known_first(x), known_first(y)) |
| 225 | + fst = fst === nothing ? first(x) : fst |
| 226 | + |
| 227 | + lst = _try_static(known_last(x), known_last(y)) |
| 228 | + lst = lst === nothing ? last(x) : lst |
| 229 | + return Base.Slice(OptionallyStaticUnitRange(fst, lst)) |
| 230 | +end |
| 231 | + |
0 commit comments