Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some methods to help support LoopVectorization.jl #61

Merged
merged 53 commits into from
Sep 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
9a41d08
Added some methods to help support LoopVectorization.jl
chriselrod Aug 11, 2020
597606b
Add a little more to doc on `batch` return of `stridelayout`.
chriselrod Aug 11, 2020
2513045
canavx -> can_avx, add&fix missing stridelayout test, make strideorde…
chriselrod Aug 12, 2020
791cffa
Merge branch 'master' into loopvecsupport
chriselrod Aug 12, 2020
6c36e69
Update tests for stridelayout always using tuples for rank.
chriselrod Aug 12, 2020
8af5986
Add density information per axis.
chriselrod Aug 12, 2020
87e8bd1
Add `axesdense` to stridelayout's returns.
chriselrod Aug 15, 2020
53929ad
Merge branch 'master' into loopvecsupport
chriselrod Aug 18, 2020
572fccc
Add sentinal values to batch parameter definition.
chriselrod Aug 18, 2020
874fee8
Split apart definitions.
chriselrod Aug 20, 2020
358a93c
Updated README.
chriselrod Aug 20, 2020
0ca96b8
More tests, simplify `Device`
chriselrod Aug 20, 2020
9b09212
Add contiguous_batch_size docs.
chriselrod Aug 20, 2020
3417ba0
Better stride_rank.
chriselrod Aug 20, 2020
944fd93
Terrible cumsum(::Tuple) implementation for pre-1.5.
chriselrod Aug 20, 2020
a8567db
Device -> device changes, differentiate between CPUPointer and CPUInd…
chriselrod Aug 22, 2020
bd55e33
Make `stride_rank(::Type{T})`(no dim argument) return a parametrized …
chriselrod Aug 22, 2020
0862eb6
Added hybrid dynamic-static tuple type, and size and strides function…
chriselrod Aug 23, 2020
a3ac561
Added `nothing` fallbacks for sdsize and sdstrides.
chriselrod Aug 23, 2020
2f92c61
Use tuple type instead of value tuples in SDTuple, to facilitate use …
chriselrod Aug 26, 2020
9778d47
unwrap -> _get
chriselrod Aug 28, 2020
ab02af4
Use Static for partially static tuples.
chriselrod Sep 1, 2020
7cfd029
Merge branch 'master' into loopvecsupport
chriselrod Sep 1, 2020
39f414b
Make indices accept heterogenous tuples, and ntuple(f, ::Static{}) fu…
chriselrod Sep 1, 2020
9f852da
Add OffsetArray support.
chriselrod Sep 3, 2020
94eb2a7
Better default sdoffsets.
chriselrod Sep 3, 2020
973a01c
Delete `Static` ntuple.
chriselrod Sep 4, 2020
6e9207b
Better sdoffsets.
chriselrod Sep 4, 2020
00d0a00
Added sdoffset tests and made inference improvements.
chriselrod Sep 6, 2020
5120651
Add more Static methods to avoid ambiguities.
chriselrod Sep 7, 2020
d6130f8
Update src/stridelayout.jl
chriselrod Sep 8, 2020
a88c692
Update README.md
chriselrod Sep 8, 2020
910a42a
Update README.md
chriselrod Sep 8, 2020
6ad4ec5
Added docstrings for sd(size/strides/offsets), made `OptionmallyStati…
chriselrod Sep 8, 2020
28404dc
More indices tests.
chriselrod Sep 8, 2020
1f08ae0
Added some `Static` comparison methods to improve type stability.
chriselrod Sep 8, 2020
00c9112
Merged `static.jl`
chriselrod Sep 9, 2020
5f1f14f
Merge branch 'static' into loopvecsupport
chriselrod Sep 9, 2020
dd7b0ed
Remove extra parameter.
chriselrod Sep 9, 2020
726691e
Merged.
chriselrod Sep 9, 2020
895f9b7
Fix _sdsize (generated functions should rely on methods others will e…
chriselrod Sep 14, 2020
a59e05a
_try_static in _sdsize
chriselrod Sep 14, 2020
7d17247
propagate_inbounds in ranges.jl
chriselrod Sep 14, 2020
18b954b
Finish merging master.
chriselrod Sep 14, 2020
09515fe
Move docstring for rank_to_sortperm.
chriselrod Sep 14, 2020
0ef7653
Drop `sd` prefix to size/strides/offsets.
chriselrod Sep 15, 2020
33b8433
Add indexing size/stride methods.
chriselrod Sep 15, 2020
efe40e1
add where and tests.
chriselrod Sep 15, 2020
e72cfc8
Set generic fallback size and strides to equal base methods.
chriselrod Sep 15, 2020
42258b5
Remove Base.Val(::Static{N}) definition to reduce invalidations.
chriselrod Sep 16, 2020
090dab9
Static -> StaticInt.
chriselrod Sep 18, 2020
ac8033a
Update README for sd -> ArrayInterface. and Static -> StaticInt
chriselrod Sep 18, 2020
21f7d48
More README fixes.
chriselrod Sep 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
deps/deps.jl
Manifest.toml
*~
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "2.12.1"
version = "2.13.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -16,10 +16,11 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "OffsetArrays", "Aqua"]
76 changes: 72 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,80 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns
If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

## Static(N::Int)
## device(::Type{T})

Indicates the most efficient way to access elements from the collection in low level code.
For `GPUArrays`, will return `ArrayInterface.GPU()`.
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
Otherwise, returns `nothing`.

## contiguous_axis(::Type{T})

Returns the axis of an array of type `T` containing contiguous data.
If no axis is contiguous, it returns `Contiguous{-1}`.
If unknown, it returns `nothing`.

## contiguous_axis_indicator(::Type{T})

Returns a tuple of boolean `Val`s indicating whether that axis is contiguous.

## contiguous_batch_size(::Type{T})

Returns the size of contiguous batches if `!isone(stride_rank(T, contiguous_axis(T)))`.
If `isone(stride_rank(T, contiguous_axis(T)))`, then it will return `ContiguousBatch{0}()`.
If `contiguous_axis(T) == -1`, it will return `ContiguousBatch{-1}()`.
If unknown, it will return `nothing`.

## stride_rank(::Type{T})

Returns the rank of each stride.

## dense_dims(::Type{T})
Returns a tuple of indicators for whether each axis is dense.
An axis `i` of array `A` is dense if `stride(A, i) * size(A, i) == stride(A, j)` where `j` is the axis (if it exists) such that `stride_rank(A)[i] + 1 == stride_rank(A)[j]`.

## ArrayInterface.size(A)

Returns the size of `A`. If the size of any axes are known at compile time,
these should be returned as `StaticInt`s. For example:
```julia
julia> using StaticArrays, ArrayInterface

