Skip to content

Commit 1e1f4a5

Browse files
authored
Merge pull request #61 from chriselrod/loopvecsupport
Added some methods to help support LoopVectorization.jl
2 parents e49ec67 + 21f7d48 commit 1e1f4a5

10 files changed

+725
-114
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
deps/deps.jl
55
Manifest.toml
6+
*~

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.12.1"
3+
version = "2.13.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -16,10 +16,11 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1616
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
1717
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1818
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
19+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2021
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2122
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
2223
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2324

2425
[targets]
25-
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"]
26+
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"]

README.md

+72-4
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,80 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns
134134
If `length` of an instance of type `T` is known at compile time, return it.
135135
Otherwise, return `nothing`.
136136

137-
## Static(N::Int)
137+
## device(::Type{T})
138+
139+
Indicates the most efficient way to access elements from the collection in low level code.
140+
For `GPUArrays`, will return `ArrayInterface.GPU()`.
141+
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
142+
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
143+
Otherwise, returns `nothing`.
144+
145+
## contiguous_axis(::Type{T})
146+
147+
Returns the axis of an array of type `T` containing contiguous data.
148+
If no axis is contiguous, it returns `Contiguous{-1}`.
149+
If unknown, it returns `nothing`.
150+
151+
## contiguous_axis_indicator(::Type{T})
152+
153+
Returns a tuple of boolean `Val`s indicating whether that axis is contiguous.
154+
155+
## contiguous_batch_size(::Type{T})
156+
157+
Returns the size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
158+
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
159+
If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
160+
If unknown, it will return `nothing`.
161+
162+
## stride_rank(::Type{T})
163+
164+
Returns the rank of each stride.
165+
166+
## dense_dims(::Type{T})
167+
Returns a tuple of indicators for whether each axis is dense.
168+
An axis `i` of array `A` is dense if `stride(A, i) * size(A, i) == stride(A, j)` where `j` is the axis (if it exists) such that `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.
169+
170+
## ArrayInterface.size(A)
171+
172+
Returns the size of `A`. If the size of any axes are known at compile time,
173+
these should be returned as `StaticInt`s. For example:
174+
```julia
175+
julia> using StaticArrays, ArrayInterface
176+
177+
julia> A = @SMatrix rand(3,4);
178+
179+
julia> ArrayInterface.size(A)
180+
(StaticInt{3}(), StaticInt{4}())
181+
```
182+
183+
## ArrayInterface.strides(A)
184+
185+
Returns the strides of array `A`. If any strides are known at compile time,
186+
these should be returned as `StaticInt`s. For example:
187+
```julia
188+
julia> using ArrayInterface
189+
190+
julia> A = rand(3,4);
191+
192+
julia> ArrayInterface.strides(A)
193+
(StaticInt{1}(), 3)
194+
```
195+
## offsets(A)
196+
197+
Returns offsets of indices with respect to 0. If values are known at compile time,
198+
it should return them as `StaticInt`s.
199+
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.
200+
201+
## can_avx(f)
202+
203+
Is the function `f` whitelisted for `LoopVectorization.@avx`?
204+
205+
## StaticInt(N::Int)
138206

139207
Creates a static integer with value known at compile time. It is a number,
140-
supporting basic arithmetic. Many operations with two `Static` integers
141-
will produce another `Static` integer. If one of the arguments to a
142-
function call isn't static (e.g., `Static(4) + 3`) then the `Static`
208+
supporting basic arithmetic. Many operations with two `StaticInt` integers
209+
will produce another `StaticInt` integer. If one of the arguments to a
210+
function call isn't static (e.g., `StaticInt(4) + 3`) then the `StaticInt`
143211
number will promote to a dynamic value.
144212

145213
# List of things to add

src/ArrayInterface.jl

+102-15
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ If `length` of an instance of type `T` is known at compile time, return it.
3333
Otherwise, return `nothing`.
3434
"""
3535
known_length(x) = known_length(typeof(x))
36-
known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N
3736
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
37+
known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T))
38+
known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N
3839
known_length(::Type{<:Number}) = 1
3940
function known_length(::Type{T}) where {T}
4041
if parent_type(T) <: T
@@ -52,7 +53,7 @@ _known_length(x::Tuple{Vararg{Int}}) = prod(x)
5253
"""
5354
can_change_size(::Type{T}) -> Bool
5455
55-
Returns `true` if the size of `T` can change, in which case operations
56+
Returns `true` if the Base.size of `T` can change, in which case operations
5657
such as `pop!` and `popfirst!` are available for collections of type `T`.
5758
"""
5859
can_change_size(x) = can_change_size(typeof(x))
@@ -102,7 +103,7 @@ function Base.setindex(x::AbstractVector,v,i::Int)
102103
end
103104

104105
function Base.setindex(x::AbstractMatrix,v,i::Int,j::Int)
105-
n,m = size(x)
106+
n,m = Base.size(x)
106107
x .* (i .!== 1:n) .* (j .!== i:m)' .+ v .* (i .== 1:n) .* (j .== i:m)'
107108
end
108109

@@ -202,7 +203,7 @@ Return: (I,J) #indexable objects
202203
Find sparsity pattern of special matrices, the same as the first two elements of findnz(::SparseMatrixCSC)
203204
"""
204205
function findstructralnz(x::Diagonal)
205-
n=size(x,1)
206+
n = Base.size(x,1)
206207
(1:n,1:n)
207208
end
208209

@@ -412,15 +413,15 @@ function Base.getindex(ind::BandedBlockBandedMatrixIndex,i::Int)
412413
end
413414

414415
function findstructralnz(x::Bidiagonal)
415-
n=size(x,1)
416+
n= Base.size(x,1)
416417
isup= x.uplo=='U' ? true : false
417418
rowind=BidiagonalIndex(n+n-1,isup)
418419
colind=BidiagonalIndex(n+n-1,!isup)
419420
(rowind,colind)
420421
end
421422

422423
function findstructralnz(x::Union{Tridiagonal,SymTridiagonal})
423-
n=size(x,1)
424+
n= Base.size(x,1)
424425
rowind=TridiagonalIndex(n+n-1+n-1,n,true)
425426
colind=TridiagonalIndex(n+n-1+n-1,n,false)
426427
(rowind,colind)
@@ -447,26 +448,26 @@ fast_matrix_colors(A::Type{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagona
447448
matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
448449
449450
The color vector for dense matrix and triangular matrix is simply
450-
`[1,2,3,...,size(A,2)]`
451+
`[1,2,3,..., Base.size(A,2)]`
451452
"""
452453
function matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
453-
eachindex(1:size(A,2)) # Vector size matches number of rows
454+
eachindex(1:Base.size(A,2)) # Vector Base.size matches number of rows
454455
end
455456

456457
function _cycle(repetend,len)
457458
repeat(repetend,div(len,length(repetend))+1)[1:len]
458459
end
459460

460461
function matrix_colors(A::Diagonal)
461-
fill(1,size(A,2))
462+
fill(1, Base.size(A,2))
462463
end
463464

464465
function matrix_colors(A::Bidiagonal)
465-
_cycle(1:2,size(A,2))
466+
_cycle(1:2, Base.size(A,2))
466467
end
467468

468469
function matrix_colors(A::Union{Tridiagonal,SymTridiagonal})
469-
_cycle(1:3,size(A,2))
470+
_cycle(1:3, Base.size(A,2))
470471
end
471472

472473
"""
@@ -540,9 +541,64 @@ function restructure(x,y)
540541
end
541542

542543
function restructure(x::Array,y)
543-
reshape(convert(Array,y),size(x)...)
544+
reshape(convert(Array,y), Base.size(x)...)
544545
end
545546

547+
abstract type AbstractDevice end
548+
abstract type AbstractCPU <: AbstractDevice end
549+
struct CPUPointer <: AbstractCPU end
550+
struct CheckParent end
551+
struct CPUIndex <: AbstractCPU end
552+
struct GPU <: AbstractDevice end
553+
"""
554+
device(::Type{T})
555+
556+
Indicates the most efficient way to access elements from the collection in low level code.
557+
For `GPUArrays`, will return `ArrayInterface.GPU()`.
558+
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
559+
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
560+
Otherwise, returns `nothing`.
561+
"""
562+
device(A) = device(typeof(A))
563+
device(::Type) = nothing
564+
device(::Type{<:Tuple}) = CPUIndex()
565+
# Relies on overloading for GPUArrays that have subtyped `StridedArray`.
566+
device(::Type{<:StridedArray}) = CPUPointer()
567+
function device(::Type{T}) where {T <: AbstractArray}
568+
P = parent_type(T)
569+
T === P ? CPUIndex() : device(P)
570+
end
571+
572+
573+
"""
574+
defines_strides(::Type{T}) -> Bool
575+
576+
Is strides(::T) defined?
577+
"""
578+
defines_strides(::Type) = false
579+
defines_strides(x) = defines_strides(typeof(x))
580+
defines_strides(::Type{<:StridedArray}) = true
581+
defines_strides(::Type{A}) where {A <: Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} = defines_strides(parent_type(A))
582+
583+
"""
584+
can_avx(f)
585+
586+
Returns `true` if the function `f` is guaranteed to be compatible with `LoopVectorization.@avx` for supported element and array types.
587+
While a return value of `false` does not indicate the function isn't supported, this allows a library to conservatively apply `@avx`
588+
only when it is known to be safe to do so.
589+
590+
```julia
591+
function mymap!(f, y, args...)
592+
if can_avx(f)
593+
@avx @. y = f(args...)
594+
else
595+
@. y = f(args...)
596+
end
597+
end
598+
```
599+
"""
600+
can_avx(::Any) = false
601+
546602
"""
547603
insert(collection, index, item)
548604
@@ -654,6 +710,7 @@ function __init__()
654710
ismutable(::Type{<:StaticArrays.StaticArray}) = false
655711
can_setindex(::Type{<:StaticArrays.StaticArray}) = false
656712
ismutable(::Type{<:StaticArrays.MArray}) = true
713+
ismutable(::Type{<:StaticArrays.SizedArray}) = true
657714

658715
function lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
659716
A = StaticArrays.SArray(_A)
@@ -675,6 +732,26 @@ function __init__()
675732
known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
676733
known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
677734

735+
device(::Type{<:StaticArrays.MArray}) = CPUPointer()
736+
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}()
737+
contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}()
738+
stride_rank(::Type{T}) where {N, T <: StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}()
739+
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}()
740+
defines_strides(::Type{<:StaticArrays.MArray}) = true
741+
@generated function size(A::StaticArrays.StaticArray{S}) where {S}
742+
t = Expr(:tuple); Sp = S.parameters
743+
for n in 1:length(Sp)
744+
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
745+
end
746+
t
747+
end
748+
@generated function strides(A::StaticArrays.StaticArray{S}) where {S}
749+
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1))); Sp = S.parameters; x = 1
750+
for n in 1:length(Sp)-1
751+
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, (x *= Sp[n]))))
752+
end
753+
t
754+
end
678755
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
679756
function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S
680757
StaticArrays.SArray{S}(xs)
@@ -694,7 +771,7 @@ function __init__()
694771
aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where N = Tracker.collect(x)
695772
end
696773

697-
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
774+
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
698775
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
699776
include("cuarrays.jl")
700777
end
@@ -717,7 +794,7 @@ function __init__()
717794
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
718795
function findstructralnz(x::BandedMatrices.BandedMatrix)
719796
l,u=BandedMatrices.bandwidths(x)
720-
rowsize,colsize=size(x)
797+
rowsize,colsize= Base.size(x)
721798
rowind=BandedMatrixIndex(rowsize,colsize,l,u,true)
722799
colind=BandedMatrixIndex(rowsize,colsize,l,u,false)
723800
(rowind,colind)
@@ -730,7 +807,7 @@ function __init__()
730807
function matrix_colors(A::BandedMatrices.BandedMatrix)
731808
l,u=BandedMatrices.bandwidths(A)
732809
width=u+l+1
733-
_cycle(1:width,size(A,2))
810+
_cycle(1:width, Base.size(A,2))
734811
end
735812

736813
end
@@ -794,9 +871,19 @@ function __init__()
794871
end
795872
end
796873
end
874+
@require OffsetArrays="6fe1bfb0-de20-5000-8ca7-80f57d26f881" begin
875+
size(A::OffsetArrays.OffsetArray) = size(parent(A))
876+
strides(A::OffsetArrays.OffsetArray) = strides(parent(A))
877+
# offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A)))
878+
device(::OffsetArrays.OffsetArray) = CheckParent()
879+
contiguous_axis(A::OffsetArrays.OffsetArray) = contiguous_axis(parent(A))
880+
contiguous_batch_size(A::OffsetArrays.OffsetArray) = contiguous_batch_size(parent(A))
881+
stride_rank(A::OffsetArrays.OffsetArray) = stride_rank(parent(A))
882+
end
797883
end
798884

799885
include("static.jl")
800886
include("ranges.jl")
887+
include("stridelayout.jl")
801888

802889
end

src/cuarrays.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ function Base.setindex(x::CuArrays.CuArray,v,i::Int)
99
end
1010

1111
function restructure(x::CuArrays.CuArray,y)
12-
reshape(Adapt.adapt(parameterless_type(x),y),size(x)...)
12+
reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...)
1313
end
14+
15+
Device(::Type{<:CuArrays.CuArray}) = GPU()
16+

src/cuarrays2.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ function Base.setindex(x::CUDA.CuArray,v,i::Int)
99
end
1010

1111
function restructure(x::CUDA.CuArray,y)
12-
reshape(Adapt.adapt(parameterless_type(x),y),size(x)...)
12+
reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...)
1313
end
14+
15+
Device(::Type{<:CUDA.CuArray}) = GPU()
16+

0 commit comments

Comments
 (0)