Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sized AbstractArray #783

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 169 additions & 45 deletions src/SizedArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
SizedArray{Tuple{dims...}}(array)

Wraps an `Array` with a static size, so to take advantage of the (faster)
Wraps an `AbstractArray` with a static size, so to take advantage of the (faster)
methods defined by the static array package. The size is checked once upon
construction to determine if the number of elements (`length`) match, but the
array may be reshaped.
Expand All @@ -11,37 +11,48 @@ The aliases `SizedVector{N}` and `SizedMatrix{N,M}` are provided as more
convenient names for one and two dimensional `SizedArray`s. For example, to
wrap a 2x3 array `a` in a `SizedArray`, use `SizedMatrix{2,3}(a)`.
"""
struct SizedArray{S <: Tuple, T, N, M} <: StaticArray{S, T, N}
data::Array{T, M}
struct SizedArray{S<:Tuple,T,N,M,TData<:AbstractArray{T,M}} <: StaticArray{S,T,N}
data::TData

function SizedArray{S, T, N, M}(a::Array) where {S, T, N, M}
if length(a) != tuple_prod(S)
function SizedArray{S,T,N,M,TData}(a::TData) where {S,T,N,M,TData<:AbstractArray{T,M}}
if size(a) != size_to_tuple(S) && size(a) != (tuple_prod(S),)
throw(DimensionMismatch("Dimensions $(size(a)) don't match static size $S"))
end
if size(a) != size_to_tuple(S)
Base.depwarn("Construction of `SizedArray` with an `Array` of a different
size is deprecated. If you need this functionality report it at
https://github.com/JuliaArrays/StaticArrays.jl/pull/666 .
Calling `sa = reshape(a::Array, s::Size)` will actually reshape
array `a` in the future and converting `sa` back to `Array` will
return an `Array` of shape `s`.", :SizedArray)
end
new{S,T,N,M}(a)
return new{S,T,N,M,TData}(a)
end

function SizedArray{S, T, N, M}(::UndefInitializer) where {S, T, N, M}
new{S, T, N, M}(Array{T, M}(undef, size_to_tuple(S)...))
function SizedArray{S,T,N,1,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,1}}
return new{S,T,N,1,TData}(TData(undef, tuple_prod(S)))
end
function SizedArray{S,T,N,N,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,N}}
return new{S,T,N,N,TData}(TData(undef, size_to_tuple(S)...))
end
end

@inline SizedArray{S,T,N}(a::Array{T,M}) where {S,T,N,M} = SizedArray{S,T,N,M}(a)
@inline SizedArray{S,T}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)
@inline SizedArray{S}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)

@inline SizedArray{S,T,N}(::UndefInitializer) where {S,T,N} = SizedArray{S,T,N,N}(undef)
@inline SizedArray{S,T}(::UndefInitializer) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(undef)

@generated function SizedArray{S,T,N,M}(x::NTuple{L,Any}) where {S,T,N,M,L}
@inline function SizedArray{S,T,N}(
a::TData,
) where {S,T,N,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,N,M,TData}(a)
end
@inline function SizedArray{S,T}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,tuple_length(S),M,TData}(a)
end
@inline function SizedArray{S}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
return SizedArray{S,T,tuple_length(S),M,TData}(a)
end
function SizedArray{S,T,N,N}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,N,Array{T,N}}(undef)
end
function SizedArray{S,T,N,1}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,1,Vector{T}}(undef)
end
@inline function SizedArray{S,T,N}(::UndefInitializer) where {S,T,N}
return SizedArray{S,T,N,N}(undef)
end
@inline function SizedArray{S,T}(::UndefInitializer) where {S,T}
return SizedArray{S,T,tuple_length(S)}(undef)
end
@generated function (::Type{SizedArray{S,T,N,M,TData}})(x::NTuple{L,Any}) where {S,T,N,M,TData<:AbstractArray{T,M},L}
if L != tuple_prod(S)
error("Dimension mismatch")
end
Expand All @@ -53,43 +64,156 @@ end
return a
end
end

@inline SizedArray{S,T,N}(x::Tuple) where {S,T,N} = SizedArray{S,T,N,N}(x)
@inline SizedArray{S,T}(x::Tuple) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
@inline SizedArray{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
@inline function SizedArray{S,T,N,M}(x::Tuple) where {S,T,N,M}
return SizedArray{S,T,N,M,Array{T,M}}(x)
end
@inline function SizedArray{S,T,N}(x::Tuple) where {S,T,N}
return SizedArray{S,T,N,N,Array{T,N}}(x)
end
@inline function SizedArray{S,T}(x::Tuple) where {S,T}
return SizedArray{S,T,tuple_length(S)}(x)
end
@inline function SizedArray{S}(x::NTuple{L,T}) where {S,T,L}
return SizedArray{S,T}(x)
end

# Overide some problematic default behaviour
@inline convert(::Type{SA}, sa::SizedArray) where {SA<:SizedArray} = SA(sa.data)
@inline convert(::Type{SA}, sa::SA) where {SA<:SizedArray} = sa

# Back to Array (unfortunately need both convert and construct to overide other methods)
@inline Array(sa::SizedArray) = Array(sa.data)
@inline Array{T}(sa::SizedArray{S,T}) where {T,S} = Array{T}(sa.data)
@inline Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N} = Array{T,N}(sa.data)
@inline function Base.Array(sa::SizedArray{S}) where {S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function Base.Array{T}(sa::SizedArray{S,T}) where {T,S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function Base.Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N}
return Array(reshape(sa.data, size_to_tuple(S)))
end

@inline convert(::Type{Array}, sa::SizedArray) = sa.data
@inline convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S} = sa.data
@inline convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N}) where {T,S,N} = sa.data
@inline function convert(::Type{Array}, sa::SizedArray{S}) where {S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
return sa.data
end
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
return sa.data
end
@inline function convert(
::Type{Array{T,N}},
sa::SizedArray{S,T,N},
) where {T,S,N}
return Array(reshape(sa.data, size_to_tuple(S)))
end
@inline function convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N,N,Array{T,N}}) where {S,T,N}
return sa.data
end

@propagate_inbounds getindex(a::SizedArray, i::Int) = getindex(a.data, i)
@propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i)

SizedVector{S,T,M} = SizedArray{Tuple{S},T,1,M}
@inline SizedVector{S}(a::Array{T,M}) where {S,T,M} = SizedArray{Tuple{S},T,1,M}(a)
@inline SizedVector{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{Tuple{S},T,1,1}(x)
const SizedVector{S,T} = SizedArray{Tuple{S},T,1,1}

@inline function SizedVector{S}(a::TData) where {S,T,TData<:AbstractVector{T}}
return SizedArray{Tuple{S},T,1,1,TData}(a)
end
@inline function SizedVector(x::NTuple{S,T}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
@inline function SizedVector{S}(x::NTuple{S,T}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
@inline function SizedVector{S,T}(x::NTuple{S}) where {S,T}
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
end
# disambiguation
@inline function SizedVector{S}(a::StaticVector{S,T}) where {S,T}
return SizedVector{S,T}(a.data)
end

SizedMatrix{S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}
@inline SizedMatrix{S1,S2}(a::Array{T,M}) where {S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}(a)
@inline SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L} = SizedArray{Tuple{S1,S2},T,2,2}(x)
const SizedMatrix{S1,S2,T} = SizedArray{Tuple{S1,S2},T,2}

@inline function SizedMatrix{S1,S2}(
a::TData,
) where {S1,S2,T,M,TData<:AbstractArray{T,M}}
return SizedArray{Tuple{S1,S2},T,2,M,TData}(a)
end
@inline function SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L}
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
end
@inline function SizedMatrix{S1,S2,T}(x::NTuple{L}) where {S1,S2,T,L}
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
end
# disambiguation
@inline function SizedMatrix{S1,S2}(a::StaticMatrix{S1,S2,T}) where {S1,S2,T}
return SizedMatrix{S1,S2,T}(a.data)
end

Base.dataids(sa::SizedArray) = Base.dataids(sa.data)

function (::Size{S})(a::Array) where {S}
Base.depwarn("`Size{S}(a::Array)` is deprecated, use `SizedVector{N}(a)`, `SizedMatrix{N,M}(a)` or `SizedArray{Tuple{S}}(a)` instead", :Size)
SizedArray{Tuple{S...}}(a)
function promote_rule(
::Type{SizedArray{S,T,N,M,TDataA}},
::Type{SizedArray{S,U,N,M,TDataB}},
) where {S,T,U,N,M,TDataA,TDataB}
TU = promote_type(T, U)
return SizedArray{S, TU, N, M, promote_type(TDataA, TDataB)}
end

function promote_rule(
::Type{SizedArray{S,T,N,M}},
::Type{SizedArray{S,U,N,M}},
) where {S,T,U,N,M,}
TU = promote_type(T, U)
return SizedArray{S, TU, N, M}
end

function promote_rule(
::Type{SizedArray{S,T,N}},
::Type{SizedArray{S,U,N}},
) where {S,T,U,N}
TU = promote_type(T, U)
return SizedArray{S, TU, N}
end


### Code that makes views of statically sized arrays also statically sized (where possible)

@generated function new_out_size(::Type{Size}, inds...) where Size
os = []
map(Size.parameters, inds) do s, i
if i <: Integer
# dimension is fixed
elseif i <: StaticVector
push!(os, i.parameters[1].parameters[1])
elseif i == Colon || i <: Base.Slice
push!(os, s)
elseif i <: SOneTo
push!(os, i.parameters[1])
else
error("Unknown index type: $i")
end
end
return Tuple{os...}
end

function Base.view(
a::SizedArray{S},
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
) where {S}
new_size = new_out_size(S, indices...)
return SizedArray{new_size}(view(a.data, indices...))
end

function promote_rule(::Type{<:SizedArray{S,T,N,M}}, ::Type{<:SizedArray{S,U,N,M}}) where {S,T,U,N,M}
SizedArray{S,promote_type(T,U),N,M}
function Base.view(
a::MArray{S},
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
) where {S}
new_size = new_out_size(S, indices...)
view_from_invoke = invoke(view, Tuple{AbstractArray, typeof(indices).parameters...}, a, indices...)
return SizedArray{new_size}(view_from_invoke)
end
Loading