Skip to content

Commit d47c771

Browse files
authored
Generalize broadcasting to StaticMatrixLike (#1220)
* Generalize broadcasting to `StaticMatrixLike` * Bump version to 1.7.0
1 parent d419e21 commit d47c771

File tree

5 files changed

+50
-37
lines changed

5 files changed

+50
-37
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.6.5"
3+
version = "1.7.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/abstractarray.jl

+1-4
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,10 @@ length(a::Type{SA}) where {SA <: StaticArrayLike} = prod(Size(SA))::Int
88
end
99
@inline size(a::StaticArrayLike) = Tuple(Size(a))
1010

11-
Base.axes(s::StaticArray) = _axes(Size(s))
11+
Base.axes(s::StaticArrayLike) = _axes(Size(s))
1212
@pure function _axes(::Size{sizes}) where {sizes}
1313
map(SOneTo, sizes)
1414
end
15-
Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
16-
Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...)
17-
Base.axes(d::Diagonal{<:Any,<:StaticVector}) = (ax = axes(d.diag, 1); (ax, ax))
1815

1916
Base.eachindex(::IndexLinear, a::StaticArray) = SOneTo(length(a))
2017

src/broadcast.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@ import Base.Broadcast: _bcs1 # for SOneTo axis information
99
using Base.Broadcast: _bcsm
1010

