Skip to content

Commit 6f239f1

Browse files
committed
improve cat design / performance
This used to make a lot of references to design issues with the SparseArrays package (#2326 / #20815), which result in a non-sensical dispatch arrangement, and contribute to a slow loading experience do to the nonsense Unions that must be checked by subtyping.
1 parent 4995d3f commit 6f239f1

File tree

9 files changed

+19
-51
lines changed

9 files changed

+19
-51
lines changed

base/abstractarray.jl

+5-7
Original file line numberDiff line numberDiff line change
@@ -1984,16 +1984,14 @@ julia> cat(1, [2], [3;;]; dims=Val(2))
19841984

19851985
# The specializations for 1 and 2 inputs are important
19861986
# especially when running with --inline=no, see #11158
1987-
# The specializations for Union{AbstractVecOrMat,Number} are necessary
1988-
# to have more specialized methods here than in LinearAlgebra/uniformscaling.jl
19891987
vcat(A::AbstractArray) = cat(A; dims=Val(1))
19901988
vcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(1))
19911989
vcat(A::AbstractArray...) = cat(A...; dims=Val(1))
1992-
vcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(1))
1990+
vcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(1))
19931991
hcat(A::AbstractArray) = cat(A; dims=Val(2))
19941992
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
19951993
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))
1996-
hcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(2))
1994+
hcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(2))
19971995

