Skip to content

Commit d6e85e8

Browse files
committed
Extend sparse broadcast to VecOrMats
1 parent b8971ea commit d6e85e8

File tree

3 files changed

+42
-39
lines changed

3 files changed

+42
-39
lines changed

base/broadcast.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using Base: linearindices, tail, OneTo, to_shape,
99
import Base: broadcast, broadcast!
1010
export broadcast_getindex, broadcast_setindex!, dotview
1111

12+
immutable OneOrTwoD end
13+
typealias ArrayType Union{Type{AbstractArray}, Type{OneOrTwoD}}
1214
typealias ScalarType Union{Type{Any}, Type{Nullable}}
1315

1416
## Broadcasting utilities ##
@@ -24,16 +26,19 @@ broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X)
2426
_containertype(::Type) = Any
2527
_containertype{T<:Ptr}(::Type{T}) = Any
2628
_containertype{T<:Tuple}(::Type{T}) = Tuple
27-
_containertype{T<:Ref}(::Type{T}) = Array
28-
_containertype{T<:AbstractArray}(::Type{T}) = Array
29+
_containertype{T<:Ref}(::Type{T}) = AbstractArray
30+
_containertype{T<:AbstractArray}(::Type{T}) = AbstractArray
31+
_containertype{T<:AbstractVecOrMat}(::Type{T}) = OneOrTwoD
2932
_containertype{T<:Nullable}(::Type{T}) = Nullable
3033
containertype(x) = _containertype(typeof(x))
3134
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
3235
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))
3336

34-
promote_containertype(::Type{Array}, ::Type{Array}) = Array
35-
promote_containertype(::Type{Array}, ct) = Array
36-
promote_containertype(ct, ::Type{Array}) = Array
37+
promote_containertype(ct1, ct2) = AbstractArray
38+
promote_containertype(::ArrayType, ::Type{Tuple}) = AbstractArray
39+
promote_containertype(::Type{Tuple}, ::ArrayType) = AbstractArray
40+
promote_containertype(::ArrayType, ::ScalarType) = AbstractArray
41+
promote_containertype(::ScalarType, ::ArrayType) = AbstractArray
3742
promote_containertype(::Type{Tuple}, ::ScalarType) = Tuple
3843
promote_containertype(::ScalarType, ::Type{Tuple}) = Tuple
3944
promote_containertype(::Type{Any}, ::Type{Nullable}) = Nullable
@@ -46,8 +51,8 @@ broadcast_indices() = ()
4651
broadcast_indices(A) = broadcast_indices(containertype(A), A)
4752
broadcast_indices(::ScalarType, A) = ()
4853
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
49-
broadcast_indices(::Type{Array}, A::Ref) = ()
50-
broadcast_indices(::Type{Array}, A) = indices(A)
54+
broadcast_indices(::Type{AbstractArray}, A::Ref) = ()
55+
broadcast_indices(::ArrayType, A) = indices(A)
5156
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
5257
# shape (i.e., tuple-of-indices) inputs
5358
broadcast_shape(shape::Tuple) = shape
@@ -125,7 +130,7 @@ end
125130
end
126131

127132
Base.@propagate_inbounds _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
128-
Base.@propagate_inbounds _broadcast_getindex(::Type{Array}, A::Ref, I) = A[]
133+
Base.@propagate_inbounds _broadcast_getindex(::Type{AbstractArray}, A::Ref, I) = A[]
129134
Base.@propagate_inbounds _broadcast_getindex(::ScalarType, A, I) = A
130135
Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
131136

