Skip to content

Commit 4da7f45

Browse files
sethaxenantoine-levitt
authored andcommitted
Compute real matrix logarithm and matrix square root using real arithmetic (JuliaLang#39973)
* Add failing test * Add sylvester methods for small matrices * Add 2x2 real matrix square root * Add real square root of quasitriangular matrix * Simplify 2x2 real square root * Rename functions to use quasitriu * Avoid NaNs when eigenvalues are all zero * Reuse ranges * Add clarifying comments * Unify real and complex matrix square root * Add reference for real sqrt * Move quasitriu auxiliary functions to triangular.jl * Ensure loops are type-stable and use simd * Remove duplicate computation * Correctly promote for dimensionful A * Use simd directive * Test that UpperTriangular is returned by sqrt * Test sqrt for UnitUpperTriangular * Test that return type is complex when input type is * Test that output is complex when input is * Add failing test * Separate type-stable from type-unstable part * Use generic sqrt_quasitriu for sqrt triu * Avoid redundant matmul * Clarify comment * Return complex output for complex input * Call log_quasitriu * Add failing test for log type-inferrability * Realify or complexify as necessary * Call sqrt_quasitriu directly * Refactor sqrt_diag! * Simplify utility function * Add comment * Compute accurate block-diagonal * Compute superdiagonal for quasi triu A0 * Compute accurate block superdiagonal * Avoid full LU decomposition in inner loop * Avoid promotion to improve type-stability * Modify return type if necessary * Clarify comment * Add comments * Call log_quasitriu on quasitriu matrices * Document quasi-triangular algorithm * Remove test This matrix has eigenvalues to close to zero that its eltype is not stable * Rearrange definition * Add compatibility for unit triangular matrices * Release constraints on tests * Separate copying of A from log computation * Revert "Separate copying of A from log computation" This reverts commit 23becc5. * Use Givens rotations * Compute Schur in-place when possible * Always allocate a copy * Fix block indexing * Compute sqrt in-place * Overwrite AmI * Reduce allocations in Pade approximation * Use T * Don't unnecessarily unwrap * Test remaining log branches * Add additional matrix square root tests * Separate type-unstable from type-stable part This substantially reduces allocations for some reason * Use Ref instead of a Vector * Eliminate allocation in checksquare * Refactor param choosing code to own function * Comment section * Use more descriptive variable name * Reuse temporaries * Add reference * More accurately describe condition
1 parent bd80fef commit 4da7f45

File tree

5 files changed

+743
-187
lines changed

5 files changed

+743
-187
lines changed

stdlib/LinearAlgebra/src/dense.jl

+82-43
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ function rcswap!(i::Integer, j::Integer, X::StridedMatrix{<:Number})
679679
end
680680

681681
"""
682-
log(A{T}::StridedMatrix{T})
682+
log(A::StridedMatrix)
683683
684684
If `A` has no negative real eigenvalue, compute the principal matrix logarithm of `A`, i.e.
685685
the unique matrix ``X`` such that ``e^X = A`` and ``-\\pi < Im(\\lambda) < \\pi`` for all
@@ -688,9 +688,10 @@ matrix function is returned whenever possible.
688688
689689
If `A` is symmetric or Hermitian, its eigendecomposition ([`eigen`](@ref)) is
690690
used, if `A` is triangular an improved version of the inverse scaling and squaring method is
691-
employed (see [^AH12] and [^AHR13]). For general matrices, the complex Schur form
692-
([`schur`](@ref)) is computed and the triangular algorithm is used on the
693-
triangular factor.
691+
employed (see [^AH12] and [^AHR13]). If `A` is real with no negative eigenvalues, then
692+
the real Schur form is computed. Otherwise, the complex Schur form is computed. Then
693+
the upper (quasi-)triangular algorithm in [^AHR13] is used on the upper (quasi-)triangular
694+
factor.
694695
695696
[^AH12]: Awad H. Al-Mohy and Nicholas J. Higham, "Improved inverse scaling and squaring algorithms for the matrix logarithm", SIAM Journal on Scientific Computing, 34(4), 2012, C153-C169. [doi:10.1137/110852553](https://doi.org/10.1137/110852553)
696697
@@ -713,27 +714,28 @@ function log(A::StridedMatrix)
713714
# If possible, use diagonalization
714715
if ishermitian(A)
715716
logHermA = log(Hermitian(A))
716-
return isa(logHermA, Hermitian) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
717-
end
718-
719-
# Use Schur decomposition
720-
n = checksquare(A)
721-
if istriu(A)
722-
return triu!(parent(log(UpperTriangular(complex(A)))))
723-
else
724-
if isreal(A)
725-
SchurF = schur(real(A))
717+
return ishermitian(logHermA) ? copytri!(parent(logHermA), 'U', true) : parent(logHermA)
718+
elseif istriu(A)
719+
return triu!(parent(log(UpperTriangular(A))))
720+
elseif isreal(A)
721+
SchurF = schur(real(A))
722+
if istriu(SchurF.T)
723+
logA = SchurF.Z * log(UpperTriangular(SchurF.T)) * SchurF.Z'
726724
else
727-
SchurF = schur(A)
728-
end
729-
if !istriu(SchurF.T)
730-
SchurS = schur(complex(SchurF.T))
731-
logT = SchurS.Z * log(UpperTriangular(SchurS.T)) * SchurS.Z'
732-
return SchurF.Z * logT * SchurF.Z'
733-
else
734-
R = log(UpperTriangular(complex(SchurF.T)))
735-
return SchurF.Z * R * SchurF.Z'
725+
# real log exists whenever all eigenvalues are positive
726+
is_log_real = !any(x -> isreal(x) && real(x) 0, SchurF.values)
727+
if is_log_real
728+
logA = SchurF.Z * log_quasitriu(SchurF.T) * SchurF.Z'
729+
else
730+
SchurS = schur!(complex(SchurF.T))
731+
Z = SchurF.Z * SchurS.Z
732+
logA = Z * log(UpperTriangular(SchurS.T)) * Z'
733+
end
736734
end
735+
return eltype(A) <: Complex ? complex(logA) : logA
736+
else
737+
SchurF = schur(A)
738+
return SchurF.vectors * log(UpperTriangular(SchurF.T)) * SchurF.vectors'
737739
end
738740
end
739741

@@ -755,13 +757,21 @@ defaults to machine precision scaled by `size(A,1)`.
755757
Otherwise, the square root is determined by means of the
756758
Björck-Hammarling method [^BH83], which computes the complex Schur form ([`schur`](@ref))
757759
and then the complex square root of the triangular factor.
760+
If a real square root exists, then an extension of this method [^H87] that computes the real
761+
Schur form and then the real square root of the quasi-triangular factor is instead used.
758762
759763
[^BH83]:
760764
761765
Åke Björck and Sven Hammarling, "A Schur method for the square root of a matrix",
762766
Linear Algebra and its Applications, 52-53, 1983, 127-140.
763767
[doi:10.1016/0024-3795(83)80010-X](https://doi.org/10.1016/0024-3795(83)80010-X)
764768
769+
[^H87]:
770+
771+
Nicholas J. Higham, "Computing real square roots of a real matrix",
772+
Linear Algebra and its Applications, 88-89, 1987, 405-430.
773+
[doi:10.1016/0024-3795(87)90118-2](https://doi.org/10.1016/0024-3795(87)90118-2)
774+
765775
# Examples
766776
```jldoctest
767777
julia> A = [4 0; 0 4]
@@ -775,31 +785,32 @@ julia> sqrt(A)
775785
0.0 2.0
776786
```
777787
"""
778-
function sqrt(A::StridedMatrix{<:Real})
779-
if issymmetric(A)
780-
return copytri!(parent(sqrt(Symmetric(A))), 'U')
781-
end
782-
n = checksquare(A)
783-
if istriu(A)
784-
return triu!(parent(sqrt(UpperTriangular(A))))
785-
else
786-
SchurF = schur(complex(A))
787-
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
788-
return SchurF.vectors * R * SchurF.vectors'
789-
end
790-
end
791-
function sqrt(A::StridedMatrix{<:Complex})
788+
function sqrt(A::StridedMatrix{T}) where {T<:Union{Real,Complex}}
792789
if ishermitian(A)
793790
sqrtHermA = sqrt(Hermitian(A))
794-
return isa(sqrtHermA, Hermitian) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
795-
end
796-
n = checksquare(A)
797-
if istriu(A)
791+
return ishermitian(sqrtHermA) ? copytri!(parent(sqrtHermA), 'U', true) : parent(sqrtHermA)
792+
elseif istriu(A)
798793
return triu!(parent(sqrt(UpperTriangular(A))))
794+
elseif isreal(A)
795+
SchurF = schur(real(A))
796+
if istriu(SchurF.T)
797+
sqrtA = SchurF.Z * sqrt(UpperTriangular(SchurF.T)) * SchurF.Z'
798+
else
799+
# real sqrt exists whenever no eigenvalues are negative
800+
is_sqrt_real = !any(x -> isreal(x) && real(x) < 0, SchurF.values)
801+
# sqrt_quasitriu uses LAPACK functions for non-triu inputs
802+
if typeof(sqrt(zero(T))) <: BlasFloat && is_sqrt_real
803+
sqrtA = SchurF.Z * sqrt_quasitriu(SchurF.T) * SchurF.Z'
804+
else
805+
SchurS = schur!(complex(SchurF.T))
806+
Z = SchurF.Z * SchurS.Z
807+
sqrtA = Z * sqrt(UpperTriangular(SchurS.T)) * Z'
808+
end
809+
end
810+
return eltype(A) <: Complex ? complex(sqrtA) : sqrtA
799811
else
800812
SchurF = schur(A)
801-
R = triu!(parent(sqrt(UpperTriangular(SchurF.T)))) # unwrapping unnecessary?
802-
return SchurF.vectors * R * SchurF.vectors'
813+
return SchurF.vectors * sqrt(UpperTriangular(SchurF.T)) * SchurF.vectors'
803814
end
804815
end
805816

@@ -1526,6 +1537,34 @@ function sylvester(A::StridedMatrix{T},B::StridedMatrix{T},C::StridedMatrix{T})
15261537
end
15271538
sylvester(A::StridedMatrix{T}, B::StridedMatrix{T}, C::StridedMatrix{T}) where {T<:Integer} = sylvester(float(A), float(B), float(C))
15281539

1540+
Base.@propagate_inbounds function _sylvester_2x1!(A, B, C)
1541+
b = B[1]
1542+
a21, a12 = A[2, 1], A[1, 2]
1543+
m11 = b + A[1, 1]
1544+
m22 = b + A[2, 2]
1545+
d = m11 * m22 - a12 * a21
1546+
c1, c2 = C
1547+
C[1] = (a12 * c2 - m22 * c1) / d
1548+
C[2] = (a21 * c1 - m11 * c2) / d
1549+
return C
1550+
end
1551+
Base.@propagate_inbounds function _sylvester_1x2!(A, B, C)
1552+
a = A[1]
1553+
b21, b12 = B[2, 1], B[1, 2]
1554+
m11 = a + B[1, 1]
1555+
m22 = a + B[2, 2]
1556+
d = m11 * m22 - b21 * b12
1557+
c1, c2 = C
1558+
C[1] = (b21 * c2 - m22 * c1) / d
1559+
C[2] = (b12 * c1 - m11 * c2) / d
1560+
return C
1561+
end
1562+
function _sylvester_2x2!(A, B, C)
1563+
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
1564+
rmul!(C, -inv(scale))
1565+
return C
1566+
end
1567+
15291568
sylvester(a::Union{Real,Complex}, b::Union{Real,Complex}, c::Union{Real,Complex}) = -c / (a + b)
15301569

15311570
# AX + XA' + C = 0

stdlib/LinearAlgebra/src/lapack.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -6449,15 +6449,15 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
64496449
B::AbstractMatrix{$elty}, C::AbstractMatrix{$elty}, isgn::Int=1)
64506450
require_one_based_indexing(A, B, C)
64516451
chkstride1(A, B, C)
6452-
m, n = checksquare(A, B)
6452+
m, n = checksquare(A), checksquare(B)
64536453
lda = max(1, stride(A, 2))
64546454
ldb = max(1, stride(B, 2))
64556455
m1, n1 = size(C)
64566456
if m != m1 || n != n1
64576457
throw(DimensionMismatch("dimensions of A, ($m,$n), and C, ($m1,$n1), must match"))
64586458
end
64596459
ldc = max(1, stride(C, 2))
6460-
scale = Vector{$relty}(undef, 1)
6460+
scale = Ref{$relty}()
64616461
info = Ref{BlasInt}()
64626462
ccall((@blasfunc($fn), libblastrampoline), Cvoid,
64636463
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{BlasInt},
@@ -6467,7 +6467,7 @@ for (fn, elty, relty) in ((:dtrsyl_, :Float64, :Float64),
64676467
A, lda, B, ldb, C, ldc,
64686468
scale, info, 1, 1)
64696469
chklapackerror(info[])
6470-
C, scale[1]
6470+
C, scale[]
64716471
end
64726472
end
64736473
end

0 commit comments

Comments
 (0)