Skip to content

Commit 3a49b3a

Browse files
committed
Fix several broadcast issues
1 parent da1a371 commit 3a49b3a

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

src/broadcast.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Base.Broadcast:
1010
# This isn't the precise output type, just a placeholder to return from
1111
# promote_containertype, which will control dispatch to our broadcast_c.
1212
_containertype(::Type{<:StaticArray}) = StaticArray
13+
_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray
1314

1415
# With the above, the default promote_containertype gives reasonable defaults:
1516
# StaticArray, StaticArray -> StaticArray
@@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A)
3233
_broadcast(f, broadcast_sizes(as...), as...)
3334
end
3435

36+
@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...)
3537
@inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...)
3638
@inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...)
3739
@inline broadcast_sizes() = ()
@@ -66,9 +68,9 @@ end
6668
for i = 1:length(sizes)
6769
s = sizes[i]
6870
for j = 1:length(s)
69-
if newsize[j] == 1 || newsize[j] == s[j]
71+
if newsize[j] == 1
7072
newsize[j] = s[j]
71-
else
73+
elseif newsize[j] s[j] && s[j] 1
7274
throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes"))
7375
end
7476
end

test/broadcast.jl

+13-13
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,32 @@ end
4545
@testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin
4646
m1 = @SMatrix [1 2; 3 4]
4747
m2 = @SMatrix [1 4]
48-
@test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
49-
@test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
48+
@test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197
49+
@test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197
5050
@test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8]
51-
@test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
51+
@test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197
5252
@test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16]
53-
@test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
53+
@test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197
5454
@test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1]
55-
@test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
55+
@test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197
5656
@test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0]
57-
@test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197
57+
@test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256] #197
5858
end
5959

6060
@testset "1x2 StaticMatrix with StaticVector" begin
6161
m = @SMatrix [1 2]
6262
v = SVector(1, 4)
6363
@test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6]
6464
@test @inferred(m .+ v) === @SMatrix [2 3; 5 6]
65-
@test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
65+
@test @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197
6666
@test @inferred(m .* v) === @SMatrix [1 2; 4 8]
67-
@test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
67+
@test @inferred(v .* m) === @SMatrix [1 2; 4 8] #197
6868
@test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2]
69-
@test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
69+
@test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197
7070
@test @inferred(m .- v) === @SMatrix [0 1; -3 -2]
71-
@test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
71+
@test @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197
7272
@test @inferred(m .^ v) === @SMatrix [1 2; 1 16]
73-
@test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
73+
@test @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197
7474
end
7575

7676
@testset "StaticVector with StaticVector" begin
@@ -89,9 +89,9 @@ end
8989
@test @inferred(v2 .^ v1) === SVector(1, 16)
9090
# test case issue #199
9191
@test @inferred(SVector(1) .+ SVector()) === SVector()
92-
@test_broken @inferred(SVector() .+ SVector(1)) === SVector()
92+
@test @inferred(SVector() .+ SVector(1)) === SVector()
9393
# test case issue #200
94-
@test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5]
94+
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
9595
end
9696

9797
@testset "StaticVector with Scalar" begin

0 commit comments

Comments
 (0)