@@ -5,12 +5,11 @@ module HigherOrderFns
5
5
# This module provides higher order functions specialized for sparse arrays,
6
6
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7
7
import Base: map, map!, broadcast, broadcast!
8
- import Base. Broadcast: _containertype, promote_containertype,
9
- broadcast_indices, broadcast_c, broadcast_c!
10
8
11
9
using Base: front, tail, to_shape
12
10
using .. SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
13
11
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
12
+ using Base. Broadcast: BroadcastStyle
14
13
15
14
# This module is organized as follows:
16
15
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
@@ -23,10 +22,10 @@ using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
23
22
# (7) Define _broadcast_[not]zeropres! specialized for a single (input) sparse vector/matrix.
24
23
# (8) Define _broadcast_[not]zeropres! specialized for a pair of (input) sparse vectors/matrices.
25
24
# (9) Define general _broadcast_[not]zeropres! capable of handling >2 (input) sparse vectors/matrices.
26
- # (10) Define ( broadcast[!]) methods handling combinations of broadcast scalars and sparse vectors/matrices.
27
- # (11) Define ( broadcast[!]) methods handling combinations of scalars, sparse vectors/matrices,
25
+ # (10) Define broadcast methods handling combinations of broadcast scalars and sparse vectors/matrices.
26
+ # (11) Define broadcast[!] methods handling combinations of scalars, sparse vectors/matrices,
28
27
# structured matrices, and one- and two-dimensional Arrays.
29
- # (12) Define ( map[!]) methods handling combinations of sparse and structured matrices.
28
+ # (12) Define map[!] methods handling combinations of sparse and structured matrices.
30
29
31
30
32
31
# (1) The definitions below provide a common interface to sparse vectors and matrices
@@ -85,7 +84,7 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
85
84
fofzeros = f (_zeros_eltypes (A, Bs... )... )
86
85
fpreszeros = _iszero (fofzeros)
87
86
maxnnzC = fpreszeros ? min (length (A), _sumnnzs (A, Bs... )) : length (A)
88
- entrytypeC = Base. Broadcast. _broadcast_eltype (f, A, Bs... )
87
+ entrytypeC = Base. Broadcast. combine_eltypes (f, A, Bs... )
89
88
indextypeC = _promote_indtype (A, Bs... )
90
89
C = _allocres (size (A), indextypeC, entrytypeC, maxnnzC)
91
90
return fpreszeros ? _map_zeropres! (f, C, A, Bs... ) :
@@ -126,8 +125,8 @@ function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMa
126
125
fofzeros = f (_zeros_eltypes (A, Bs... )... )
127
126
fpreszeros = _iszero (fofzeros)
128
127
indextypeC = _promote_indtype (A, Bs... )
129
- entrytypeC = Base. Broadcast. _broadcast_eltype (f, A, Bs... )
130
- shapeC = to_shape (Base. Broadcast. broadcast_indices (A, Bs... ))
128
+ entrytypeC = Base. Broadcast. combine_eltypes (f, A, Bs... )
129
+ shapeC = to_shape (Base. Broadcast. combine_indices (A, Bs... ))
131
130
maxnnzC = fpreszeros ? _checked_maxnnzbcres (shapeC, A, Bs... ) : _densennz (shapeC)
132
131
C = _allocres (shapeC, indextypeC, entrytypeC, maxnnzC)
133
132
return fpreszeros ? _broadcast_zeropres! (f, C, A, Bs... ) :
@@ -897,29 +896,40 @@ end
897
896
end
898
897
899
898
900
- # (10) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
899
+ # (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices
901
900
902
- # broadcast shape promotion for combinations of sparse arrays and other types
903
- broadcast_indices (:: Type{AbstractSparseArray} , A) = indices (A)
904
901
# broadcast container type promotion for combinations of sparse arrays and other types
905
- _containertype (:: Type{<:SparseVecOrMat} ) = AbstractSparseArray
906
- # combinations of sparse arrays with broadcast scalars should yield sparse arrays
907
- promote_containertype (:: Type{Any} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
908
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{Any} ) = AbstractSparseArray
909
- # combinations of sparse arrays with tuples should divert to the generic AbstractArray broadcast code
910
- # (we handle combinations involving dense vectors/matrices below)
911
- promote_containertype (:: Type{Tuple} , :: Type{AbstractSparseArray} ) = Array
912
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{Tuple} ) = Array
902
+ struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
903
+ struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
904
+ Broadcast. BroadcastStyle (:: Type{<:SparseVector} ) = SparseVecStyle ()
905
+ Broadcast. BroadcastStyle (:: Type{<:SparseMatrixCSC} ) = SparseMatStyle ()
906
+ const SPVM = Union{SparseVecStyle,SparseMatStyle}
913
907
914
- # broadcast[!] entry points for combinations of sparse arrays and other (scalar) types
915
- @inline function broadcast_c (f, :: Type{AbstractSparseArray} , mixedargs:: Vararg{Any,N} ) where N
908
+ # SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
909
+ # SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
910
+ # Fall back to DefaultArrayStyle for higher dimensionality.
911
+ SparseVecStyle (:: Val{0} ) = SparseVecStyle ()
912
+ SparseVecStyle (:: Val{1} ) = SparseVecStyle ()
913
+ SparseVecStyle (:: Val{2} ) = SparseMatStyle ()
914
+ SparseVecStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
915
+ SparseMatStyle (:: Val{0} ) = SparseMatStyle ()
916
+ SparseMatStyle (:: Val{1} ) = SparseMatStyle ()
917
+ SparseMatStyle (:: Val{2} ) = SparseMatStyle ()
918
+ SparseMatStyle (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
919
+
920
+ Broadcast. BroadcastStyle (:: SparseMatStyle , :: SparseVecStyle ) = SparseMatStyle ()
921
+
922
+ # Tuples promote to dense
923
+ Broadcast. BroadcastStyle (:: SparseVecStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {1} ()
924
+ Broadcast. BroadcastStyle (:: SparseMatStyle , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
925
+
926
+ # broadcast entry points for combinations of sparse arrays and other (scalar) types
927
+ function broadcast (f, :: SPVM , :: Void , :: Void , mixedargs:: Vararg{Any,N} ) where N
916
928
parevalf, passedargstup = capturescalars (f, mixedargs)
917
929
return broadcast (parevalf, passedargstup... )
918
930
end
919
- @inline function broadcast_c! (f, :: Type{AbstractSparseArray} , dest:: SparseVecOrMat , mixedsrcargs:: Vararg{Any,N} ) where N
920
- parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
921
- return broadcast! (parevalf, dest, passedsrcargstup... )
922
- end
931
+ # for broadcast! see (11)
932
+
923
933
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
924
934
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
925
935
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
@@ -966,99 +976,61 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
966
976
# for combinations involving only scalars, sparse arrays, structured matrices, and dense
967
977
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
968
978
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
969
- #
970
- # this requires three steps: segregate combinations to promote to sparse via Broadcast's
971
- # containertype promotion and dispatch layer (broadcast_c[!], containertype,
972
- # promote_containertype), separate ambiguous cases from the preceding dispatch
973
- # layer in sparse broadcast's internal containertype promotion and dispatch layer
974
- # (spbroadcast_c[!], spcontainertype, promote_spcontainertype), and then promote
975
- # arguments to sparse as appropriate and rebroadcast.
976
-
977
979
978
- # first (Broadcast containertype) dispatch layer's promotion logic
979
- struct PromoteToSparse end
980
-
981
- # broadcast containertype definitions for structured matrices
980
+ struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
982
981
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
983
- _containertype (:: Type{<:StructuredMatrix} ) = PromoteToSparse
984
- broadcast_indices (:: Type{PromoteToSparse} , A) = indices (A)
982
+ Broadcast. BroadcastStyle (:: Type{<:StructuredMatrix} ) = PromoteToSparse ()
985
983
986
- # combinations explicitly involving Tuples and PromoteToSparse collections
987
- # divert to the generic AbstractArray broadcast code
988
- promote_containertype (:: Type{PromoteToSparse} , :: Type{Tuple} ) = Array
989
- promote_containertype (:: Type{Tuple} , :: Type{PromoteToSparse} ) = Array
990
- # combinations involving scalars and PromoteToSparse collections continue in the promote-to-sparse funnel
991
- promote_containertype (:: Type{PromoteToSparse} , :: Type{Any} ) = PromoteToSparse
992
- promote_containertype (:: Type{Any} , :: Type{PromoteToSparse} ) = PromoteToSparse
993
- # combinations involving sparse arrays and PromoteToSparse collections continue in the promote-to-sparse funnel
994
- promote_containertype (:: Type{PromoteToSparse} , :: Type{AbstractSparseArray} ) = PromoteToSparse
995
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{PromoteToSparse} ) = PromoteToSparse
996
- # combinations involving Arrays and PromoteToSparse collections continue in the promote-to-sparse funnel
997
- promote_containertype (:: Type{PromoteToSparse} , :: Type{Array} ) = PromoteToSparse
998
- promote_containertype (:: Type{Array} , :: Type{PromoteToSparse} ) = PromoteToSparse
999
- # combinations involving Arrays and sparse arrays continue in the promote-to-sparse funnel
1000
- promote_containertype (:: Type{AbstractSparseArray} , :: Type{Array} ) = PromoteToSparse
1001
- promote_containertype (:: Type{Array} , :: Type{AbstractSparseArray} ) = PromoteToSparse
984
+ PromoteToSparse (:: Val{0} ) = PromoteToSparse ()
985
+ PromoteToSparse (:: Val{1} ) = PromoteToSparse ()
986
+ PromoteToSparse (:: Val{2} ) = PromoteToSparse ()
987
+ PromoteToSparse (:: Val{N} ) where N = Broadcast. DefaultArrayStyle {N} ()
1002
988
1003
- # second (internal sparse broadcast containertype) dispatch layer's promotion logic
1004
- # mostly just disambiguates Array from the main containertype promotion mechanism
1005
- # AbstractArray serves as a marker to shunt to the generic AbstractArray broadcast code
1006
- _spcontainertype (x) = _containertype (x)
1007
- _spcontainertype (:: Type{<:Vector} ) = Vector
1008
- _spcontainertype (:: Type{<:Matrix} ) = Matrix
1009
- _spcontainertype (:: Type{<:RowVector} ) = Matrix
1010
- _spcontainertype (:: Type{<:Ref} ) = AbstractArray
1011
- _spcontainertype (:: Type{<:AbstractArray} ) = AbstractArray
1012
- # need the following two methods to override the immediately preceding method
1013
- _spcontainertype (:: Type{<:StructuredMatrix} ) = PromoteToSparse
1014
- _spcontainertype (:: Type{<:SparseVecOrMat} ) = AbstractSparseArray
1015
- spcontainertype (x) = _spcontainertype (typeof (x))
1016
- spcontainertype (ct1, ct2) = promote_spcontainertype (spcontainertype (ct1), spcontainertype (ct2))
1017
- @inline spcontainertype (ct1, ct2, cts... ) = promote_spcontainertype (spcontainertype (ct1), spcontainertype (ct2, cts... ))
989
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: SPVM ) = PromoteToSparse ()
990
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
1018
991
1019
- promote_spcontainertype (:: Type{T} , :: Type{T} ) where {T} = T
1020
- # combinations involving AbstractArrays and/or Tuples divert to the generic AbstractArray broadcast code
1021
- DivertToAbsArrayBC = Union{Type{AbstractArray},Type{Tuple}}
1022
- promote_spcontainertype (:: DivertToAbsArrayBC , ct) = AbstractArray
1023
- promote_spcontainertype (ct, :: DivertToAbsArrayBC ) = AbstractArray
1024
- promote_spcontainertype (:: DivertToAbsArrayBC , :: DivertToAbsArrayBC ) = AbstractArray
1025
- # combinations involving scalars, sparse arrays, structured matrices (PromoteToSparse),
1026
- # dense vectors/matrices, and PromoteToSparse collections continue in the promote-to-sparse funnel
1027
- FunnelToSparseBC = Union{Type{Any},Type{Vector},Type{Matrix},Type{PromoteToSparse},Type{AbstractSparseArray}}
1028
- promote_spcontainertype (:: FunnelToSparseBC , :: FunnelToSparseBC ) = PromoteToSparse
992
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{0} ) = PromoteToSparse ()
993
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{1} ) = PromoteToSparse ()
994
+ Broadcast. BroadcastStyle (:: SPVM , :: Broadcast.DefaultArrayStyle{2} ) = PromoteToSparse ()
995
+ Broadcast. BroadcastStyle (:: PromoteToSparse , :: Broadcast.Style{Tuple} ) = Broadcast. DefaultArrayStyle {2} ()
1029
996
997
+ broadcast (f, :: PromoteToSparse , :: Void , :: Void , As:: Vararg{Any,N} ) where {N} =
998
+ broadcast (f, map (_sparsifystructured, As)... )
1030
999
1031
- # first (Broadcast containertype) dispatch layer
1032
- # (broadcast_c[!], containertype, promote_containertype)
1033
- @inline broadcast_c (f, :: Type{PromoteToSparse} , As:: Vararg{Any,N} ) where {N} =
1034
- spbroadcast_c (f, spcontainertype (As... ), As... )
1035
- @inline broadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1036
- spbroadcast_c! (f, AbstractSparseArray, spcontainertype (B, As... ), C, B, As... )
1037
- # where destination C is not an AbstractSparseArray, divert to generic AbstractArray broadcast code
1038
- @inline broadcast_c! (f, CT:: Type , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1039
- broadcast_c! (f, CT, Array, C, B, As... )
1000
+ # ambiguity resolution
1001
+ broadcast! (:: typeof (identity), dest:: SparseVecOrMat , x:: Number ) =
1002
+ fill! (dest, x)
1003
+ broadcast! (f, dest:: SparseVecOrMat , x:: Number... ) =
1004
+ spbroadcast_args! (f, dest, SPVM, mixedsrcargs... )
1040
1005
1041
- # second (internal sparse broadcast containertype) dispatch layer
1042
- # (spbroadcast_c[!], spcontainertype, promote_spcontainertype)
1043
- @inline spbroadcast_c (f, :: Type{PromoteToSparse} , As:: Vararg{Any,N} ) where {N} =
1044
- broadcast (f, map (_sparsifystructured, As)... )
1045
- @inline spbroadcast_c (f, :: Type{AbstractArray} , As:: Vararg{Any,N} ) where {N} =
1046
- broadcast_c (f, Array, As... )
1047
- @inline spbroadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{PromoteToSparse} , C, B, As:: Vararg{Any,N} ) where {N} =
1048
- broadcast! (f, C, _sparsifystructured (B), map (_sparsifystructured, As)... )
1049
- @inline spbroadcast_c! (f, :: Type{AbstractSparseArray} , :: Type{AbstractArray} , C, B, As:: Vararg{Any,N} ) where {N} =
1050
- broadcast_c! (f, Array, Array, C, B, As... )
1006
+ # For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
1007
+ # the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
1008
+ # we can handle it here, otherwise see below for the promotion machinery.
1009
+ broadcast! (f, dest:: SparseVecOrMat , mixedsrcargs:: Vararg{Any,N} ) where N =
1010
+ spbroadcast_args! (f, dest, Broadcast. combine_styles (mixedsrcargs... ), mixedsrcargs... )
1011
+ function spbroadcast_args! (f, dest, :: Type{SPVM} , mixedsrcargs:: Vararg{Any,N} ) where N
1012
+ # mixedsrcargs contains nothing but SparseVecOrMat and scalars
1013
+ parevalf, passedsrcargstup = capturescalars (f, mixedsrcargs)
1014
+ return broadcast! (parevalf, dest, passedsrcargstup... )
1015
+ end
1016
+ function spbroadcast_args! (f, dest, :: PromoteToSparse , mixedsrcargs:: Vararg{Any,N} ) where N
1017
+ broadcast! (f, dest, map (_sparsifystructured, mixedsrcargs)... )
1018
+ end
1019
+ function spbroadcast_args! (f, dest, :: Any , mixedsrcargs:: Vararg{Any,N} ) where N
1020
+ # Fallback. From a performance perspective would it be best to densify?
1021
+ Broadcast. _broadcast! (f, dest, mixedsrcargs... )
1022
+ end
1051
1023
1052
- @inline _sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
1053
- @inline _sparsifystructured (V:: AbstractVector ) = SparseVector (V)
1054
- @inline _sparsifystructured (M:: AbstractSparseMatrix ) = SparseMatrixCSC (M)
1055
- @inline _sparsifystructured (V:: AbstractSparseVector ) = SparseVector (V)
1056
- @inline _sparsifystructured (S:: SparseVecOrMat ) = S
1057
- @inline _sparsifystructured (x) = x
1024
+ _sparsifystructured (M:: AbstractMatrix ) = SparseMatrixCSC (M)
1025
+ _sparsifystructured (V:: AbstractVector ) = SparseVector (V)
1026
+ _sparsifystructured (P:: AbstractArray{<:Any,0} ) = SparseVector (reshape (P, 1 ))
1027
+ _sparsifystructured (M:: AbstractSparseMatrix ) = SparseMatrixCSC (M)
1028
+ _sparsifystructured (V:: AbstractSparseVector ) = SparseVector (V)
1029
+ _sparsifystructured (S:: SparseVecOrMat ) = S
1030
+ _sparsifystructured (x) = x
1058
1031
1059
1032
1060
1033
# (12) map[!] over combinations of sparse and structured matrices
1061
- StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
1062
1034
SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix}
1063
1035
map (f:: Tf , A:: StructuredMatrix ) where {Tf} = _noshapecheck_map (f, _sparsifystructured (A))
1064
1036
map (f:: Tf , A:: SparseOrStructuredMatrix , Bs:: Vararg{SparseOrStructuredMatrix,N} ) where {Tf,N} =
0 commit comments