julia> A = @SMatrix rand(3,4);

julia> ArrayInterface.size(A)
(StaticInt{3}(), StaticInt{4}())
```

## ArrayInterface.strides(A)

Returns the strides of array `A`. If any strides are known at compile time,
these should be returned as `StaticInt`s. For example:
```julia
julia> using ArrayInterface

julia> A = rand(3,4);

julia> ArrayInterface.strides(A)
(StaticInt{1}(), 3)
```
## offsets(A)

Returns offsets of indices with respect to 0. If values are known at compile time,
it should return them as `StaticInt`s.
For example, if `A isa Base.Matrix`, `offsets(A) === (StaticInt(1), StaticInt(1))`.

## can_avx(f)

Is the function `f` whitelisted for `LoopVectorization.@avx`?

## StaticInt(N::Int)

Creates a static integer with value known at compile time. It is a number,
supporting basic arithmetic. Many operations with two `Static` integers
will produce another `Static` integer. If one of the arguments to a
function call isn't static (e.g., `Static(4) + 3`) then the `Static`
supporting basic arithmetic. Many operations with two `StaticInt` integers
will produce another `StaticInt` integer. If one of the arguments to a
function call isn't static (e.g., `StaticInt(4) + 3`) then the `StaticInt`
number will promote to a dynamic value.

# List of things to add
Expand Down
117 changes: 102 additions & 15 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.
"""
known_length(x) = known_length(typeof(x))
known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N
known_length(::Type{<:NamedTuple{L}}) where {L} = length(L)
known_length(::Type{T}) where {T<:Base.Slice} = known_length(parent_type(T))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking a bit about Base.Slice in the context of HybridArrays and my general conclusion was to treat is the same way as Colon. IIRC there was a case where a statically-known axis ended up being a slice with Base.OneTo instead of SOneTo. Are there any benefits of using this method instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking a bit about Base.Slice in the context of HybridArrays and my general conclusion was to treat is the same way as Colon.

We are treating them as traversing the entire axis, so a Base.Slice preserves density.

IIRC there was a case where a statically-known axis ended up being a slice with Base.OneTo instead of SOneTo

SOneTo is preferable, so I'd recommend getting that case to return a Base.Slice(SOneTo(N)) instead.

Are there any benefits of using this method instead?

Using which method instead of which alternative?

Currently, known_length takes a type and returns nothing (if length unknown, e.g. for Base.Arrays), or returns the known length.
A problem is that we can't really use this within @generated functions, lest we face world age issues / defeat the entire purpose of providing an interface others can extend (I just checked and realize I have to fix _sdsize).

I'm considering a breaking change where known_length will return the Static numbers introduced in this PR, which will make it easier to "wrap" @generated functions (i.e., the wrapper can safely call known_length and friends, and with Static, it can pass the information on to the @generated function).
Although this works:

L = known_length(T)
if isnothing(L)
   # handle
else
  Static(L) # do something with this
end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOneTo is preferable, so I'd recommend getting that case to return a Base.Slice(SOneTo(N)) instead.

OK, I'll keep that in mind.

Using which method instead of which alternative?

In HybridArrays I take into account both array type and types of indices to figure out whether a certain access pattern is statically-sized or not. In some places for example Colon has to be special-cased, so the same special-casing can be used for Slice, like I did here for example: https://github.com/JuliaArrays/StaticArrays.jl/pull/783/files#diff-c1a2ec8ab9a030d018ed5124a63d38a4R193 . I don't see in what circumstances known_length(::Type{<:Base.Slice}) would be used that don't treat Colon in a special way as well.

I'm considering a breaking change where known_length will return the Static numbers introduced in this PR, which will make it easier to "wrap" @generated functions (i.e., the wrapper can safely call known_length and friends, and with Static, it can pass the information on to the @generated function).

I support this change. Generated functions that need static size information should get that information as an argument. That's what StaticArray does in many places. The downside main I think is potentially increasing compilation times.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@chriselrod chriselrod Sep 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, that slipped my mind. Fixed.

known_length(::Type{<:Tuple{Vararg{Any,N}}}) where {N} = N
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need known_length(::Type{<:NTuple{N,<:Any}}) where {N} = N having this method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the NTuple method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't have this defined for tuples then the method would return nothing, implying the length isn't known.

Copy link
Collaborator Author

@chriselrod chriselrod Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Tuple{Vararg{Any,N}} definition works for both homogenous and heterogenous tuples, so the NTuple-specific definition was redundant.

julia> NTuple{1423} <: Tuple{Vararg{Any,1423}}
true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood. I thought the final idea was to remove both Tuple methods, sorry.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To expand a little on Chris' answer, in case you didn't know:

julia> (1, 'd') isa NTuple{2,<:Any}
false

julia> (1, 'd') isa NTuple{2,Any}
true

known_length(::Type{<:Number}) = 1
function known_length(::Type{T}) where {T}
if parent_type(T) <: T
Expand All @@ -52,7 +53,7 @@ _known_length(x::Tuple{Vararg{Int}}) = prod(x)
"""
can_change_size(::Type{T}) -> Bool

Returns `true` if the size of `T` can change, in which case operations
Returns `true` if the Base.size of `T` can change, in which case operations
such as `pop!` and `popfirst!` are available for collections of type `T`.
"""
can_change_size(x) = can_change_size(typeof(x))
Expand Down Expand Up @@ -102,7 +103,7 @@ function Base.setindex(x::AbstractVector,v,i::Int)
end

function Base.setindex(x::AbstractMatrix,v,i::Int,j::Int)
n,m = size(x)
n,m = Base.size(x)
x .* (i .!== 1:n) .* (j .!== i:m)' .+ v .* (i .== 1:n) .* (j .== i:m)'
end

Expand Down Expand Up @@ -202,7 +203,7 @@ Return: (I,J) #indexable objects
Find sparsity pattern of special matrices, the same as the first two elements of findnz(::SparseMatrixCSC)
"""
function findstructralnz(x::Diagonal)
n=size(x,1)
n = Base.size(x,1)
(1:n,1:n)
end

Expand Down Expand Up @@ -412,15 +413,15 @@ function Base.getindex(ind::BandedBlockBandedMatrixIndex,i::Int)
end

function findstructralnz(x::Bidiagonal)
n=size(x,1)
n= Base.size(x,1)
isup= x.uplo=='U' ? true : false
rowind=BidiagonalIndex(n+n-1,isup)
colind=BidiagonalIndex(n+n-1,!isup)
(rowind,colind)
end

function findstructralnz(x::Union{Tridiagonal,SymTridiagonal})
n=size(x,1)
n= Base.size(x,1)
rowind=TridiagonalIndex(n+n-1+n-1,n,true)
colind=TridiagonalIndex(n+n-1+n-1,n,false)
(rowind,colind)
Expand All @@ -447,26 +448,26 @@ fast_matrix_colors(A::Type{<:Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagona
matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})

The color vector for dense matrix and triangular matrix is simply
`[1,2,3,...,size(A,2)]`
`[1,2,3,..., Base.size(A,2)]`
"""
function matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
eachindex(1:size(A,2)) # Vector size matches number of rows
eachindex(1:Base.size(A,2)) # Vector Base.size matches number of rows
end

function _cycle(repetend,len)
repeat(repetend,div(len,length(repetend))+1)[1:len]
end

function matrix_colors(A::Diagonal)
fill(1,size(A,2))
fill(1, Base.size(A,2))
end

function matrix_colors(A::Bidiagonal)
_cycle(1:2,size(A,2))
_cycle(1:2, Base.size(A,2))
end

function matrix_colors(A::Union{Tridiagonal,SymTridiagonal})
_cycle(1:3,size(A,2))
_cycle(1:3, Base.size(A,2))
end

"""
Expand Down Expand Up @@ -540,9 +541,64 @@ function restructure(x,y)
end

function restructure(x::Array,y)
reshape(convert(Array,y),size(x)...)
reshape(convert(Array,y), Base.size(x)...)
end

abstract type AbstractDevice end
abstract type AbstractCPU <: AbstractDevice end
struct CPUPointer <: AbstractCPU end
struct CheckParent end
struct CPUIndex <: AbstractCPU end
struct GPU <: AbstractDevice end
"""
device(::Type{T})

Indicates the most efficient way to access elements from the collection in low level code.
For `GPUArrays`, will return `ArrayInterface.GPU()`.
For `AbstractArray` supporting a `pointer` method, returns `ArrayInterface.CPUPointer()`.
For other `AbstractArray`s and `Tuple`s, returns `ArrayInterface.CPUIndex()`.
Otherwise, returns `nothing`.
"""
device(A) = device(typeof(A))
device(::Type) = nothing
device(::Type{<:Tuple}) = CPUIndex()
# Relies on overloading for GPUArrays that have subtyped `StridedArray`.
device(::Type{<:StridedArray}) = CPUPointer()
function device(::Type{T}) where {T <: AbstractArray}
P = parent_type(T)
T === P ? CPUIndex() : device(P)
end


"""
defines_strides(::Type{T}) -> Bool

Is strides(::T) defined?
"""
defines_strides(::Type) = false
defines_strides(x) = defines_strides(typeof(x))
defines_strides(::Type{<:StridedArray}) = true
defines_strides(::Type{A}) where {A <: Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} = defines_strides(parent_type(A))

"""
can_avx(f)

Returns `true` if the function `f` is guaranteed to be compatible with `LoopVectorization.@avx` for supported element and array types.
While a return value of `false` does not indicate the function isn't supported, this allows a library to conservatively apply `@avx`
only when it is known to be safe to do so.

```julia
function mymap!(f, y, args...)
if can_avx(f)
@avx @. y = f(args...)
else
@. y = f(args...)
end
end
```
"""
can_avx(::Any) = false

"""
insert(collection, index, item)

Expand Down Expand Up @@ -654,6 +710,7 @@ function __init__()
ismutable(::Type{<:StaticArrays.StaticArray}) = false
can_setindex(::Type{<:StaticArrays.StaticArray}) = false
ismutable(::Type{<:StaticArrays.MArray}) = true
ismutable(::Type{<:StaticArrays.SizedArray}) = true

function lu_instance(_A::StaticArrays.StaticMatrix{N,N}) where {N}
A = StaticArrays.SArray(_A)
Expand All @@ -675,6 +732,26 @@ function __init__()
known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N

device(::Type{<:StaticArrays.MArray}) = CPUPointer()
contiguous_axis(::Type{<:StaticArrays.StaticArray}) = Contiguous{1}()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole part will have to be revisited after merging JuliaArrays/StaticArrays.jl#783.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preferably it'd define a parent function, otherwise we'll need to access the parent through A.data ourselves, which may not be external API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've added parent for SizedArray in my PR.

contiguous_batch_size(::Type{<:StaticArrays.StaticArray}) = ContiguousBatch{0}()
stride_rank(::Type{T}) where {N, T <: StaticArrays.StaticArray{<:Any,<:Any,N}} = StrideRank{ntuple(identity, Val{N}())}()
dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} = DenseDims{ntuple(_ -> true, Val(N))}()
defines_strides(::Type{<:StaticArrays.MArray}) = true
@generated function size(A::StaticArrays.StaticArray{S}) where {S}
t = Expr(:tuple); Sp = S.parameters
for n in 1:length(Sp)
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, Sp[n])))
end
t
end
@generated function strides(A::StaticArrays.StaticArray{S}) where {S}
t = Expr(:tuple, Expr(:call, Expr(:curly, :StaticInt, 1))); Sp = S.parameters; x = 1
for n in 1:length(Sp)-1
push!(t.args, Expr(:call, Expr(:curly, :StaticInt, (x *= Sp[n]))))
end
t
end
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
function Adapt.adapt_storage(::Type{<:StaticArrays.SArray{S}},xs::Array) where S
StaticArrays.SArray{S}(xs)
Expand All @@ -694,7 +771,7 @@ function __init__()
aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where N = Tracker.collect(x)
end

@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
include("cuarrays.jl")
end
Expand All @@ -717,7 +794,7 @@ function __init__()
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
function findstructralnz(x::BandedMatrices.BandedMatrix)
l,u=BandedMatrices.bandwidths(x)
rowsize,colsize=size(x)
rowsize,colsize= Base.size(x)
rowind=BandedMatrixIndex(rowsize,colsize,l,u,true)
colind=BandedMatrixIndex(rowsize,colsize,l,u,false)
(rowind,colind)
Expand All @@ -730,7 +807,7 @@ function __init__()
function matrix_colors(A::BandedMatrices.BandedMatrix)
l,u=BandedMatrices.bandwidths(A)
width=u+l+1
_cycle(1:width,size(A,2))
_cycle(1:width, Base.size(A,2))
end

end
Expand Down Expand Up @@ -794,9 +871,19 @@ function __init__()
end
end
end
@require OffsetArrays="6fe1bfb0-de20-5000-8ca7-80f57d26f881" begin
size(A::OffsetArrays.OffsetArray) = size(parent(A))
strides(A::OffsetArrays.OffsetArray) = strides(parent(A))
# offsets(A::OffsetArrays.OffsetArray) = map(+, A.offsets, offsets(parent(A)))
device(::OffsetArrays.OffsetArray) = CheckParent()
contiguous_axis(A::OffsetArrays.OffsetArray) = contiguous_axis(parent(A))
contiguous_batch_size(A::OffsetArrays.OffsetArray) = contiguous_batch_size(parent(A))
stride_rank(A::OffsetArrays.OffsetArray) = stride_rank(parent(A))
end
end

include("static.jl")
include("ranges.jl")
include("stridelayout.jl")

end
5 changes: 4 additions & 1 deletion src/cuarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ function Base.setindex(x::CuArrays.CuArray,v,i::Int)
end

function restructure(x::CuArrays.CuArray,y)
reshape(Adapt.adapt(parameterless_type(x),y),size(x)...)
reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...)
end

Device(::Type{<:CuArrays.CuArray}) = GPU()

5 changes: 4 additions & 1 deletion src/cuarrays2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,8 @@ function Base.setindex(x::CUDA.CuArray,v,i::Int)
end

function restructure(x::CUDA.CuArray,y)
reshape(Adapt.adapt(parameterless_type(x),y),size(x)...)
reshape(Adapt.adapt(parameterless_type(x),y), Base.size(x)...)
end

Device(::Type{<:CUDA.CuArray}) = GPU()

Loading