diff --git a/src/broadcast.jl b/src/broadcast.jl index e892fd48..4c66c105 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -10,6 +10,7 @@ import Base.Broadcast: # This isn't the precise output type, just a placeholder to return from # promote_containertype, which will control dispatch to our broadcast_c. _containertype(::Type{<:StaticArray}) = StaticArray +_containertype(::Type{<:RowVector{<:Any,<:SVector}}) = StaticArray # With the above, the default promote_containertype gives reasonable defaults: # StaticArray, StaticArray -> StaticArray @@ -32,6 +33,7 @@ broadcast_indices(::Type{StaticArray}, A) = indices(A) _broadcast(f, broadcast_sizes(as...), as...) end +@inline broadcast_sizes(a::RowVector{<:Any,<:SVector}, as...) = (Size(a), broadcast_sizes(as...)...) @inline broadcast_sizes(a::StaticArray, as...) = (Size(a), broadcast_sizes(as...)...) @inline broadcast_sizes(a, as...) = (Size(), broadcast_sizes(as...)...) @inline broadcast_sizes() = () @@ -66,9 +68,9 @@ end for i = 1:length(sizes) s = sizes[i] for j = 1:length(s) - if newsize[j] == 1 || newsize[j] == s[j] + if newsize[j] == 1 newsize[j] = s[j] - else + elseif newsize[j] ≠ s[j] && s[j] ≠ 1 throw(DimensionMismatch("Tried to broadcast on inputs sized $sizes")) end end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 12188c3e..aa2af46d 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -105,7 +105,7 @@ end N = length(S) Snew = ([n==D ? 1 : S[n] for n = 1:N]...) T0 = eltype(a) - T = :((T1 = Base.promote_op(f, $T0); Base.promote_op(op, T1, T1))) + T = :((T1 = Core.Inference.return_type(f, Tuple{$T0}); Core.Inference.return_type(op, Tuple{T1,T1}))) exprs = Array{Expr}(Snew) itr = [1:n for n ∈ Snew] @@ -235,7 +235,7 @@ end @generated function _diff(::Size{S}, a::StaticArray, ::Type{Val{D}}) where {S,D} N = length(S) Snew = ([n==D ? S[n]-1 : S[n] for n = 1:N]...) - T = Base.promote_op(-, eltype(a), eltype(a)) + T = typeof(one(eltype(a)) - one(eltype(a))) exprs = Array{Expr}(Snew) itr = [1:n for n = Snew] diff --git a/test/broadcast.jl b/test/broadcast.jl index b80708a4..b22916d3 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -43,34 +43,36 @@ end end @testset "2x2 StaticMatrix with 1x2 StaticMatrix" begin + # Issues #197, #242: broadcast between SArray and row-like SMatrix m1 = @SMatrix [1 2; 3 4] m2 = @SMatrix [1 4] - @test_broken @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] #197 - @test_broken @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] #197 + @test @inferred(broadcast(+, m1, m2)) === @SMatrix [2 6; 4 8] + @test @inferred(m1 .+ m2) === @SMatrix [2 6; 4 8] @test @inferred(m2 .+ m1) === @SMatrix [2 6; 4 8] - @test_broken @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] #197 + @test @inferred(m1 .* m2) === @SMatrix [1 8; 3 16] @test @inferred(m2 .* m1) === @SMatrix [1 8; 3 16] - @test_broken @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] #197 + @test @inferred(m1 ./ m2) === @SMatrix [1 1/2; 3 1] @test @inferred(m2 ./ m1) === @SMatrix [1 2; 1/3 1] - @test_broken @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] #197 + @test @inferred(m1 .- m2) === @SMatrix [0 -2; 2 0] @test @inferred(m2 .- m1) === @SMatrix [0 2; -2 0] - @test_broken @inferred(m1 .^ m2) === @SMatrix [1 16; 1 256] #197 + @test @inferred(m1 .^ m2) === @SMatrix [1 16; 3 256] end @testset "1x2 StaticMatrix with StaticVector" begin + # Issues #197, #242: broadcast between SVector and row-like SMatrix m = @SMatrix [1 2] v = SVector(1, 4) @test @inferred(broadcast(+, m, v)) === @SMatrix [2 3; 5 6] @test @inferred(m .+ v) === @SMatrix [2 3; 5 6] - @test_broken @inferred(v .+ m) === @SMatrix [2 3; 5 6] #197 + @test @inferred(v .+ m) === @SMatrix [2 3; 5 6] @test @inferred(m .* v) === @SMatrix [1 2; 4 8] - @test_broken @inferred(v .* m) === @SMatrix [1 2; 4 8] #197 + @test @inferred(v .* m) === @SMatrix [1 2; 4 8] @test @inferred(m ./ v) === @SMatrix [1 2; 1/4 1/2] - @test_broken @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] #197 + @test @inferred(v ./ m) === @SMatrix [1 1/2; 4 2] @test @inferred(m .- v) === @SMatrix [0 1; -3 -2] - @test_broken @inferred(v .- m) === @SMatrix [0 -1; 3 2] #197 + @test @inferred(v .- m) === @SMatrix [0 -1; 3 2] @test @inferred(m .^ v) === @SMatrix [1 2; 1 16] - @test_broken @inferred(v .^ m) === @SMatrix [1 1; 4 16] #197 + @test @inferred(v .^ m) === @SMatrix [1 1; 4 16] end @testset "StaticVector with StaticVector" begin @@ -87,11 +89,11 @@ end @test @inferred(v2 .- v1) === SVector(0, 2) @test @inferred(v1 .^ v2) === SVector(1, 16) @test @inferred(v2 .^ v1) === SVector(1, 16) - # test case issue #199 + # Issue #199: broadcast with empty SArray @test @inferred(SVector(1) .+ SVector()) === SVector() - @test_broken @inferred(SVector() .+ SVector(1)) === SVector() - # test case issue #200 - @test_broken @inferred(v1 .+ v2') === @SMatrix [2 5; 3 5] + @test @inferred(SVector() .+ SVector(1)) === SVector() + # Issue #200: broadcast with RowVector + @test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6] end @testset "StaticVector with Scalar" begin