Skip to content

Commit d7a00ad

Browse files
authored
Merge pull request #63 from Tokazama/master
Safe co-iteration across an axis for 1+ arrays
2 parents 4a959b1 + 95c0102 commit d7a00ad

File tree

4 files changed

+277
-39
lines changed

4 files changed

+277
-39
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ Returns the parent array that `x` wraps.
2020
Returns `true` if the size of `T` can change, in which case operations
2121
such as `pop!` and `popfirst!` are available for collections of type `T`.
2222

23+
## indices(x[, d])
24+
25+
Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
26+
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
27+
returned. If any indices are not equal along dimension `d` an error is thrown. A
28+
tuple may be used to specify a different dimension for each array. If `d` is not
29+
specified then indices for visiting each index of `x` is returned.
30+
2331
## ismutable(x)
2432

2533
A trait function for whether `x` is a mutable or immutable array. Used for

src/ArrayInterface.jl

+18-39
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using Requires
44
using LinearAlgebra
55
using SparseArrays
66

7+
using Base: OneTo
8+
79
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
810
parameterless_type(x) = parameterless_type(typeof(x))
911
parameterless_type(x::Type) = __parameterless_type(x)
@@ -20,8 +22,21 @@ parent_type(::Type{Adjoint{T,S}}) where {T,S} = S
2022
parent_type(::Type{Transpose{T,S}}) where {T,S} = S
2123
parent_type(::Type{Symmetric{T,S}}) where {T,S} = S
2224
parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
25+
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
26+
parent_type(::Type{Base.Slice{T}}) where {T} = T
2327
parent_type(::Type{T}) where {T} = T
2428

