@@ -5,8 +5,8 @@ module HigherOrderFns
5
5
# This module provides higher order functions specialized for sparse arrays,
6
6
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7
7
import Base: map, map!, broadcast, broadcast!
8
- import Base. Broadcast: _containertype, promote_containertype ,
9
- broadcast_indices, broadcast_c, broadcast_c!
8
+ import Base. Broadcast: ScalarType, OneOrTwoD, _containertype ,
9
+ promote_containertype, broadcast_indices, broadcast_c, broadcast_c!
10
10
11
11
using Base: front, tail, to_shape
12
12
using .. SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
@@ -848,21 +848,23 @@ broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
848
848
# broadcast container type promotion for combinations of sparse arrays and other types
849
849
_containertype {T<:SparseVecOrMat} (:: Type{T} ) = AbstractSparseArray
850
850
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
851
- promote_containertype (:: Type{Any} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
852
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{Any} ) = AbstractSparseArray
851
+ promote_containertype (:: Type{AbstractSparseArray} , :: ScalarType ) = AbstractSparseArray
852
+ promote_containertype (:: ScalarType , :: Type{AbstractSparseArray} ) = AbstractSparseArray
853
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{OneOrTwoD} ) = AbstractSparseArray
854
+ promote_containertype (:: Type{OneOrTwoD} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
853
855
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
854
- promote_containertype (:: Type{Array } , :: Type{AbstractSparseArray } ) = Array
855
- promote_containertype (:: Type{Tuple } , :: Type{AbstractSparseArray} ) = Array
856
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{Array } ) = Array
857
- promote_containertype (:: Type{AbstractSparseArray } , :: Type{Tuple } ) = Array
856
+ promote_containertype (:: Type{AbstractSparseArray } , :: Type{AbstractArray } ) = AbstractArray
857
+ promote_containertype (:: Type{AbstractArray } , :: Type{AbstractSparseArray} ) = AbstractArray
858
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{Tuple } ) = AbstractArray
859
+ promote_containertype (:: Type{Tuple } , :: Type{AbstractSparseArray } ) = AbstractArray
858
860
859
861
# broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
860
862
@inline function broadcast_c {N} (f, :: Type{AbstractSparseArray} , mixedargs:: Vararg{Any,N} )
861
- parevalf, passedargstup = capturescalars (f, mixedargs)
863
+ parevalf, passedargstup = capturescalars (f, map (_sparsifystructured, mixedargs) )
862
864
return broadcast (parevalf, passedargstup... )
863
865
end
864
866
@inline function broadcast_c! {N} (f, :: Type{AbstractSparseArray} , dest:: SparseVecOrMat , mixedsrcargs:: Vararg{Any,N} )
865
- parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
867
+ parevalf, passedsrcargstup = capturescalars (f, map (_sparsifystructured, mixedsrcargs) )
866
868
return broadcast! (parevalf, dest, passedsrcargstup... )
867
869
end
868
870
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
@@ -900,32 +902,27 @@ _containertype{T<:Diagonal}(::Type{T}) = StructuredArray
900
902
_containertype {T<:Bidiagonal} (:: Type{T} ) = StructuredArray
901
903
_containertype {T<:Tridiagonal} (:: Type{T} ) = StructuredArray
902
904
_containertype {T<:SymTridiagonal} (:: Type{T} ) = StructuredArray
903
- promote_containertype (:: Type{StructuredArray} , :: Type{StructuredArray} ) = StructuredArray
904
905
# combinations involving sparse arrays continue in the structured array funnel
905
- promote_containertype (:: Type{StructuredArray} , :: Type{AbstractSparseArray} ) = StructuredArray
906
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{StructuredArray} ) = StructuredArray
906
+ promote_containertype (:: Type{StructuredArray} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
907
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{StructuredArray} ) = AbstractSparseArray
908
+ promote_containertype (:: Type{StructuredArray} , :: Type{OneOrTwoD} ) = AbstractSparseArray
909
+ promote_containertype (:: Type{OneOrTwoD} , :: Type{StructuredArray} ) = AbstractSparseArray
907
910
# combinations involving scalars continue in the structured array funnel
908
- promote_containertype (:: Type{StructuredArray} , :: Type{Any} ) = StructuredArray
909
- promote_containertype (:: Type{Any} , :: Type{StructuredArray} ) = StructuredArray
911
+ promote_containertype (:: Type{StructuredArray} , :: ScalarType ) = AbstractSparseArray
912
+ promote_containertype (:: ScalarType , :: Type{StructuredArray} ) = AbstractSparseArray
910
913
# combinations involving arrays divert to the generic array code
911
- promote_containertype (:: Type{StructuredArray} , :: Type{Array } ) = Array
912
- promote_containertype (:: Type{Array } , :: Type{StructuredArray} ) = Array
914
+ promote_containertype (:: Type{StructuredArray} , :: Type{AbstractArray } ) = AbstractArray
915
+ promote_containertype (:: Type{AbstractArray } , :: Type{StructuredArray} ) = AbstractArray
913
916
# combinations involving tuples divert to the generic array code
914
- promote_containertype (:: Type{StructuredArray} , :: Type{Tuple} ) = Array
915
- promote_containertype (:: Type{Tuple} , :: Type{StructuredArray} ) = Array
917
+ promote_containertype (:: Type{StructuredArray} , :: Type{Tuple} ) = AbstractArray
918
+ promote_containertype (:: Type{Tuple} , :: Type{StructuredArray} ) = AbstractArray
916
919
917
- # for combinations involving sparse/structured arrays and scalars only,
918
- # promote all structured arguments to sparse and then rebroadcast
919
- @inline broadcast_c {N} (f, :: Type{StructuredArray} , As:: Vararg{Any,N} ) =
920
- broadcast (f, map (_sparsifystructured, As)... )
921
- @inline broadcast_c! {N} (f, :: Type{AbstractSparseArray} , :: Type{StructuredArray} , C, B, As:: Vararg{Any,N} ) =
922
- broadcast! (f, C, _sparsifystructured (B), map (_sparsifystructured, As)... )
923
- @inline broadcast_c! {N} (f, CT:: Type , :: Type{StructuredArray} , C, B, As:: Vararg{Any,N} ) =
924
- broadcast_c! (f, CT, Array, C, B, As... )
925
920
@inline _sparsifystructured (S:: SymTridiagonal ) = SparseMatrixCSC (S)
926
921
@inline _sparsifystructured (T:: Tridiagonal ) = SparseMatrixCSC (T)
927
922
@inline _sparsifystructured (B:: Bidiagonal ) = SparseMatrixCSC (B)
928
923
@inline _sparsifystructured (D:: Diagonal ) = SparseMatrixCSC (D)
924
+ @inline _sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
925
+ @inline _sparsifystructured (V:: AbstractVector ) = SparseVector (V)
929
926
@inline _sparsifystructured (A:: AbstractSparseArray ) = A
930
927
@inline _sparsifystructured (x) = x
931
928
0 commit comments