Skip to content

Commit 04259da

Browse files
authored
Reroute (Upper/Lower)Triangular * Diagonal through __muldiag (#55984)
Currently, `::Diagonal * ::AbstractMatrix` calls the method `LinearAlgebra.__muldiag!` in general that scales the rows, and similarly for the diagonal on the right. The implementation of `__muldiag` was duplicating the logic in `LinearAlgebra.modify!` and the methods for `MulAddMul`. This PR replaces the various branches with calls to `modify!` instead. I've also extracted the multiplication logic into its own function `__muldiag_nonzeroalpha!` so that this may be specialized for matrix types, such as triangular ones. Secondly, `::Diagonal * ::UpperTriangular` (and similarly, other triangular matrices) was specialized to forward the multiplication to the parent of the triangular. For strided matrices, however, it makes more sense to use the structure and scale only the filled half of the matrix. Firstly, this improves performance, and secondly, this avoids errors in case the parent isn't fully initialized corresponding to the structural zero elements. Performance improvement: ```julia julia> D = Diagonal(1:400); julia> U = UpperTriangular(zeros(size(D))); julia> @Btime $D * $U; 314.944 μs (3 allocations: 1.22 MiB) # v"1.12.0-DEV.1288" 195.960 μs (3 allocations: 1.22 MiB) # This PR ``` Fix: ```julia julia> M = Matrix{BigFloat}(undef, 2, 2); julia> M[1,1] = M[2,2] = M[1,2] = 3; julia> U = UpperTriangular(M) 2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}: 3.0 3.0 ⋅ 3.0 julia> D = Diagonal(1:2); julia> U * D # works after this PR 2×2 UpperTriangular{BigFloat, Matrix{BigFloat}}: 3.0 6.0 ⋅ 6.0 ```
1 parent b01095e commit 04259da

File tree

3 files changed

+134
-66
lines changed

3 files changed

+134
-66
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

+2
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,8 @@ matprod_dest(A::StructuredMatrix, B::Diagonal, TS) = _matprod_dest_diag(A, TS)
655655
matprod_dest(A::Diagonal, B::StructuredMatrix, TS) = _matprod_dest_diag(B, TS)
656656
matprod_dest(A::Diagonal, B::Diagonal, TS) = _matprod_dest_diag(B, TS)
657657
_matprod_dest_diag(A, TS) = similar(A, TS)
658+
_matprod_dest_diag(A::UnitUpperTriangular, TS) = UpperTriangular(similar(parent(A), TS))
659+
_matprod_dest_diag(A::UnitLowerTriangular, TS) = LowerTriangular(similar(parent(A), TS))
658660
function _matprod_dest_diag(A::SymTridiagonal, TS)
659661
n = size(A, 1)
660662
ev = similar(A, TS, max(0, n-1))

stdlib/LinearAlgebra/src/diagonal.jl

+93-65
Original file line numberDiff line numberDiff line change
@@ -396,82 +396,120 @@ function lmul!(D::Diagonal, T::Tridiagonal)
396396
return T
397397
end
398398

399-
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
399+
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B, _add::MulAddMul)
400+
@inbounds for j in axes(B, 2)
401+
@simd for i in axes(B, 1)
402+
_modify!(_add, D.diag[i] * B[i,j], out, (i,j))
403+
end
404+
end
405+
out
406+
end
407+
_maybe_unwrap_tri(out, A) = out, A
408+
_maybe_unwrap_tri(out::UpperTriangular, A::UpperOrUnitUpperTriangular) = parent(out), parent(A)
409+
_maybe_unwrap_tri(out::LowerTriangular, A::LowerOrUnitLowerTriangular) = parent(out), parent(A)
410+
@inline function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, _add::MulAddMul)
411+
isunit = B isa Union{UnitUpperTriangular, UnitLowerTriangular}
412+
# if both B and out have the same upper/lower triangular structure,
413+
# we may directly read and write from the parents
414+
out_maybeparent, B_maybeparent = _maybe_unwrap_tri(out, B)
415+
for j in axes(B, 2)
416+
if isunit
417+
_modify!(_add, D.diag[j] * B[j,j], out, (j,j))
418+
end
419+
rowrange = B isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(B,1))) : (j+isunit:size(B,1))
420+
@inbounds @simd for i in rowrange
421+
_modify!(_add, D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
422+
end
423+
end
424+
out
425+
end
426+
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul)
400427
require_one_based_indexing(out, B)
401428
alpha, beta = _add.alpha, _add.beta
402429
if iszero(alpha)
403430
_rmul_or_fill!(out, beta)
404431
else
405-
if bis0
406-
@inbounds for j in axes(B, 2)
407-
@simd for i in axes(B, 1)
408-
out[i,j] = D.diag[i] * B[i,j] * alpha
409-
end
410-
end
411-
else
412-
@inbounds for j in axes(B, 2)
413-
@simd for i in axes(B, 1)
414-
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
415-
end
416-
end
417-
end
432+
__muldiag_nonzeroalpha!(out, D, B, _add)
418433
end
419434
return out
420435
end
421-
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
436+
437+
@inline function __muldiag_nonzeroalpha!(out, A, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
438+
beta = _add.beta
439+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
440+
@inbounds for j in axes(A, 2)
441+
dja = _add(D.diag[j])
442+
@simd for i in axes(A, 1)
443+
_modify!(_add_aisone, A[i,j] * dja, out, (i,j))
444+
end
445+
end
446+
out
447+
end
448+
@inline function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
449+
isunit = A isa Union{UnitUpperTriangular, UnitLowerTriangular}
450+
beta = _add.beta
451+
# since alpha is multiplied to the diagonal element of D,
452+
# we may skip alpha in the second multiplication by setting ais1 to true
453+
_add_aisone = MulAddMul{true,bis0,Bool,typeof(beta)}(true, beta)
454+
# if both A and out have the same upper/lower triangular structure,
455+
# we may directly read and write from the parents
456+
out_maybeparent, A_maybeparent = _maybe_unwrap_tri(out, A)
457+
@inbounds for j in axes(A, 2)
458+
dja = _add(D.diag[j])
459+
if isunit
460+
_modify!(_add_aisone, A[j,j] * dja, out, (j,j))
461+
end
462+
rowrange = A isa UpperOrUnitUpperTriangular ? (1:min(j-isunit, size(A,1))) : (j+isunit:size(A,1))
463+
@simd for i in rowrange
464+
_modify!(_add_aisone, A_maybeparent[i,j] * dja, out_maybeparent, (i,j))
465+
end
466+
end
467+
out
468+
end
469+
function __muldiag!(out, A, D::Diagonal, _add::MulAddMul)
422470
require_one_based_indexing(out, A)
423471
alpha, beta = _add.alpha, _add.beta
424472
if iszero(alpha)
425473
_rmul_or_fill!(out, beta)
426474
else
427-
if bis0
428-
@inbounds for j in axes(A, 2)
429-
dja = D.diag[j] * alpha
430-
@simd for i in axes(A, 1)
431-
out[i,j] = A[i,j] * dja
432-
end
433-
end
434-
else
435-
@inbounds for j in axes(A, 2)
436-
dja = D.diag[j] * alpha
437-
@simd for i in axes(A, 1)
438-
out[i,j] = A[i,j] * dja + out[i,j] * beta
439-
end
440-
end
441-
end
475+
__muldiag_nonzeroalpha!(out, A, D, _add)
442476
end
443477
return out
444478
end
445-
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
479+
480+
@inline function __muldiag_nonzeroalpha!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
446481
d1 = D1.diag
447482
d2 = D2.diag
483+
outd = out.diag
484+
@inbounds @simd for i in eachindex(d1, d2, outd)
485+
_modify!(_add, d1[i] * d2[i], outd, i)
486+
end
487+
out
488+
end
489+
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
448490
alpha, beta = _add.alpha, _add.beta
449491
if iszero(alpha)
450492
_rmul_or_fill!(out.diag, beta)
451493
else
452-
if bis0
453-
@inbounds @simd for i in eachindex(out.diag)
454-
out.diag[i] = d1[i] * d2[i] * alpha
455-
end
456-
else
457-
@inbounds @simd for i in eachindex(out.diag)
458-
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
459-
end
460-
end
494+
__muldiag_nonzeroalpha!(out, D1, D2, _add)
461495
end
462496
return out
463497
end
464-
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
465-
require_one_based_indexing(out)
466-
alpha, beta = _add.alpha, _add.beta
467-
mA = size(D1, 1)
498+
@inline function __muldiag_nonzeroalpha!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul)
468499
d1 = D1.diag
469500
d2 = D2.diag
501+
@inbounds @simd for i in eachindex(d1, d2)
502+
_modify!(_add, d1[i] * d2[i], out, (i,i))
503+
end
504+
out
505+
end
506+
function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1}) where {ais1}
507+
require_one_based_indexing(out)
508+
alpha, beta = _add.alpha, _add.beta
470509
_rmul_or_fill!(out, beta)
471510
if !iszero(alpha)
472-
@inbounds @simd for i in 1:mA
473-
out[i,i] += d1[i] * d2[i] * alpha
474-
end
511+
_add_bis1 = MulAddMul{ais1,false,typeof(alpha),Bool}(alpha,true)
512+
__muldiag_nonzeroalpha!(out, D1, D2, _add_bis1)
475513
end
476514
return out
477515
end
@@ -658,31 +696,21 @@ for Tri in (:UpperTriangular, :LowerTriangular)
658696
@eval $fun(A::$Tri, D::Diagonal) = $Tri($fun(A.data, D))
659697
@eval $fun(A::$UTri, D::Diagonal) = $Tri(_setdiag!($fun(A.data, D), $f, D.diag))
660698
end
699+
@eval *(A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
700+
@invoke *(A::AbstractMatrix, D::Diagonal)
701+
@eval *(A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}, D::Diagonal) =
702+
@invoke *(A::AbstractMatrix, D::Diagonal)
661703
for (fun, f) in zip((:*, :lmul!, :ldiv!, :\), (:identity, :identity, :inv, :inv))
662704
@eval $fun(D::Diagonal, A::$Tri) = $Tri($fun(D, A.data))
663705
@eval $fun(D::Diagonal, A::$UTri) = $Tri(_setdiag!($fun(D, A.data), $f, D.diag))
664706
end
707+
@eval *(D::Diagonal, A::$Tri{<:Any, <:StridedMaybeAdjOrTransMat}) =
708+
@invoke *(D::Diagonal, A::AbstractMatrix)
709+
@eval *(D::Diagonal, A::$UTri{<:Any, <:StridedMaybeAdjOrTransMat}) =
710+
@invoke *(D::Diagonal, A::AbstractMatrix)
665711
# 3-arg ldiv!
666712
@eval ldiv!(C::$Tri, D::Diagonal, A::$Tri) = $Tri(ldiv!(C.data, D, A.data))
667713
@eval ldiv!(C::$Tri, D::Diagonal, A::$UTri) = $Tri(_setdiag!(ldiv!(C.data, D, A.data), inv, D.diag))
668-
# 3-arg mul! is disambiguated in special.jl
669-
# 5-arg mul!
670-
@eval _mul!(C::$Tri, D::Diagonal, A::$Tri, _add) = $Tri(mul!(C.data, D, A.data, _add.alpha, _add.beta))
671-
@eval function _mul!(C::$Tri, D::Diagonal, A::$UTri, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
672-
α, β = _add.alpha, _add.beta
673-
iszero(α) && return _rmul_or_fill!(C, β)
674-
diag′ = bis0 ? nothing : diag(C)
675-
data = mul!(C.data, D, A.data, α, β)
676-
$Tri(_setdiag!(data, _add, D.diag, diag′))
677-
end
678-
@eval _mul!(C::$Tri, A::$Tri, D::Diagonal, _add) = $Tri(mul!(C.data, A.data, D, _add.alpha, _add.beta))
679-
@eval function _mul!(C::$Tri, A::$UTri, D::Diagonal, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
680-
α, β = _add.alpha, _add.beta
681-
iszero(α) && return _rmul_or_fill!(C, β)
682-
diag′ = bis0 ? nothing : diag(C)
683-
data = mul!(C.data, A.data, D, α, β)
684-
$Tri(_setdiag!(data, _add, D.diag, diag′))
685-
end
686714
end
687715

688716
@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)

stdlib/LinearAlgebra/test/diagonal.jl

+39-1
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ end
11881188
@test oneunit(D3) isa typeof(D3)
11891189
end
11901190

1191-
@testset "AbstractTriangular" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
1191+
@testset "$Tri" for (Tri, UTri) in ((UpperTriangular, UnitUpperTriangular), (LowerTriangular, UnitLowerTriangular))
11921192
A = randn(4, 4)
11931193
TriA = Tri(A)
11941194
UTriA = UTri(A)
@@ -1218,6 +1218,44 @@ end
12181218
@test outTri === mul!(outTri, D, UTriA, 2, 1)::Tri == mul!(out, D, Matrix(UTriA), 2, 1)
12191219
@test outTri === mul!(outTri, TriA, D, 2, 1)::Tri == mul!(out, Matrix(TriA), D, 2, 1)
12201220
@test outTri === mul!(outTri, UTriA, D, 2, 1)::Tri == mul!(out, Matrix(UTriA), D, 2, 1)
1221+
1222+
# we may write to a Unit triangular if the diagonal is preserved
1223+
ID = Diagonal(ones(size(UTriA,2)))
1224+
@test mul!(copy(UTriA), UTriA, ID) == UTriA
1225+
@test mul!(copy(UTriA), ID, UTriA) == UTriA
1226+
1227+
@testset "partly filled parents" begin
1228+
M = Matrix{BigFloat}(undef, 2, 2)
1229+
M[1,1] = M[2,2] = 3
1230+
isupper = Tri == UpperTriangular
1231+
M[1+!isupper, 1+isupper] = 3
1232+
D = Diagonal(1:2)
1233+
T = Tri(M)
1234+
TA = Array(T)
1235+
@test T * D == TA * D
1236+
@test D * T == D * TA
1237+
@test mul!(copy(T), T, D, 2, 3) == 2T * D + 3T
1238+
@test mul!(copy(T), D, T, 2, 3) == 2D * T + 3T
1239+
1240+
U = UTri(M)
1241+
UA = Array(U)
1242+
@test U * D == UA * D
1243+
@test D * U == D * UA
1244+
@test mul!(copy(T), U, D, 2, 3) == 2 * UA * D + 3TA
1245+
@test mul!(copy(T), D, U, 2, 3) == 2 * D * UA + 3TA
1246+
1247+
M2 = Matrix{BigFloat}(undef, 2, 2)
1248+
M2[1+!isupper, 1+isupper] = 3
1249+
U = UTri(M2)
1250+
UA = Array(U)
1251+
@test U * D == UA * D
1252+
@test D * U == D * UA
1253+
ID = Diagonal(ones(size(U,2)))
1254+
@test mul!(copy(U), U, ID) == U
1255+
@test mul!(copy(U), ID, U) == U
1256+
@test mul!(copy(U), U, ID, 2, -1) == U
1257+
@test mul!(copy(U), ID, U, 2, -1) == U
1258+
end
12211259
end
12221260

12231261
struct SMatrix1{T} <: AbstractArray{T,2}

0 commit comments

Comments
 (0)