29+
"""
30+
known_length(::Type{T})
31+
32+
If `length` of an instance of type `T` is known at compile time, return it.
33+
Otherwise, return `nothing`.
34+
"""
35+
known_length(x) = known_length(typeof(x))
36+
known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N
37+
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
38+
known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T))
39+
2540
"""
2641
can_change_size(::Type{T}) -> Bool
2742
@@ -503,45 +518,6 @@ function restructure(x::Array,y)
503518
reshape(convert(Array,y),size(x)...)
504519
end
505520

506-
"""
507-
known_first(::Type{T})
508-
509-
If `first` of an instance of type `T` is known at compile time, return it.
510-
Otherwise, return `nothing`.
511-
512-
@test isnothing(known_first(typeof(1:4)))
513-
@test isone(known_first(typeof(Base.OneTo(4))))
514-
"""
515-
known_first(x) = known_first(typeof(x))
516-
known_first(::Type{T}) where {T} = nothing
517-
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
518-
519-
"""
520-
known_last(::Type{T})
521-
522-
If `last` of an instance of type `T` is known at compile time, return it.
523-
Otherwise, return `nothing`.
524-
525-
@test isnothing(known_last(typeof(1:4)))
526-
using StaticArrays
527-
@test known_last(typeof(SOneTo(4))) == 4
528-
"""
529-
known_last(x) = known_last(typeof(x))
530-
known_last(::Type{T}) where {T} = nothing
531-
532-
"""
533-
known_step(::Type{T})
534-
535-
If `step` of an instance of type `T` is known at compile time, return it.
536-
Otherwise, return `nothing`.
537-
538-
@test isnothing(known_step(typeof(1:0.2:4)))
539-
@test isone(known_step(typeof(1:4)))
540-
"""
541-
known_step(x) = known_step(typeof(x))
542-
known_step(::Type{T}) where {T} = nothing
543-
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
544-
545521
function __init__()
546522

547523
@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
@@ -575,6 +551,7 @@ function __init__()
575551

576552
known_first(::Type{<:StaticArrays.SOneTo}) = 1
577553
known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
554+
known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
578555

579556
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
580557
function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S
@@ -697,4 +674,6 @@ function __init__()
697674
end
698675
end
699676

677+
include("ranges.jl")
678+
700679
end

src/ranges.jl

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
2+
"""
3+
known_first(::Type{T})
4+
5+
If `first` of an instance of type `T` is known at compile time, return it.
6+
Otherwise, return `nothing`.
7+
8+
@test isnothing(known_first(typeof(1:4)))
9+
@test isone(known_first(typeof(Base.OneTo(4))))
10+
"""
11+
known_first(x) = known_first(typeof(x))
12+
known_first(::Type{T}) where {T} = nothing
13+
known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
14+
known_first(::Type{T}) where {T<:Base.Slice} = known_first(parent_type(T))
15+
16+
"""
17+
known_last(::Type{T})
18+
19+
If `last` of an instance of type `T` is known at compile time, return it.
20+
Otherwise, return `nothing`.
21+
22+
@test isnothing(known_last(typeof(1:4)))
23+
using StaticArrays
24+
@test known_last(typeof(SOneTo(4))) == 4
25+
"""
26+
known_last(x) = known_last(typeof(x))
27+
known_last(::Type{T}) where {T} = nothing
28+
known_last(::Type{T}) where {T<:Base.Slice} = known_last(parent_type(T))
29+
30+
"""
31+
known_step(::Type{T})
32+
33+
If `step` of an instance of type `T` is known at compile time, return it.
34+
Otherwise, return `nothing`.
35+
36+
@test isnothing(known_step(typeof(1:0.2:4)))
37+
@test isone(known_step(typeof(1:4)))
38+
"""
39+
known_step(x) = known_step(typeof(x))
40+
known_step(::Type{T}) where {T} = nothing
41+
known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
42+
43+
# add methods to support ArrayInterface
44+
45+
_get(x) = x
46+
_get(::Val{V}) where {V} = V
47+
_convert(::Type{T}, x) where {T} = convert(T, x)
48+
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))
49+
50+
"""
51+
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}
52+
53+
This range permits diverse representations of arrays to comunicate common information
54+
about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
55+
at compile time. An `OptionallyStaticUnitRange` is intended to be constructed internally
56+
from other valid indices. Therefore, users should not expect the same checks are used
57+
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
58+
"""
59+
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
60+
start::F
61+
stop::L
62+
63+
function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
64+
if _get(start) isa T
65+
if _get(stop) isa T
66+
return new{T,typeof(start),typeof(stop)}(start, stop)
67+
else
68+
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
69+
end
70+
else
71+
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
72+
end
73+
end
74+
75+
function OptionallyStaticUnitRange(start, stop)
76+
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
77+
return OptionallyStaticUnitRange{T}(start, stop)
78+
end
79+
80+
function OptionallyStaticUnitRange(x::AbstractRange)
81+
if step(x) == 1
82+
fst = known_first(x)
83+
fst = fst === nothing ? first(x) : Val(fst)
84+
lst = known_last(x)
85+
lst = lst === nothing ? last(x) : Val(lst)
86+
return OptionallyStaticUnitRange(fst, lst)
87+
else
88+
throw(ArgumentError("step must be 1, got $(step(r))"))
89+
end
90+
end
91+
end
92+
93+
Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
94+
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start
95+
96+
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
97+
98+
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
99+
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop
100+
101+
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
102+
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
103+
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
104+
105+
function Base.isempty(r::OptionallyStaticUnitRange)
106+
if known_first(r) === oneunit(eltype(r))
107+
return unsafe_isempty_one_to(last(r))
108+
else
109+
return unsafe_isempty_unit_range(first(r), last(r))
110+
end
111+
end
112+
113+
unsafe_isempty_one_to(lst) = lst <= zero(lst)
114+
unsafe_isempty_unit_range(fst, lst) = fst > lst
115+
116+
unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))
117+
118+
unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
119+
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
120+
121+
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
122+
if known_first(r) === oneunit(r)
123+
return get_index_one_to(r, i)
124+
else
125+
return get_index_unit_range(r, i)
126+
end
127+
end
128+
129+
@inline function get_index_one_to(r, i)
130+
@boundscheck if ((i > 0) & (i <= last(r)))
131+
throw(BoundsError(r, i))
132+
end
133+
return convert(eltype(r), i)
134+
end
135+
136+
@inline function get_index_unit_range(r, i)
137+
val = first(r) + (i - 1)
138+
@boundscheck if i > 0 && val <= last(r) && val >= first(r)
139+
throw(BoundsError(r, i))
140+
end
141+
return convert(eltype(r), val)
142+
end
143+
144+
_try_static(x, y) = Val(x)
145+
_try_static(::Nothing, y) = Val(y)
146+
_try_static(x, ::Nothing) = Val(x)
147+
_try_static(::Nothing, ::Nothing) = nothing
148+
149+
###
150+
### length
151+
###
152+
@inline function known_length(::Type{T}) where {T<:AbstractUnitRange}
153+
fst = known_first(T)
154+
lst = known_last(T)
155+
if fst === nothing || lst === nothing
156+
return nothing
157+
else
158+
if fst === oneunit(eltype(T))
159+
return unsafe_length_one_to(lst)
160+
else
161+
return unsafe_length_unit_range(fst, lst)
162+
end
163+
end
164+
end
165+
166+
function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
167+
if isempty(r)
168+
return zero(T)
169+
else
170+
if known_one(r) === one(T)
171+
return unsafe_length_one_to(last(r))
172+
else
173+
return unsafe_length_unit_range(first(r), last(r))
174+
end
175+
end
176+
end
177+
178+
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
179+
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
180+
end
181+
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
182+
return Base.checked_add(lst - fst, one(T))
183+
end
184+
185+
"""
186+
indices(x[, d])
187+
188+
Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple
189+
of arrays then the indices corresponding to dimension `d` of all arrays in `x` are
190+
returned. If any indices are not equal along dimension `d` an error is thrown. A
191+
tuple may be used to specify a different dimension for each array. If `d` is not
192+
specified then indices for visiting each index of `x` is returned.
193+
"""
194+
@inline function indices(x)
195+
inds = eachindex(x)
196+
if inds isa AbstractUnitRange{<:Integer}
197+
return Base.Slice(OptionallyStaticUnitRange(inds))
198+
else
199+
return inds
200+
end
201+
end
202+
203+
function indices(x::Tuple)
204+
inds = map(eachindex, x)
205+
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
206+
return reduce(_pick_range, inds)
207+
end
208+
209+
indices(x, d) = indices(axes(x, d))
210+
211+
@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
212+
inds = map(x_i -> indices(x_i, dim), x)
213+
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
214+
return reduce(_pick_range, inds)
215+
end
216+
217+
@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
218+
inds = map(indices, x, dim)
219+
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
220+
return reduce(_pick_range, inds)
221+
end
222+
223+
@inline function _pick_range(x, y)
224+
fst = _try_static(known_first(x), known_first(y))
225+
fst = fst === nothing ? first(x) : fst
226+
227+
lst = _try_static(known_last(x), known_last(y))
228+
lst = lst === nothing ? last(x) : lst
229+
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
230+
end
231+

test/runtests.jl

+20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using StaticArrays
1010
@test ArrayInterface.ismutable((0.1,1.0)) == false
1111
@test isone(ArrayInterface.known_first(typeof(StaticArrays.SOneTo(7))))
1212
@test ArrayInterface.known_last(typeof(StaticArrays.SOneTo(7))) == 7
13+
@test ArrayInterface.known_length(typeof(StaticArrays.SOneTo(7))) == 7
1314

1415
using LinearAlgebra, SparseArrays
1516

@@ -173,6 +174,8 @@ using ArrayInterface: parent_type
173174
@test parent_type(transpose(x)) <: typeof(x)
174175
@test parent_type(Symmetric(x)) <: typeof(x)
175176
@test parent_type(UpperTriangular(x)) <: typeof(x)
177+
@test parent_type(PermutedDimsArray(x, (2,1))) <: typeof(x)
178+
@test parent_type(Base.Slice(1:10)) <: UnitRange{Int}
176179
end
177180

178181
@testset "Range Interface" begin
@@ -196,3 +199,20 @@ end
196199
@test !ArrayInterface.can_change_size(Tuple{})
197200
end
198201

202+
@testset "known_length" begin
203+
@test ArrayInterface.known_length(ArrayInterface.indices(SOneTo(7))) == 7
204+
@test ArrayInterface.known_length(1:2) == nothing
205+
@test ArrayInterface.known_length((1,)) == 1
206+
@test ArrayInterface.known_length((a=1,b=2)) == 2
207+
end
208+
209+
@testset "indices" begin
210+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
211+
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
212+
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
213+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
214+
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
215+
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
216+
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
217+
end
218+

0 commit comments

Comments
 (0)