Skip to content

Commit ed987f2

Browse files
authored
Bidiagonal to Tridiagonal with immutable bands (#55059)
Using `similar` to generate the zero band necessarily allocates a mutable vector, which would lead to an error if the other bands are immutable. This PR changes this to use `zero` instead, which usually produces a vector of the same type. There are occasions where `zero(v)` produces a different type from `v`, so an extra conversion is added to obtain a zero vector of the same type. The following works after this: ```julia julia> using FillArrays, LinearAlgebra julia> n = 4; B = Bidiagonal(Fill(3, n), Fill(2, n-1), :U) 4×4 Bidiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 3 2 ⋅ ⋅ ⋅ 3 2 ⋅ ⋅ ⋅ 3 2 ⋅ ⋅ ⋅ 3 julia> Tridiagonal(B) 4×4 Tridiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 3 2 ⋅ ⋅ 0 3 2 ⋅ ⋅ 0 3 2 ⋅ ⋅ 0 3 julia> Tridiagonal{Float64}(B) 4×4 Tridiagonal{Float64, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}: 3.0 2.0 ⋅ ⋅ 0.0 3.0 2.0 ⋅ ⋅ 0.0 3.0 2.0 ⋅ ⋅ 0.0 3.0 ```
1 parent 23dabef commit ed987f2

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ promote_rule(::Type{<:Matrix}, ::Type{<:Bidiagonal}) = Matrix
231231
function Tridiagonal{T}(A::Bidiagonal) where T
232232
dv = convert(AbstractVector{T}, A.dv)
233233
ev = convert(AbstractVector{T}, A.ev)
234-
z = fill!(similar(ev), zero(T))
234+
# ensure that the types are identical, even if zero returns a different type
235+
z = oftype(ev, zero(ev))
235236
A.uplo == 'U' ? Tridiagonal(z, dv, ev) : Tridiagonal(ev, dv, z)
236237
end
237238
promote_rule(::Type{<:Tridiagonal{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} =

stdlib/LinearAlgebra/src/special.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ Diagonal(A::Bidiagonal) = Diagonal(A.dv)
1515
SymTridiagonal(A::Bidiagonal) =
1616
iszero(A.ev) ? SymTridiagonal(A.dv, A.ev) :
1717
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
18-
Tridiagonal(A::Bidiagonal) =
19-
Tridiagonal(A.uplo == 'U' ? fill!(similar(A.ev), 0) : A.ev, A.dv,
20-
A.uplo == 'U' ? A.ev : fill!(similar(A.ev), 0))
18+
function Tridiagonal(A::Bidiagonal)
19+
# ensure that the types are identical, even if zero returns a different type
20+
z = oftype(A.ev, zero(A.ev))
21+
Tridiagonal(A.uplo == 'U' ? z : A.ev, A.dv, A.uplo == 'U' ? A.ev : z)
22+
end
2123

2224
# conversions from SymTridiagonal to other special matrix types
2325
Diagonal(A::SymTridiagonal) = Diagonal(A.dv)

stdlib/LinearAlgebra/test/bidiag.jl

+16
Original file line numberDiff line numberDiff line change
@@ -933,4 +933,20 @@ end
933933
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
934934
end
935935

936+
@testset "conversion to Tridiagonal for immutable bands" begin
937+
n = 4
938+
dv = FillArrays.Fill(3, n)
939+
ev = FillArrays.Fill(2, n-1)
940+
z = FillArrays.Fill(0, n-1)
941+
dvf = FillArrays.Fill(Float64(3), n)
942+
evf = FillArrays.Fill(Float64(2), n-1)
943+
zf = FillArrays.Fill(Float64(0), n-1)
944+
B = Bidiagonal(dv, ev, :U)
945+
@test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(z, dv, ev)
946+
@test Tridiagonal{Float64}(B) === Tridiagonal(zf, dvf, evf)
947+
B = Bidiagonal(dv, ev, :L)
948+
@test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(ev, dv, z)
949+
@test Tridiagonal{Float64}(B) === Tridiagonal(evf, dvf, zf)
950+
end
951+
936952
end # module TestBidiagonal

test/testhelpers/FillArrays.jl

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Base.size(F::Fill) = F.size
1111

1212
Base.copy(F::Fill) = F
1313

14+
Base.AbstractArray{T,N}(F::Fill{<:Any,N}) where {T,N} = Fill(T(F.value), F.size)
15+
1416
@inline getindex_value(F::Fill) = F.value
1517

1618
@inline function Base.getindex(F::Fill{<:Any,N}, i::Vararg{Int,N}) where {N}
@@ -29,6 +31,8 @@ end
2931
F
3032
end
3133

34+
Base.zero(F::Fill) = Fill(zero(F.value), size(F))
35+
3236
Base.show(io::IO, F::Fill) = print(io, "Fill($(F.value), $(F.size))")
3337
Base.show(io::IO, ::MIME"text/plain", F::Fill) = show(io, F)
3438

0 commit comments

Comments
 (0)