1111
BroadcastStyle(::Type{<:StaticArray{<:Tuple, <:Any, N}}) where {N} = StaticArrayStyle{N}()
12-
BroadcastStyle(::Type{<:Transpose{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
13-
BroadcastStyle(::Type{<:Adjoint{<:Any, <:StaticArray}}) = StaticArrayStyle{2}()
14-
BroadcastStyle(::Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}}) = StaticArrayStyle{2}()
12+
BroadcastStyle(::Type{<:StaticMatrixLike}) = StaticArrayStyle{2}()
1513
# Precedence rules
1614
BroadcastStyle(::StaticArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
1715
DefaultArrayStyle(Val(max(M, N)))
@@ -104,10 +102,7 @@ function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
104102
return :(a[$i][$ind])
105103
end
106104

107-
isstatic(::StaticArray) = true
108-
isstatic(::Transpose{<:Any, <:StaticArray}) = true
109-
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
110-
isstatic(::Diagonal{<:Any, <:StaticArray}) = true
105+
isstatic(::StaticArrayLike) = true
111106
isstatic(_) = false
112107

113108
@inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...)

test/abstractarray.jl

+14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@ using StaticArrays, Test, LinearAlgebra
1010
@test eachindex(m) isa SOneTo
1111
end
1212

13+
@testset "axes" begin
14+
v = @SVector [1, 2, 3]
15+
@test @inferred(axes(v)) == (SOneTo(3),)
16+
for T in (Adjoint, Transpose)
17+
@test @inferred(axes(T(v))) == (SOneTo(1), SOneTo(3))
18+
end
19+
20+
m = @SMatrix [1 2; 3 4]
21+
@test @inferred(axes(m)) == (SOneTo(2), SOneTo(2))
22+
for T in (Adjoint, Transpose, Diagonal, Symmetric, Hermitian, UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
23+
@test @inferred(axes(T(m))) == (SOneTo(2), SOneTo(2))
24+
end
25+
end
26+
1327
@testset "strides" begin
1428
m1 = MArray{Tuple{3, 4, 5}}(rand(Int, 3, 4, 5))
1529
m2 = SizedArray{Tuple{3,4,5}}(rand(Int, 3, 4, 5))

test/broadcast.jl

+32-25
Original file line numberDiff line numberDiff line change
@@ -48,37 +48,44 @@ end
4848
@testset "2x2 StaticMatrix with StaticVector" begin
4949
m = @SMatrix [1 2; 3 4]
5050
v = SVector(1, 4)
51-
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 7 8]
52-
@test @inferred(m .+ v) === @SMatrix [2 3; 7 8]
53-
@test @inferred(v .+ m) === @SMatrix [2 3; 7 8]
54-
@test @inferred(m .* v) === @SMatrix [1 2; 12 16]
55-
@test @inferred(v .* m) === @SMatrix [1 2; 12 16]
56-
@test @inferred(m ./ v) === @SMatrix [1 2; 3/4 1]
57-
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4/3 1]
58-
@test @inferred(m .- v) === @SMatrix [0 1; -1 0]
59-
@test @inferred(v .- m) === @SMatrix [0 -1; 1 0]
60-
@test @inferred(m .^ v) === @SMatrix [1 2; 81 256]
61-
@test @inferred(v .^ m) === @SMatrix [1 1; 64 256]
62-
# Issue #546
63-
@test @inferred(m ./ (v .* v')) === @SMatrix [1.0 0.5; 0.75 0.25]
64-
testinf(m, v) = m ./ (v .* v')
65-
@test @inferred(testinf(m, v)) === @SMatrix [1.0 0.5; 0.75 0.25]
51+
vrep = @SMatrix [1 1; 4 4]
52+
for m in (m, Transpose(m), Adjoint(m), Diagonal(m), Symmetric(m, :U), Symmetric(m, :L), Hermitian(m, :U), Hermitian(m, :L), UpperTriangular(m), LowerTriangular(m), UnitUpperTriangular(m), UnitLowerTriangular(m))
53+
@test @inferred(broadcast(+, m, v)) === map(+, m, vrep)::SMatrix
54+
@test @inferred(m .+ v) === map(+, m, vrep)::SMatrix
55+
@test @inferred(v .+ m) === map(+, vrep, m)::SMatrix
56+
@test @inferred(m .* v) === map(*, m, vrep)::SMatrix
57+
@test @inferred(v .* m) === map(*, vrep, m)::SMatrix
58+
@test @inferred(m ./ v) === map(/, m, vrep)::SMatrix
59+
@test @inferred(v ./ m) === map(/, vrep, m)::SMatrix
60+
@test @inferred(m .- v) === map(-, m, vrep)::SMatrix
61+
@test @inferred(v .- m) === map(-, vrep, m)::SMatrix
62+
@test @inferred(m .^ v) === map(^, m, vrep)::SMatrix
63+
@test @inferred(v .^ m) === map(^, vrep, m)::SMatrix
64+
# Issue #546
65+
@test @inferred(m ./ (v .* v')) === map(/, m, v .* v')::SMatrix
66+
testinf(m, v) = m ./ (v .* v')
67+
@test @inferred(testinf(m, v)) === map(/, m, v .* v')::SMatrix
68+
end
6669
end
6770

6871
@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
6972
# Issues #197, #242: broadcast between SArray and row-like SMatrix
7073
m1 = @SMatrix [1 2; 3 4]
7174
m2 = @SMatrix [1 4]
72-
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8]
73-
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8]
74-
@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
75-
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16]
76-
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
77-
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1]
78-
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
79-
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0]
80-
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
81-
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256]
75+
m2rep = @SMatrix [1 4; 1 4]
76+
m1s = (m1, Transpose(m1), Adjoint(m1), Diagonal(m1), Symmetric(m1, :U), Symmetric(m1, :L), Hermitian(m1, :U), Hermitian(m1, :L), UpperTriangular(m1), LowerTriangular(m1), UnitUpperTriangular(m1), UnitLowerTriangular(m1))
77+
for m1 in m1s
78+
@test @inferred(broadcast(+, m1, m2)) === map(+, m1, m2rep)::SMatrix
79+
@test @inferred(m1 .+ m2) === map(+, m1, m2rep)::SMatrix
80+
@test @inferred(m2 .+ m1) === map(+, m2rep, m1)::SMatrix
81+
@test @inferred(m1 .* m2) === map(*, m1, m2rep)::SMatrix
82+
@test @inferred(m2 .* m1) === map(*, m2rep, m1)::SMatrix
83+
@test @inferred(m1 ./ m2) === map(/, m1, m2rep)::SMatrix
84+
@test @inferred(m2 ./ m1) === map(/, m2rep, m1)::SMatrix
85+
@test @inferred(m1 .- m2) === map(-, m1, m2rep)::SMatrix
86+
@test @inferred(m2 .- m1) === map(-, m2rep, m1)::SMatrix
87+
@test @inferred(m1 .^ m2) === map(^, m1, m2rep)::SMatrix
88+
end
8289
end
8390

8491
@testset "1x2 StaticMatrix with StaticVector" begin

0 commit comments

Comments
 (0)