Skip to content

Commit bc8337a

Browse files
authoredNov 16, 2021
Fix ==/sum/issymmetric for (Sym)Tridiagonal with non-number eltype (#43066)
1 parent 9affe2f commit bc8337a

File tree

3 files changed

+98
-24
lines changed

3 files changed

+98
-24
lines changed
 

‎stdlib/LinearAlgebra/src/tridiag.jl

+27-24
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,13 @@ AbstractMatrix{T}(S::SymTridiagonal) where {T} =
125125
function Matrix{T}(M::SymTridiagonal) where T
126126
n = size(M, 1)
127127
Mf = zeros(T, n, n)
128-
if n == 0
129-
return Mf
130-
end
131-
@inbounds begin
132-
@simd for i = 1:n-1
133-
Mf[i,i] = M.dv[i]
134-
Mf[i+1,i] = M.ev[i]
135-
Mf[i,i+1] = M.ev[i]
136-
end
137-
Mf[n,n] = M.dv[n]
128+
n == 0 && return Mf
129+
@inbounds for i = 1:n-1
130+
Mf[i,i] = symmetric(M.dv[i], :U)
131+
Mf[i+1,i] = transpose(M.ev[i])
132+
Mf[i,i+1] = M.ev[i]
138133
end
134+
Mf[n,n] = symmetric(M.dv[n], :U)
139135
return Mf
140136
end
141137
Matrix(M::SymTridiagonal{T}) where {T} = Matrix{T}(M)
@@ -612,8 +608,8 @@ transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
612608
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))
613609
Base.copy(tS::Transpose{<:Any,<:Tridiagonal}) = (S = tS.parent; Tridiagonal(map(x -> copy.(transpose.(x)), (S.du, S.d, S.dl))...))
614610

615-
ishermitian(S::Tridiagonal) = isreal(S.d) && S.du == adjoint.(S.dl)
616-
issymmetric(S::Tridiagonal) = S.du == S.dl
611+
ishermitian(S::Tridiagonal) = all(ishermitian, S.d) && all(Iterators.map((x, y) -> x == y', S.du, S.dl))
612+
issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) -> x == transpose(y), S.du, S.dl))
617613

618614
\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:StridedVecOrMat}) = copy(A) \ B
619615

@@ -744,8 +740,12 @@ end
744740
\(B::Number, A::Tridiagonal) = Tridiagonal(B\A.dl, B\A.d, B\A.du)
745741

746742
==(A::Tridiagonal, B::Tridiagonal) = (A.dl==B.dl) && (A.d==B.d) && (A.du==B.du)
747-
==(A::Tridiagonal, B::SymTridiagonal) = (A.dl==A.du==B.ev) && (A.d==B.dv)
748-
==(A::SymTridiagonal, B::Tridiagonal) = (B.dl==B.du==A.ev) && (B.d==A.dv)
743+
function ==(A::Tridiagonal, B::SymTridiagonal)
744+
iseq = all(Iterators.map((x, y) -> x == transpose(y), A.du, A.dl))
745+
iseq = iseq && A.du == _evview(B)
746+
iseq && all(Iterators.map((x, y) -> x == symmetric(y, :U), A.d, B.dv))
747+
end
748+
==(A::SymTridiagonal, B::Tridiagonal) = B == A
749749

750750
det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du)
751751

@@ -760,7 +760,10 @@ function SymTridiagonal{T}(M::Tridiagonal) where T
760760
end
761761

762762
Base._sum(A::Tridiagonal, ::Colon) = sum(A.d) + sum(A.dl) + sum(A.du)
763-
Base._sum(A::SymTridiagonal, ::Colon) = sum(A.dv) + 2sum(A.ev)
763+
function Base._sum(A::SymTridiagonal, ::Colon)
764+
se = sum(_evview(A))
765+
symmetric(sum(A.dv), :U) + se + transpose(se)
766+
end
764767

765768
function Base._sum(A::Tridiagonal, dims::Integer)
766769
res = Base.reducedim_initarray(A, dims, zero(eltype(A)))
@@ -807,24 +810,24 @@ function Base._sum(A::SymTridiagonal, dims::Integer)
807810
end
808811
@inbounds begin
809812
if dims == 1
810-
res[1] = A.ev[1] + A.dv[1]
813+
res[1] = transpose(A.ev[1]) + symmetric(A.dv[1], :U)
811814
for i = 2:n-1
812-
res[i] = A.ev[i] + A.dv[i] + A.ev[i-1]
815+
res[i] = transpose(A.ev[i]) + symmetric(A.dv[i], :U) + A.ev[i-1]
813816
end
814-
res[n] = A.dv[n] + A.ev[n-1]
817+
res[n] = symmetric(A.dv[n], :U) + A.ev[n-1]
815818
elseif dims == 2
816-
res[1] = A.dv[1] + A.ev[1]
819+
res[1] = symmetric(A.dv[1], :U) + A.ev[1]
817820
for i = 2:n-1
818-
res[i] = A.ev[i-1] + A.dv[i] + A.ev[i]
821+
res[i] = transpose(A.ev[i-1]) + symmetric(A.dv[i], :U) + A.ev[i]
819822
end
820-
res[n] = A.ev[n-1] + A.dv[n]
823+
res[n] = transpose(A.ev[n-1]) + symmetric(A.dv[n], :U)
821824
elseif dims >= 3
822825
for i = 1:n-1
823826
res[i,i+1] = A.ev[i]
824-
res[i,i] = A.dv[i]
825-
res[i+1,i] = A.ev[i]
827+
res[i,i] = symmetric(A.dv[i], :U)
828+
res[i+1,i] = transpose(A.ev[i])
826829
end
827-
res[n,n] = A.dv[n]
830+
res[n,n] = symmetric(A.dv[n], :U)
828831
end
829832
end
830833
res

