Skip to content

Commit 42a02ae

Browse files
committed
improve cat design
This is type-piracy, but we cannot change that (until JuliaLang/julia#2326), so at least do not make these method intersections unnecessary slow and complicated for everyone who does not care about SparseArrays and does not load it, and unreliable for everyone who does load it.
1 parent 8affe9e commit 42a02ae

File tree

2 files changed

+52
-28
lines changed

2 files changed

+52
-28
lines changed

src/linalg.jl

+7
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,13 @@ const _Triangular_SparseKronArrays = UpperOrLowerTriangular{<:Any,<:_SparseKronA
13231323
const _Annotated_SparseKronArrays = Union{_Triangular_SparseKronArrays, _Symmetric_SparseKronArrays, _Hermitian_SparseKronArrays}
13241324
const _SparseKronGroup = Union{_SparseKronArrays, _Annotated_SparseKronArrays}
13251325

1326+
const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
1327+
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
1328+
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
1329+
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{<:Any,A} # AbstractTriangular{T,A}
1330+
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
1331+
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
1332+
13261333
@inline function kron!(C::SparseMatrixCSC, A::AbstractSparseMatrixCSC, B::AbstractSparseMatrixCSC)
13271334
mA, nA = size(A); mB, nB = size(B)
13281335
mC, nC = mA*mB, nA*nB

src/sparsevector.jl

+45-28
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import Base: sort!, findall, copy!
66
import LinearAlgebra: promote_to_array_type, promote_to_arrays_
7-
using LinearAlgebra: _SpecialArrays, _DenseConcatGroup
87

98
### The SparseVector
109

@@ -1175,24 +1174,10 @@ function _absspvec_vcat(X::AbstractSparseVector{Tv,Ti}...) where {Tv,Ti}
11751174
SparseVector(len, rnzind, rnzval)
11761175
end
11771176

1178-
hcat(Xin::Union{Vector, AbstractSparseVector}...) = hcat(map(sparse, Xin)...)
1179-
vcat(Xin::Union{Vector, AbstractSparseVector}...) = vcat(map(sparse, Xin)...)
1180-
11811177
### Concatenation of un/annotated sparse/special/dense vectors/matrices
1182-
1183-
const _SparseArrays = Union{AbstractSparseVector,
1184-
AbstractSparseMatrixCSC,
1185-
Adjoint{<:Any,<:AbstractSparseVector},
1186-
Transpose{<:Any,<:AbstractSparseVector}}
1187-
const _SparseConcatArrays = Union{_SpecialArrays, _SparseArrays}
1188-
1189-
const _Symmetric_SparseConcatArrays = Symmetric{<:Any,<:_SparseConcatArrays}
1190-
const _Hermitian_SparseConcatArrays = Hermitian{<:Any,<:_SparseConcatArrays}
1191-
const _Triangular_SparseConcatArrays = UpperOrLowerTriangular{<:Any,<:_SparseConcatArrays}
1192-
const _Annotated_SparseConcatArrays = Union{_Triangular_SparseConcatArrays, _Symmetric_SparseConcatArrays, _Hermitian_SparseConcatArrays}
1193-
# It's important that _SparseConcatGroup is a larger union than _DenseConcatGroup to make
1194-
# sparse cat-methods less specific and to kick in only if there is some sparse array present
1195-
const _SparseConcatGroup = Union{_DenseConcatGroup, _SparseConcatArrays, _Annotated_SparseConcatArrays}
1178+
# by type-pirating and subverting the Base.cat design by making these a subtype of the normal methods for it
1179+
# and re-defining all of it here. See https://github.com/JuliaLang/julia/issues/2326
1180+
# for what would have been a more principled way of doing this.
11961181

11971182
# Concatenations involving un/annotated sparse/special matrices/vectors should yield sparse arrays
11981183

@@ -1204,23 +1189,55 @@ _sparse(A) = _makesparse(A)
12041189
_makesparse(x::Number) = x
12051190
_makesparse(x::AbstractVector) = convert(SparseVector, issparse(x) ? x : sparse(x))::SparseVector
12061191
_makesparse(x::AbstractMatrix) = convert(SparseMatrixCSC, issparse(x) ? x : sparse(x))::SparseMatrixCSC
1192+
anysparse() = false
1193+
anysparse(X) = X isa AbstractArray && issparse(X)
1194+
anysparse(X, Xs...) = anysparse(X) || anysparse(Xs...)
1195+
1196+
function hcat(X::Union{Vector, AbstractSparseVector}...)
1197+
if anysparse(X...)
1198+
X = map(sparse, X)
1199+
end
1200+
return cat(X...; dims=Val(2))
1201+
end
1202+
function vcat(X::Union{Vector, AbstractSparseVector}...)
1203+
if anysparse(X...)
1204+
X = map(sparse, X)
1205+
end
1206+
return cat(X...; dims=Val(1))
1207+
end
1208+
1209+
# type-pirate the Base.cat design by making this a subtype of the existing method for it
1210+
# in future versions of Julia (v1.10+), in which https://github.com/JuliaLang/julia/issues/2326 is not fixed yet, the <:Number constraint could be relaxed
1211+
# but see also https://github.com/JuliaSparse/SparseArrays.jl/issues/71
1212+
const _SparseConcatGroup = Union{AbstractVecOrMat{<:Number},Number}
12071213

12081214
# `@constprop :aggressive` allows `dims` to be propagated as constant improving return type inference
1209-
Base.@constprop :aggressive function Base._cat(dims, Xin::_SparseConcatGroup...)
1210-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1211-
T = promote_eltype(Xin...)
1215+
Base.@constprop :aggressive function Base._cat(dims, X::_SparseConcatGroup...)
1216+
T = promote_eltype(X...)
1217+
if anysparse(X...)
1218+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1219+
end
12121220
return Base._cat_t(dims, T, X...)
12131221
end
1214-
function hcat(Xin::_SparseConcatGroup...)
1215-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1222+
function hcat(X::_SparseConcatGroup...)
1223+
if anysparse(X...)
1224+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1225+
end
12161226
return cat(X..., dims=Val(2))
12171227
end
1218-
function vcat(Xin::_SparseConcatGroup...)
1219-
X = (_sparse(first(Xin)), map(_makesparse, Base.tail(Xin))...)
1228+
function vcat(X::_SparseConcatGroup...)
1229+
if anysparse(X...)
1230+
X = (_sparse(first(X)), map(_makesparse, Base.tail(X))...)
1231+
end
12201232
return cat(X..., dims=Val(1))
12211233
end
1222-
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
1223-
vcat(_hvcat_rows(rows, X...)...)
1234+
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1235+
if anysparse(X...)
1236+
vcat(_hvcat_rows(rows, X...)...)
1237+
else
1238+
typed_hvcat(promote_eltypeof(X...), rows, X...)
1239+
end
1240+
end
12241241
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
12251242
if row1 0
12261243
throw(ArgumentError("length of block row must be positive, got $row1"))
@@ -1237,7 +1254,7 @@ end
12371254
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()
12381255

12391256
# make sure UniformScaling objects are converted to sparse matrices for concatenation
1240-
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC
1257+
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
12411258
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)
12421259

12431260
"""

0 commit comments

Comments
 (0)