19981996
typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A)
19991997
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B)
@@ -2055,8 +2053,8 @@ julia> hvcat((2,2,2), a,b,c,d,e,f) == hvcat(2, a,b,c,d,e,f)
20552053
true
20562054
```
20572055
"""
2058-
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
2059-
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat{T}...) where {T} = typed_hvcat(T, rows, xs...)
2056+
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
2057+
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray{T}...) where {T} = typed_hvcat(T, rows, xs...)
20602058

20612059
function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat...) where T
20622060
nbr = length(rows) # number of block rows
@@ -2144,7 +2142,7 @@ end
21442142
hvcat(rows::Tuple{Vararg{Int}}, xs::Number...) = typed_hvcat(promote_typeof(xs...), rows, xs...)
21452143
hvcat(rows::Tuple{Vararg{Int}}, xs...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
21462144
# the following method is needed to provide a more specific one compared to LinearAlgebra/uniformscaling.jl
2147-
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractVecOrMat,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
2145+
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractArray,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
21482146

21492147
function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, xs::Number...) where T
21502148
nr = length(rows)

base/array.jl

-12
Original file line numberDiff line numberDiff line change
@@ -2041,18 +2041,6 @@ function vcat(arrays::Vector{T}...) where T
20412041
end
20422042
vcat(A::Vector...) = cat(A...; dims=Val(1)) # more special than SparseArrays's vcat
20432043

2044-
# disambiguation with LinAlg/special.jl
2045-
# Union{Number,Vector,Matrix} is for LinearAlgebra._DenseConcatGroup
2046-
# VecOrMat{T} is for LinearAlgebra._TypedDenseConcatGroup
2047-
hcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(2))
2048-
hcat(A::VecOrMat{T}...) where {T} = typed_hcat(T, A...)
2049-
vcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(1))
2050-
vcat(A::VecOrMat{T}...) where {T} = typed_vcat(T, A...)
2051-
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{Number,Vector,Matrix}...) =
2052-
typed_hvcat(promote_eltypeof(xs...), rows, xs...)
2053-
hvcat(rows::Tuple{Vararg{Int}}, xs::VecOrMat{T}...) where {T} =
2054-
typed_hvcat(T, rows, xs...)
2055-
20562044
_cat(n::Integer, x::Integer...) = reshape([x...], (ntuple(Returns(1), n-1)..., length(x)))
20572045

20582046
## find ##

deps/checksums/SparseArrays-2c7f4d6d839e9a97027454a037bfa004c1eb34b0.tar.gz/md5

-1
This file was deleted.

deps/checksums/SparseArrays-2c7f4d6d839e9a97027454a037bfa004c1eb34b0.tar.gz/sha512

-1
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
7f3a39f5b8420928993c4efb2d77626e
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1b020752a5681bc0c5de79f4ff2e80528f2899347b4b8b665fa339b191b4d2c71207cc80ddbba765dd8c049a9c99d589b5d917a6facda23eb7ef6c0213bf72af

stdlib/LinearAlgebra/src/special.jl

+5-21
Original file line numberDiff line numberDiff line change
@@ -330,27 +330,11 @@ end
330330
==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(B)) && iszero(A.ev) && A.dv == B.dv
331331
==(B::SymTridiagonal, A::Bidiagonal) = A == B
332332

333-
# concatenation
334-
const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
335-
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
336-
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
337-
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{T,A}
338-
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
339-
const _Annotated_Typed_DenseArrays{T} = Union{_Triangular_DenseArrays{T}, _Symmetric_DenseArrays{T}, _Hermitian_DenseArrays{T}}
340-
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
341-
const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpose{T,Vector{T}}, Matrix{T}, _Annotated_Typed_DenseArrays{T}}
342-
343-
promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix
344-
345-
Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...)
346-
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
347-
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
348-
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
349-
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
350-
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
351-
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
352-
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
353-
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)
333+
# TODO: remove these deprecations (used by SparseArrays in the past)
334+
const _DenseConcatGroup = Union{}
335+
const _SpecialArrays = Union{}
336+
337+
promote_to_array_type(::Tuple) = Matrix
354338

355339
# factorizations
356340
function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true)

stdlib/LinearAlgebra/src/uniformscaling.jl

+6-8
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ end
408408
# so that we can re-use this code for sparse-matrix hcat etcetera.
409409
promote_to_arrays_(n::Int, ::Type, a::Number) = a
410410
promote_to_arrays_(n::Int, ::Type{Matrix}, J::UniformScaling{T}) where {T} = Matrix(J, n, n)
411-
promote_to_arrays_(n::Int, ::Type, A::AbstractVecOrMat) = A
411+
promote_to_arrays_(n::Int, ::Type, A::AbstractArray) = A
412412
promote_to_arrays(n,k, ::Type) = ()
413413
promote_to_arrays(n,k, ::Type{T}, A) where {T} = (promote_to_arrays_(n[k], T, A),)
414414
promote_to_arrays(n,k, ::Type{T}, A, B) where {T} =
@@ -417,17 +417,16 @@ promote_to_arrays(n,k, ::Type{T}, A, B, C) where {T} =
417417
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays_(n[k+2], T, C))
418418
promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
419419
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
420-
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix
421420

422421
_us2number(A) = A
423422
_us2number(J::UniformScaling) = J.λ
424423

425424
for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"))
426425
@eval begin
427-
@inline $f(A::Union{AbstractVecOrMat,UniformScaling}...) = $_f(A...)
426+
@inline $f(A::Union{AbstractArray,UniformScaling}...) = $_f(A...)
428427
# if there's a Number present, J::UniformScaling must be 1x1-dimensional
429-
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $f(map(_us2number, A)...)
430-
function $_f(A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
428+
@inline $f(A::Union{AbstractArray,UniformScaling,Number}...) = $f(map(_us2number, A)...)
429+
function $_f(A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
431430
n = -1
432431
for a in A
433432
if !isa(a, UniformScaling)
@@ -445,9 +444,8 @@ for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"
445444
end
446445
end
447446

448-
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling}...) = _hvcat(rows, A...)
449-
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...) = _hvcat(rows, A...)
450-
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
447+
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...) = _hvcat(rows, A...)
448+
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
451449
require_one_based_indexing(A...)
452450
nr = length(rows)
453451
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))

stdlib/SparseArrays.version

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
SPARSEARRAYS_BRANCH = main
2-
SPARSEARRAYS_SHA1 = 2c7f4d6d839e9a97027454a037bfa004c1eb34b0
2+
SPARSEARRAYS_SHA1 = 78b1321ddc107370252fcc11b992f5c8bbd8f62f
33
SPARSEARRAYS_GIT_URL := https://github.com/JuliaSparse/SparseArrays.jl.git
44
SPARSEARRAYS_TAR_URL = https://api.github.com/repos/JuliaSparse/SparseArrays.jl/tarball/$1

0 commit comments

Comments
 (0)