@@ -283,7 +288,8 @@ eltypestuple(a, b...) = (Base.@_pure_meta; Tuple{eltypestuple(a).types..., eltyp
283288
_broadcast_eltype(f, A, Bs...) = Base._return_type(f, eltypestuple(A, Bs...))
284289

285290
# broadcast methods that dispatch on the type of the final container
286-
@inline function broadcast_c(f, ::Type{Array}, A, Bs...)
291+
@inline broadcast_c(f, ::Type{OneOrTwoD}, A, Bs...) = broadcast_c(f, AbstractArray, A, Bs...)
292+
@inline function broadcast_c(f, ::Type{AbstractArray}, A, Bs...)
287293
T = _broadcast_eltype(f, A, Bs...)
288294
shape = broadcast_indices(A, Bs...)
289295
iter = CartesianRange(shape)

base/sparse/higherorderfns.jl

+24-27
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ module HigherOrderFns
55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
77
import Base: map, map!, broadcast, broadcast!
8-
import Base.Broadcast: _containertype, promote_containertype,
9-
broadcast_indices, broadcast_c, broadcast_c!
8+
import Base.Broadcast: ScalarType, OneOrTwoD, _containertype,
9+
promote_containertype, broadcast_indices, broadcast_c, broadcast_c!
1010

1111
using Base: front, tail, to_shape
1212
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
@@ -848,21 +848,23 @@ broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
848848
# broadcast container type promotion for combinations of sparse arrays and other types
849849
_containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
850850
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
851-
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
852-
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
851+
promote_containertype(::Type{AbstractSparseArray}, ::ScalarType) = AbstractSparseArray
852+
promote_containertype(::ScalarType, ::Type{AbstractSparseArray}) = AbstractSparseArray
853+
promote_containertype(::Type{AbstractSparseArray}, ::Type{OneOrTwoD}) = AbstractSparseArray
854+
promote_containertype(::Type{OneOrTwoD}, ::Type{AbstractSparseArray}) = AbstractSparseArray
853855
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
854-
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
855-
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
856-
promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array
857-
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array
856+
promote_containertype(::Type{AbstractSparseArray}, ::Type{AbstractArray}) = AbstractArray
857+
promote_containertype(::Type{AbstractArray}, ::Type{AbstractSparseArray}) = AbstractArray
858+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = AbstractArray
859+
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = AbstractArray
858860

859861
# broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
860862
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
861-
parevalf, passedargstup = capturescalars(f, mixedargs)
863+
parevalf, passedargstup = capturescalars(f, map(_sparsifystructured, mixedargs))
862864
return broadcast(parevalf, passedargstup...)
863865
end
864866
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
865-
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
867+
parevalf, passedsrcargstup = capturescalars(f, map(_sparsifystructured, mixedsrcargs))
866868
return broadcast!(parevalf, dest, passedsrcargstup...)
867869
end
868870
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
@@ -900,32 +902,27 @@ _containertype{T<:Diagonal}(::Type{T}) = StructuredArray
900902
_containertype{T<:Bidiagonal}(::Type{T}) = StructuredArray
901903
_containertype{T<:Tridiagonal}(::Type{T}) = StructuredArray
902904
_containertype{T<:SymTridiagonal}(::Type{T}) = StructuredArray
903-
promote_containertype(::Type{StructuredArray}, ::Type{StructuredArray}) = StructuredArray
904905
# combinations involving sparse arrays continue in the structured array funnel
905-
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = StructuredArray
906-
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = StructuredArray
906+
promote_containertype(::Type{StructuredArray}, ::Type{AbstractSparseArray}) = AbstractSparseArray
907+
promote_containertype(::Type{AbstractSparseArray}, ::Type{StructuredArray}) = AbstractSparseArray
908+
promote_containertype(::Type{StructuredArray}, ::Type{OneOrTwoD}) = AbstractSparseArray
909+
promote_containertype(::Type{OneOrTwoD}, ::Type{StructuredArray}) = AbstractSparseArray
907910
# combinations involving scalars continue in the structured array funnel
908-
promote_containertype(::Type{StructuredArray}, ::Type{Any}) = StructuredArray
909-
promote_containertype(::Type{Any}, ::Type{StructuredArray}) = StructuredArray
911+
promote_containertype(::Type{StructuredArray}, ::ScalarType) = AbstractSparseArray
912+
promote_containertype(::ScalarType, ::Type{StructuredArray}) = AbstractSparseArray
910913
# combinations involving arrays divert to the generic array code
911-
promote_containertype(::Type{StructuredArray}, ::Type{Array}) = Array
912-
promote_containertype(::Type{Array}, ::Type{StructuredArray}) = Array
914+
promote_containertype(::Type{StructuredArray}, ::Type{AbstractArray}) = AbstractArray
915+
promote_containertype(::Type{AbstractArray}, ::Type{StructuredArray}) = AbstractArray
913916
# combinations involving tuples divert to the generic array code
914-
promote_containertype(::Type{StructuredArray}, ::Type{Tuple}) = Array
915-
promote_containertype(::Type{Tuple}, ::Type{StructuredArray}) = Array
917+
promote_containertype(::Type{StructuredArray}, ::Type{Tuple}) = AbstractArray
918+
promote_containertype(::Type{Tuple}, ::Type{StructuredArray}) = AbstractArray
916919

917-
# for combinations involving sparse/structured arrays and scalars only,
918-
# promote all structured arguments to sparse and then rebroadcast
919-
@inline broadcast_c{N}(f, ::Type{StructuredArray}, As::Vararg{Any,N}) =
920-
broadcast(f, map(_sparsifystructured, As)...)
921-
@inline broadcast_c!{N}(f, ::Type{AbstractSparseArray}, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
922-
broadcast!(f, C, _sparsifystructured(B), map(_sparsifystructured, As)...)
923-
@inline broadcast_c!{N}(f, CT::Type, ::Type{StructuredArray}, C, B, As::Vararg{Any,N}) =
924-
broadcast_c!(f, CT, Array, C, B, As...)
925920
@inline _sparsifystructured(S::SymTridiagonal) = SparseMatrixCSC(S)
926921
@inline _sparsifystructured(T::Tridiagonal) = SparseMatrixCSC(T)
927922
@inline _sparsifystructured(B::Bidiagonal) = SparseMatrixCSC(B)
928923
@inline _sparsifystructured(D::Diagonal) = SparseMatrixCSC(D)
924+
@inline _sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
925+
@inline _sparsifystructured(V::AbstractVector) = SparseVector(V)
929926
@inline _sparsifystructured(A::AbstractSparseArray) = A
930927
@inline _sparsifystructured(x) = x
931928

test/broadcast.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,10 @@ Base.size(A::Array19745) = size(A.data)
393393

394394
Base.Broadcast._containertype{T<:Array19745}(::Type{T}) = Array19745
395395

396+
# This way of defining promote_containertype methods is discouraged. The recommended
397+
# way is by defining definitions for combinations of tight containers types
396398
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array19745}) = Array19745
397-
Base.Broadcast.promote_containertype(::Type{Array19745}, ::Type{Array}) = Array19745
398399
Base.Broadcast.promote_containertype(::Type{Array19745}, ct) = Array19745
399-
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{Array19745}) = Array19745
400400
Base.Broadcast.promote_containertype(ct, ::Type{Array19745}) = Array19745
401401

402402
Base.Broadcast.broadcast_indices(::Type{Array19745}, A) = indices(A)
@@ -406,7 +406,7 @@ getfield19745(x::Array19745) = x.data
406406
getfield19745(x) = x
407407

408408
Base.Broadcast.broadcast_c(f, ::Type{Array19745}, A, Bs...) =
409-
Array19745(Base.Broadcast.broadcast_c(f, Array, getfield19745(A), map(getfield19745, Bs)...))
409+
Array19745(Base.Broadcast.broadcast_c(f, AbstractArray, getfield19745(A), map(getfield19745, Bs)...))
410410

411411
@testset "broadcasting for custom AbstractArray" begin
412412
a = randn(10)

0 commit comments

Comments
 (0)