Skip to content

Commit 8559e47

Browse files
dkarraschandreasnoack
authored andcommitted
Replace Val-types by singleton types in lu and qr (JuliaLang#40623)
Co-authored-by: Andreas Noack <[email protected]>
1 parent 94c3efa commit 8559e47

14 files changed

+95
-67
lines changed

NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ Standard library changes
115115
* The shape of an `UpperHessenberg` matrix is preserved under certain arithmetic operations, e.g. when multiplying or dividing by an `UpperTriangular` matrix. ([#40039])
116116
* `cis(A)` now supports matrix arguments ([#40194]).
117117
* `dot` now supports `UniformScaling` with `AbstractMatrix` ([#40250]).
118+
* `qr[!]` and `lu[!]` now support `LinearAlgebra.PivotingStrategy` (singleton type) values
119+
as their optional `pivot` argument: defaults are `qr(A, NoPivot())` (vs.
120+
`qr(A, ColumnNorm())` for pivoting) and `lu(A, RowMaximum())` (vs. `lu(A, NoPivot())`
121+
without pivoting); the former `Val{true/false}`-based calls are deprecated. ([#40623])
118122
* `det(M::AbstractMatrix{BigInt})` now calls `det_bareiss(M)`, which uses the [Bareiss](https://en.wikipedia.org/wiki/Bareiss_algorithm) algorithm to calculate precise values.([#40868]).
119123

120124
#### Markdown

stdlib/LinearAlgebra/src/LinearAlgebra.jl

+7
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ export
3535
BunchKaufman,
3636
Cholesky,
3737
CholeskyPivoted,
38+
ColumnNorm,
3839
Eigen,
3940
GeneralizedEigen,
4041
GeneralizedSVD,
4142
GeneralizedSchur,
4243
Hessenberg,
4344
LU,
4445
LDLt,
46+
NoPivot,
4547
QR,
4648
QRPivoted,
4749
LQ,
4850
Schur,
4951
SVD,
5052
Hermitian,
53+
RowMaximum,
5154
Symmetric,
5255
LowerTriangular,
5356
UpperTriangular,
@@ -164,6 +167,10 @@ abstract type Algorithm end
164167
struct DivideAndConquer <: Algorithm end
165168
struct QRIteration <: Algorithm end
166169

170+
abstract type PivotingStrategy end
171+
struct NoPivot <: PivotingStrategy end
172+
struct RowMaximum <: PivotingStrategy end
173+
struct ColumnNorm <: PivotingStrategy end
167174

168175
# Check that stride of matrix/vector is 1
169176
# Writing like this to avoid splatting penalty when called with multiple arguments,

stdlib/LinearAlgebra/src/dense.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ function factorize(A::StridedMatrix{T}) where T
13711371
end
13721372
return lu(A)
13731373
end
1374-
qr(A, Val(true))
1374+
qr(A, ColumnNorm())
13751375
end
13761376
factorize(A::Adjoint) = adjoint(factorize(parent(A)))
13771377
factorize(A::Transpose) = transpose(factorize(parent(A)))

stdlib/LinearAlgebra/src/factorization.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ size(F::Adjoint{<:Any,<:Factorization}) = reverse(size(parent(F)))
1616
size(F::Transpose{<:Any,<:Factorization}) = reverse(size(parent(F)))
1717

1818
checkpositivedefinite(info) = info == 0 || throw(PosDefException(info))
19-
checknonsingular(info, pivoted::Val{true}) = info == 0 || throw(SingularException(info))
20-
checknonsingular(info, pivoted::Val{false}) = info == 0 || throw(ZeroPivotException(info))
21-
checknonsingular(info) = checknonsingular(info, Val{true}())
19+
checknonsingular(info, ::RowMaximum) = info == 0 || throw(SingularException(info))
20+
checknonsingular(info, ::NoPivot) = info == 0 || throw(ZeroPivotException(info))
21+
checknonsingular(info) = checknonsingular(info, RowMaximum())
2222

2323
"""
2424
issuccess(F::Factorization)

stdlib/LinearAlgebra/src/generic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
11411141
end
11421142
return lu(A) \ B
11431143
end
1144-
return qr(A,Val(true)) \ B
1144+
return qr(A, ColumnNorm()) \ B
11451145
end
11461146

11471147
(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b

stdlib/LinearAlgebra/src/lu.jl

+32-20
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,26 @@ adjoint(F::LU) = Adjoint(F)
7676
transpose(F::LU) = Transpose(F)
7777

7878
# StridedMatrix
79-
function lu!(A::StridedMatrix{T}, pivot::Union{Val{false}, Val{true}} = Val(true);
80-
check::Bool = true) where T<:BlasFloat
81-
if pivot === Val(false)
82-
return generic_lufact!(A, pivot; check = check)
83-
end
79+
lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMaximum(); check=check)
80+
function lu!(A::StridedMatrix{T}, ::RowMaximum; check::Bool = true) where {T<:BlasFloat}
8481
lpt = LAPACK.getrf!(A)
8582
check && checknonsingular(lpt[3])
8683
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
8784
end
88-
function lu!(A::HermOrSym, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true)
85+
function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
86+
return generic_lufact!(A, pivot; check = check)
87+
end
88+
function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
8989
copytri!(A.data, A.uplo, isa(A, Hermitian))
9090
lu!(A.data, pivot; check = check)
9191
end
92+
# for backward compatibility
93+
# TODO: remove towards Julia v2
94+
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{true}; check::Bool = true) lu!(A, RowMaximum(); check=check)
95+
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{false}; check::Bool = true) lu!(A, NoPivot(); check=check)
9296

9397
"""
94-
lu!(A, pivot=Val(true); check = true) -> LU
98+
lu!(A, pivot = RowMaximum(); check = true) -> LU
9599
96100
`lu!` is the same as [`lu`](@ref), but saves space by overwriting the
97101
input `A`, instead of creating a copy. An [`InexactError`](@ref)
@@ -127,19 +131,22 @@ Stacktrace:
127131
[...]
128132
```
129133
"""
130-
lu!(A::StridedMatrix, pivot::Union{Val{false}, Val{true}} = Val(true); check::Bool = true) =
134+
lu!(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
131135
generic_lufact!(A, pivot; check = check)
132-
function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
133-
check::Bool = true) where {T,Pivot}
136+
function generic_lufact!(A::StridedMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum();
137+
check::Bool = true) where {T}
138+
# Extract values
134139
m, n = size(A)
135140
minmn = min(m,n)
141+
142+
# Initialize variables
136143
info = 0
137144
ipiv = Vector{BlasInt}(undef, minmn)
138145
@inbounds begin
139146
for k = 1:minmn
140147
# find index max
141148
kp = k
142-
if Pivot && k < m
149+
if pivot === RowMaximum() && k < m
143150
amax = abs(A[k, k])
144151
for i = k+1:m
145152
absi = abs(A[i,k])
@@ -175,7 +182,7 @@ function generic_lufact!(A::StridedMatrix{T}, ::Val{Pivot} = Val(true);
175182
end
176183
end
177184
end
178-
check && checknonsingular(info, Val{Pivot}())
185+
check && checknonsingular(info, pivot)
179186
return LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
180187
end
181188

@@ -200,7 +207,7 @@ end
200207

201208
# for all other types we must promote to a type which is stable under division
202209
"""
203-
lu(A, pivot=Val(true); check = true) -> F::LU
210+
lu(A, pivot = RowMaximum(); check = true) -> F::LU
204211
205212
Compute the LU factorization of `A`.
206213
@@ -211,7 +218,7 @@ validity (via [`issuccess`](@ref)) lies with the user.
211218
In most cases, if `A` is a subtype `S` of `AbstractMatrix{T}` with an element
212219
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`. If
213220
pivoting is chosen (default) the element type should also support [`abs`](@ref) and
214-
[`<`](@ref).
221+
[`<`](@ref). Pivoting can be turned off by passing `pivot = NoPivot()`.
215222
216223
The individual components of the factorization `F` can be accessed via [`getproperty`](@ref):
217224
@@ -267,11 +274,14 @@ julia> l == F.L && u == F.U && p == F.p
267274
true
268275
```
269276
"""
270-
function lu(A::AbstractMatrix{T}, pivot::Union{Val{false}, Val{true}}=Val(true);
271-
check::Bool = true) where T
277+
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
272278
S = lutype(T)
273279
lu!(copy_oftype(A, S), pivot; check = check)
274280
end
281+
# TODO: remove for Julia v2.0
282+
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMaximum(); check=check)
283+
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)
284+
275285

276286
lu(S::LU) = S
277287
function lu(x::Number; check::Bool=true)
@@ -481,9 +491,11 @@ inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))
481491
# Tridiagonal
482492

483493
# See dgttrf.f
484-
function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true);
485-
check::Bool = true) where {T,V}
494+
function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T,V}
495+
# Extract values
486496
n = size(A, 1)
497+
498+
# Initialize variables
487499
info = 0
488500
ipiv = Vector{BlasInt}(undef, n)
489501
dl = A.dl
@@ -500,7 +512,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
500512
end
501513
for i = 1:n-2
502514
# pivot or not?
503-
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
515+
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
504516
# No interchange
505517
if d[i] != 0
506518
fact = dl[i]/d[i]
@@ -523,7 +535,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Union{Val{false}, Val{true}} = Val(true
523535
end
524536
if n > 1
525537
i = n-1
526-
if pivot === Val(false) || abs(d[i]) >= abs(dl[i])
538+
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
527539
if d[i] != 0
528540
fact = dl[i]/d[i]
529541
dl[i] = fact

stdlib/LinearAlgebra/src/qr.jl

+18-11
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,17 @@ function qrfactPivotedUnblocked!(A::AbstractMatrix)
246246
end
247247

248248
# LAPACK version
249-
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{false} = Val(false); blocksize=36) =
249+
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; blocksize=36) =
250250
QRCompactWY(LAPACK.geqrt!(A, min(min(size(A)...), blocksize))...)
251-
qr!(A::StridedMatrix{<:BlasFloat}, ::Val{true}) = QRPivoted(LAPACK.geqp3!(A)...)
251+
qr!(A::StridedMatrix{<:BlasFloat}, ::ColumnNorm) = QRPivoted(LAPACK.geqp3!(A)...)
252252

253253
# Generic fallbacks
254254

255255
"""
256-
qr!(A, pivot=Val(false); blocksize)
256+
qr!(A, pivot = NoPivot(); blocksize)
257257
258-
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of
259-
[`StridedMatrix`](@ref), but saves space by overwriting the input `A`, instead of creating a copy.
258+
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of [`StridedMatrix`](@ref),
259+
but saves space by overwriting the input `A`, instead of creating a copy.
260260
An [`InexactError`](@ref) exception is thrown if the factorization produces a number not
261261
representable by the element type of `A`, e.g. for integer types.
262262
@@ -292,14 +292,17 @@ Stacktrace:
292292
[...]
293293
```
294294
"""
295-
qr!(A::AbstractMatrix, ::Val{false}) = qrfactUnblocked!(A)
296-
qr!(A::AbstractMatrix, ::Val{true}) = qrfactPivotedUnblocked!(A)
297-
qr!(A::AbstractMatrix) = qr!(A, Val(false))
295+
qr!(A::AbstractMatrix, ::NoPivot) = qrfactUnblocked!(A)
296+
qr!(A::AbstractMatrix, ::ColumnNorm) = qrfactPivotedUnblocked!(A)
297+
qr!(A::AbstractMatrix) = qr!(A, NoPivot())
298+
# TODO: Remove in Julia v2.0
299+
@deprecate qr!(A::AbstractMatrix, ::Val{true}) qr!(A, ColumnNorm())
300+
@deprecate qr!(A::AbstractMatrix, ::Val{false}) qr!(A, NoPivot())
298301

299302
_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
300303

301304
"""
302-
qr(A, pivot=Val(false); blocksize) -> F
305+
qr(A, pivot = NoPivot(); blocksize) -> F
303306
304307
Compute the QR factorization of the matrix `A`: an orthogonal (or unitary if `A` is
305308
complex-valued) matrix `Q`, and an upper triangular matrix `R` such that
@@ -310,7 +313,7 @@ A = Q R
310313
311314
The returned object `F` stores the factorization in a packed format:
312315
313-
- if `pivot == Val(true)` then `F` is a [`QRPivoted`](@ref) object,
316+
- if `pivot == ColumnNorm()` then `F` is a [`QRPivoted`](@ref) object,
314317
315318
- otherwise if the element type of `A` is a BLAS type ([`Float32`](@ref), [`Float64`](@ref),
316319
`ComplexF32` or `ComplexF64`), then `F` is a [`QRCompactWY`](@ref) object,
@@ -340,7 +343,7 @@ and `F.Q*A` are supported. A `Q` matrix can be converted into a regular matrix w
340343
orthogonal matrix.
341344
342345
The block size for QR decomposition can be specified by keyword argument
343-
`blocksize :: Integer` when `pivot == Val(false)` and `A isa StridedMatrix{<:BlasFloat}`.
346+
`blocksize :: Integer` when `pivot == NoPivot()` and `A isa StridedMatrix{<:BlasFloat}`.
344347
It is ignored when `blocksize > minimum(size(A))`. See [`QRCompactWY`](@ref).
345348
346349
!!! compat "Julia 1.4"
@@ -382,6 +385,10 @@ function qr(A::AbstractMatrix{T}, arg...; kwargs...) where T
382385
copyto!(AA, A)
383386
return qr!(AA, arg...; kwargs...)
384387
end
388+
# TODO: remove in Julia v2.0
389+
@deprecate qr(A::AbstractMatrix, ::Val{false}; kwargs...) qr(A, NoPivot(); kwargs...)
390+
@deprecate qr(A::AbstractMatrix, ::Val{true}; kwargs...) qr(A, ColumnNorm(); kwargs...)
391+
385392
qr(x::Number) = qr(fill(x,1,1))
386393
function qr(v::AbstractVector)
387394
require_one_based_indexing(v)

stdlib/LinearAlgebra/test/diagonal.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ end
565565
D = Diagonal(randn(5))
566566
Q = qr(randn(5, 5)).Q
567567
@test D * Q' == Array(D) * Q'
568-
Q = qr(randn(5, 5), Val(true)).Q
568+
Q = qr(randn(5, 5), ColumnNorm()).Q
569569
@test_throws ArgumentError lmul!(Q, D)
570570
end
571571

stdlib/LinearAlgebra/test/generic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -387,13 +387,13 @@ LinearAlgebra.Transpose(a::ModInt{n}) where {n} = transpose(a)
387387
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
388388
b = [ModInt{2}(1), ModInt{2}(0)]
389389

390-
@test A*(lu(A, Val(false))\b) == b
390+
@test A*(lu(A, NoPivot())\b) == b
391391

392392
# Needed for pivoting:
393393
Base.abs(a::ModInt{n}) where {n} = a
394394
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k
395395

396-
@test A*(lu(A, Val(true))\b) == b
396+
@test A*(lu(A, RowMaximum())\b) == b
397397
end
398398

399399
@testset "Issue 18742" begin

stdlib/LinearAlgebra/test/lq.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ rectangularQ(Q::LinearAlgebra.LQPackedQ) = convert(Array, Q)
4040
lqa = lq(a)
4141
x = lqa\b
4242
l,q = lqa.L, lqa.Q
43-
qra = qr(a, Val(true))
43+
qra = qr(a, ColumnNorm())
4444
@testset "Basic ops" begin
4545
@test size(lqa,1) == size(a,1)
4646
@test size(lqa,3) == 1

stdlib/LinearAlgebra/test/lu.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ dimg = randn(n)/2
6161
lua = factorize(a)
6262
@test_throws ErrorException lua.Z
6363
l,u,p = lua.L, lua.U, lua.p
64-
ll,ul,pl = lu(a)
64+
ll,ul,pl = @inferred lu(a)
6565
@test ll * ul a[pl,:]
6666
@test l*u a[p,:]
6767
@test (l*u)[invperm(p),:] a
@@ -85,9 +85,9 @@ dimg = randn(n)/2
8585
end
8686
κd = cond(Array(d),1)
8787
@testset "Tridiagonal LU" begin
88-
lud = lu(d)
88+
lud = @inferred lu(d)
8989
@test LinearAlgebra.issuccess(lud)
90-
@test lu(lud) == lud
90+
@test @inferred(lu(lud)) == lud
9191
@test_throws ErrorException lud.Z
9292
@test lud.L*lud.U lud.P*Array(d)
9393
@test lud.L*lud.U Array(d)[lud.p,:]
@@ -199,14 +199,14 @@ dimg = randn(n)/2
199199
@test lua.L*lua.U lua.P*a[:,1:n1]
200200
end
201201
@testset "Fat LU" begin
202-
lua = lu(a[1:n1,:])
202+
lua = @inferred lu(a[1:n1,:])
203203
@test lua.L*lua.U lua.P*a[1:n1,:]
204204
end
205205
end
206206

207207
@testset "LU of Symmetric/Hermitian" begin
208208
for HS in (Hermitian(a'a), Symmetric(a'a))
209-
luhs = lu(HS)
209+
luhs = @inferred lu(HS)
210210
@test luhs.L*luhs.U luhs.P*Matrix(HS)
211211
end
212212
end
@@ -229,12 +229,12 @@ end
229229
@test_throws SingularException lu!(copy(A); check = true)
230230
@test !issuccess(lu(A; check = false))
231231
@test !issuccess(lu!(copy(A); check = false))
232-
@test_throws ZeroPivotException lu(A, Val(false))
233-
@test_throws ZeroPivotException lu!(copy(A), Val(false))
234-
@test_throws ZeroPivotException lu(A, Val(false); check = true)
235-
@test_throws ZeroPivotException lu!(copy(A), Val(false); check = true)
236-
@test !issuccess(lu(A, Val(false); check = false))
237-
@test !issuccess(lu!(copy(A), Val(false); check = false))
232+
@test_throws ZeroPivotException lu(A, NoPivot())
233+
@test_throws ZeroPivotException lu!(copy(A), NoPivot())
234+
@test_throws ZeroPivotException lu(A, NoPivot(); check = true)
235+
@test_throws ZeroPivotException lu!(copy(A), NoPivot(); check = true)
236+
@test !issuccess(lu(A, NoPivot(); check = false))
237+
@test !issuccess(lu!(copy(A), NoPivot(); check = false))
238238
F = lu(A; check = false)
239239
@test sprint((io, x) -> show(io, "text/plain", x), F) ==
240240
"Failed factorization of type $(typeof(F))"
@@ -320,7 +320,7 @@ include("trickyarithmetic.jl")
320320
@testset "lu with type whose sum is another type" begin
321321
A = TrickyArithmetic.A[1 2; 3 4]
322322
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
323-
B = lu(A, Val(false))
323+
B = lu(A, NoPivot())
324324
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
325325
end
326326

0 commit comments

Comments
 (0)