Skip to content

Commit 9738bc7

Browse files
authored
Fix tr for Symmetric/Hermitian block matrices (#55522)
Since `Symmetric` and `Hermitian` symmetrize the diagonal elements of the parent, we can't forward `tr` to the parent unless it is already symmetric. This limits the existing `tr` methods to matrices of `Number`s, which is the common use-case. `tr` for `Symmetric` block matrices would now use the fallback implementation that explicitly computes the `diag`. This resolves the following discrepancy: ```julia julia> S = Symmetric(fill([1 2; 3 4], 3, 3)) 3×3 Symmetric{AbstractMatrix, Matrix{Matrix{Int64}}}: [1 2; 2 4] [1 2; 3 4] [1 2; 3 4] [1 3; 2 4] [1 2; 2 4] [1 2; 3 4] [1 3; 2 4] [1 3; 2 4] [1 2; 2 4] julia> tr(S) 2×2 Matrix{Int64}: 3 6 9 12 julia> sum(diag(S)) 2×2 Symmetric{Int64, Matrix{Int64}}: 3 6 6 12 ```
1 parent 306cee7 commit 9738bc7

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
449449
Base.copy(A::Transpose{<:Any,<:Hermitian}) =
450450
Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))
451451

452-
tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
453-
tr(A::Hermitian) = real(tr(A.data))
452+
tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
453+
tr(A::Hermitian{<:Number}) = real(tr(A.data))
454454

455455
Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
456456
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))

stdlib/LinearAlgebra/test/symmetric.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1116,4 +1116,15 @@ end
11161116
end
11171117
end
11181118

1119+
@testset "tr for block matrices" begin
1120+
m = [1 2; 3 4]
1121+
for b in (m, m * (1 + im))
1122+
M = fill(b, 3, 3)
1123+
for ST in (Symmetric, Hermitian)
1124+
S = ST(M)
1125+
@test tr(S) == sum(diag(S))
1126+
end
1127+
end
1128+
end
1129+
11191130
end # module TestSymmetric

0 commit comments

Comments
 (0)