‎stdlib/LinearAlgebra/test/tridiag.jl

+31
Original file line numberDiff line numberDiff line change
@@ -695,4 +695,35 @@ end
695695
end
696696
end
697697

698+
isdefined(Main, :SizedArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "SizedArrays.jl"))
699+
using .Main.SizedArrays
700+
@testset "non-number eltype" begin
701+
@testset "sum for SymTridiagonal" begin
702+
dv = [SizedArray{(2,2)}(rand(1:2048,2,2)) for i in 1:10]
703+
ev = [SizedArray{(2,2)}(rand(1:2048,2,2)) for i in 1:10]
704+
S = SymTridiagonal(dv, ev)
705+
Sdense = Matrix(S)
706+
@test Sdense == collect(S)
707+
@test sum(S) == sum(Sdense)
708+
@test sum(S, dims = 1) == sum(Sdense, dims = 1)
709+
@test sum(S, dims = 2) == sum(Sdense, dims = 2)
710+
end
711+
@testset "issymmetric/ishermitian for Tridiagonal" begin
712+
@test !issymmetric(Tridiagonal([[1 2;3 4]], [[1 2;2 3], [1 2;2 3]], [[1 2;3 4]]))
713+
@test !issymmetric(Tridiagonal([[1 3;2 4]], [[1 2;3 4], [1 2;3 4]], [[1 2;3 4]]))
714+
@test issymmetric(Tridiagonal([[1 3;2 4]], [[1 2;2 3], [1 2;2 3]], [[1 2;3 4]]))
715+
716+
@test ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+0im, [1 2;2 3].+0im], [[1 2;3 4].-im]))
717+
@test !ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+0im, [1 2;2 3].+0im], [[1 2;3 4].+im]))
718+
@test !ishermitian(Tridiagonal([[1 3;2 4].+im], [[1 2;2 3].+im, [1 2;2 3].+0im], [[1 2;3 4].-im]))
719+
end
720+
@testset "== between Tridiagonal and SymTridiagonal" begin
721+
dv = [SizedArray{(2,2)}([1 2;3 4]) for i in 1:4]
722+
ev = [SizedArray{(2,2)}([3 4;1 2]) for i in 1:4]
723+
S = SymTridiagonal(dv, ev)
724+
Sdense = Matrix(S)
725+
@test S == Tridiagonal(diag(Sdense, -1), diag(Sdense), diag(Sdense, 1)) == S
726+
@test S !== Tridiagonal(diag(Sdense, 1), diag(Sdense), diag(Sdense, 1)) !== S
727+
end
728+
end
698729
end # module TestTridiagonal

‎test/testhelpers/SizedArrays.jl

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
# SizedArrays
4+
5+
# This test file defines an array wrapper with statical size. It can be used to
6+
# test the action of LinearAlgebra with non-number eltype.
7+
8+
module SizedArrays
9+
10+
import Base: +, *, ==
11+
12+
export SizedArray
13+
14+
struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
15+
data::A
16+
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
17+
SZ == size(data) || throw(ArgumentError("size mismatch!"))
18+
new{SZ,T,N,typeof(data)}(data)
19+
end
20+
function SizedArray{SZ,T,N,A}(data::AbstractArray{T,N}) where {SZ,T,N,A}
21+
SZ == size(data) || throw(ArgumentError("size mismatch!"))
22+
new{SZ,T,N,A}(A(data))
23+
end
24+
end
25+
Base.convert(::Type{SizedArray{SZ,T,N,A}}, data::AbstractArray) where {SZ,T,N,A} = SizedArray{SZ,T,N,A}(data)
26+
27+
# Minimal AbstractArray interface
28+
Base.size(a::SizedArray) = size(typeof(a))
29+
Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ
30+
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
31+
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
32+
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
33+
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data
34+
function *(S1::SizedArray, S2::SizedArray)
35+
0 < ndims(S1) < 3 && 0 < ndims(S2) < 3 && size(S1, 2) == size(S2, 1) || throw(ArgumentError("size mismatch!"))
36+
data = S1.data * S2.data
37+
SZ = ndims(data) == 1 ? (size(S1, 1), ) : (size(S1, 1), size(S2, 2))
38+
SizedArray{SZ}(data)
39+
end
40+
end

0 commit comments

Comments
 (0)
Please sign in to comment.