Skip to content

Commit acb68ca

Browse files
committed
use singleton types
1 parent 3ebb14c commit acb68ca

14 files changed

+83
-100
lines changed

NEWS.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ Standard library changes
103103
* The shape of an `UpperHessenberg` matrix is preserved under certain arithmetic operations, e.g. when multiplying or dividing by an `UpperTriangular` matrix. ([#40039])
104104
* `cis(A)` now supports matrix arguments ([#40194]).
105105
* `dot` now supports `UniformScaling` with `AbstractMatrix` ([#40250]).
106-
* `qr[!]` and `lu[!]` now support `Symbol` values as their optional `pivot` argument:
107-
defaults are `qr(A, :none)` (vs. `qr(A, :colnorm)` for pivoting) and `lu(A, :rowmax)`
108-
(vs. `lu(A, :none)` without pivoting); the former `Val{true/false}`-based calls are deprecated. ([#40623])
106+
* `qr[!]` and `lu[!]` now support `PivotingStrategy` values as their optional `pivot` argument:
107+
defaults are `qr(A, NoPivot())` (vs. `qr(A, ColNorm())` for pivoting) and `lu(A, RowMax())`
108+
(vs. `lu(A, NoPivot())` without pivoting); the former `Val{true/false}`-based calls are deprecated. ([#40623])
109109

110110
#### Markdown
111111

stdlib/LinearAlgebra/src/LinearAlgebra.jl

+7
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ export
3535
BunchKaufman,
3636
Cholesky,
3737
CholeskyPivoted,
38+
ColNorm,
3839
Eigen,
3940
GeneralizedEigen,
4041
GeneralizedSVD,
4142
GeneralizedSchur,
4243
Hessenberg,
4344
LU,
4445
LDLt,
46+
NoPivot,
4547
QR,
4648
QRPivoted,
4749
LQ,
4850
Schur,
4951
SVD,
5052
Hermitian,
53+
RowMax,
5154
Symmetric,
5255
LowerTriangular,
5356
UpperTriangular,
@@ -164,6 +167,10 @@ abstract type Algorithm end
164167
struct DivideAndConquer <: Algorithm end
165168
struct QRIteration <: Algorithm end
166169

170+
abstract type PivotingStrategy end
171+
struct NoPivot <: PivotingStrategy end
172+
struct RowMax <: PivotingStrategy end
173+
struct ColNorm <: PivotingStrategy end
167174

168175
# Check that stride of matrix/vector is 1
169176
# Writing like this to avoid splatting penalty when called with multiple arguments,

stdlib/LinearAlgebra/src/dense.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ function factorize(A::StridedMatrix{T}) where T
13711371
end
13721372
return lu(A)
13731373
end
1374-
qr(A, :colnorm)
1374+
qr(A, ColNorm())
13751375
end
13761376
factorize(A::Adjoint) = adjoint(factorize(parent(A)))
13771377
factorize(A::Transpose) = transpose(factorize(parent(A)))

stdlib/LinearAlgebra/src/factorization.jl

+3-8
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,9 @@ size(F::Adjoint{<:Any,<:Factorization}) = reverse(size(parent(F)))
1616
size(F::Transpose{<:Any,<:Factorization}) = reverse(size(parent(F)))
1717

1818
checkpositivedefinite(info) = info == 0 || throw(PosDefException(info))
19-
function checknonsingular(info, pivoted = :rowmax)
20-
if info != 0
21-
pivoted === :rowmax && throw(SingularException(info))
22-
pivoted === :none && throw(ZeroPivotException(info))
23-
else
24-
return nothing
25-
end
26-
end
19+
checknonsingular(info, ::RowMax) = info == 0 || throw(SingularException(info))
20+
checknonsingular(info, ::NoPivot) = info == 0 || throw(ZeroPivotException(info))
21+
checknonsingular(info) = checknonsingular(info, RowMax())
2722

2823
"""
2924
issuccess(F::Factorization)

stdlib/LinearAlgebra/src/generic.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
11411141
end
11421142
return lu(A) \ B
11431143
end
1144-
return qr(A, :colnorm) \ B
1144+
return qr(A, ColNorm()) \ B
11451145
end
11461146

11471147
(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b

stdlib/LinearAlgebra/src/lu.jl

+30-33
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,26 @@ adjoint(F::LU) = Adjoint(F)
7676
transpose(F::LU) = Transpose(F)
7777

7878
# StridedMatrix
79-
function lu!(A::StridedMatrix{T}, pivot::Symbol = :rowmax; check::Bool = true) where {T<:BlasFloat}
80-
if pivot === :none
81-
return generic_lufact!(A, pivot; check = check)
82-
elseif pivot === :rowmax
83-
lpt = LAPACK.getrf!(A)
84-
check && checknonsingular(lpt[3])
85-
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
86-
else
87-
throw(ArgumentError("only `:rowmax` and `:none` are supported as `pivot` argument but you supplied `$pivot`"))
88-
end
79+
lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMax(); check=check)
80+
function lu!(A::StridedMatrix{T}, ::RowMax; check::Bool = true) where {T<:BlasFloat}
81+
lpt = LAPACK.getrf!(A)
82+
check && checknonsingular(lpt[3])
83+
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
8984
end
90-
function lu!(A::HermOrSym, pivot::Symbol = :rowmax; check::Bool = true)
85+
function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
86+
return generic_lufact!(A, pivot; check = check)
87+
end
88+
function lu!(A::HermOrSym, pivot::PivotingStrategy = RowMax(); check::Bool = true)
9189
copytri!(A.data, A.uplo, isa(A, Hermitian))
9290
lu!(A.data, pivot; check = check)
9391
end
9492
# for backward compatibility
9593
# TODO: remove towards Julia v2
96-
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{true}; check::Bool = true) lu!(A, :rowmax; check=check)
97-
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{false}; check::Bool = true) lu!(A, :none; check=check)
94+
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{true}; check::Bool = true) lu!(A, RowMax(); check=check)
95+
@deprecate lu!(A::Union{StridedMatrix,HermOrSym,Tridiagonal}, ::Val{false}; check::Bool = true) lu!(A, NoPivot(); check=check)
9896

9997
"""
100-
lu!(A, pivot = :rowmax; check = true) -> LU
98+
lu!(A, pivot = RowMax(); check = true) -> LU
10199
102100
`lu!` is the same as [`lu`](@ref), but saves space by overwriting the
103101
input `A`, instead of creating a copy. An [`InexactError`](@ref)
@@ -133,26 +131,26 @@ Stacktrace:
133131
[...]
134132
```
135133
"""
136-
lu!(A::StridedMatrix, pivot::Symbol = :rowmax; check::Bool = true) =
134+
lu!(A::StridedMatrix, pivot::PivotingStrategy = RowMax(); check::Bool = true) =
137135
generic_lufact!(A, pivot; check = check)
138-
function generic_lufact!(A::StridedMatrix{T}, pivot = :rowmax; check::Bool = true) where T
139-
# Check arguments
140-
if pivot !== :rowmax && pivot !== :none
141-
throw(ArgumentError("only `rowmax` and `none` are supported as `pivot` argument but you supplied `$pivot`"))
142-
end
143-
136+
function generic_lufact!(A::StridedMatrix{T}, pivot::PivotingStrategy = RowMax();
137+
check::Bool = true) where {T}
144138
# Extract values
145139
m, n = size(A)
146140
minmn = min(m,n)
147141

142+
if pivot !== RowMax() && pivot !== NoPivot()
143+
throw(ArgumentError("only `RowMax()` and `NoPivot()` are supported as `pivot` argument but you supplied `$pivot`"))
144+
end
145+
148146
# Initialize variables
149147
info = 0
150148
ipiv = Vector{BlasInt}(undef, minmn)
151149
@inbounds begin
152150
for k = 1:minmn
153151
# find index max
154152
kp = k
155-
if pivot === :rowmax && k < m
153+
if pivot === RowMax() && k < m
156154
amax = abs(A[k, k])
157155
for i = k+1:m
158156
absi = abs(A[i,k])
@@ -213,7 +211,7 @@ end
213211

214212
# for all other types we must promote to a type which is stable under division
215213
"""
216-
lu(A, pivot = :rowmax; check = true) -> F::LU
214+
lu(A, pivot = RowMax(); check = true) -> F::LU
217215
218216
Compute the LU factorization of `A`.
219217
@@ -224,7 +222,7 @@ validity (via [`issuccess`](@ref)) lies with the user.
224222
In most cases, if `A` is a subtype `S` of `AbstractMatrix{T}` with an element
225223
type `T` supporting `+`, `-`, `*` and `/`, the return type is `LU{T,S{T}}`. If
226224
pivoting is chosen (default) the element type should also support [`abs`](@ref) and
227-
[`<`](@ref). Pivoting can be turned off by passing `pivot = :none`.
225+
[`<`](@ref). Pivoting can be turned off by passing `pivot = NoPivot()`.
228226
229227
The individual components of the factorization `F` can be accessed via [`getproperty`](@ref):
230228
@@ -280,13 +278,13 @@ julia> l == F.L && u == F.U && p == F.p
280278
true
281279
```
282280
"""
283-
function lu(A::AbstractMatrix{T}, pivot::Symbol = :rowmax; check::Bool = true) where {T}
281+
function lu(A::AbstractMatrix{T}, pivot::PivotingStrategy = RowMax(); check::Bool = true) where {T}
284282
S = lutype(T)
285283
lu!(copy_oftype(A, S), pivot; check = check)
286284
end
287285
# TODO: remove for Julia v2.0
288-
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, :rowmax; check=check)
289-
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, :none; check=check)
286+
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMax(); check=check)
287+
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)
290288

291289

292290
lu(S::LU) = S
@@ -497,13 +495,12 @@ inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))
497495
# Tridiagonal
498496

499497
# See dgttrf.f
500-
function lu!(A::Tridiagonal{T,V}, pivot::Symbol = :rowmax; check::Bool = true) where {T,V}
498+
function lu!(A::Tridiagonal{T,V}, pivot::PivotingStrategy = RowMax(); check::Bool = true) where {T,V}
501499
# Extract values
502500
n = size(A, 1)
503501

504-
# Check arguments
505-
if pivot !== :rowmax && pivot !== :none
506-
throw(ArgumentError("only `:row` and `:none` are supported as `pivot` argument but you supplied `$pivot`"))
502+
if pivot !== RowMax() && pivot !== NoPivot()
503+
throw(ArgumentError("only `RowMax()` and `NoPivot()` are supported as `pivot` argument but you supplied `$pivot`"))
507504
end
508505

509506
# Initialize variables
@@ -523,7 +520,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Symbol = :rowmax; check::Bool = true) w
523520
end
524521
for i = 1:n-2
525522
# pivot or not?
526-
if pivot === :none || abs(d[i]) >= abs(dl[i])
523+
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
527524
# No interchange
528525
if d[i] != 0
529526
fact = dl[i]/d[i]
@@ -546,7 +543,7 @@ function lu!(A::Tridiagonal{T,V}, pivot::Symbol = :rowmax; check::Bool = true) w
546543
end
547544
if n > 1
548545
i = n-1
549-
if pivot === :none || abs(d[i]) >= abs(dl[i])
546+
if pivot === NoPivot() || abs(d[i]) >= abs(dl[i])
550547
if d[i] != 0
551548
fact = dl[i]/d[i]
552549
dl[i] = fact

stdlib/LinearAlgebra/src/qr.jl

+15-27
Original file line numberDiff line numberDiff line change
@@ -246,20 +246,14 @@ function qrfactPivotedUnblocked!(A::AbstractMatrix)
246246
end
247247

248248
# LAPACK version
249-
Base.@aggressive_constprop function qr!(A::StridedMatrix{<:BlasFloat}, pivot::Symbol = :none; blocksize=36)
250-
if pivot === :none
251-
return QRCompactWY(LAPACK.geqrt!(A, min(min(size(A)...), blocksize))...)
252-
elseif pivot === :colnorm
253-
return QRPivoted(LAPACK.geqp3!(A)...)
254-
else
255-
throw(ArgumentError("only `:colnorm` and `:none` are supported as `pivot` argument but you supplied `$pivot`"))
256-
end
257-
end
249+
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; blocksize=36) =
250+
QRCompactWY(LAPACK.geqrt!(A, min(min(size(A)...), blocksize))...)
251+
qr!(A::StridedMatrix{<:BlasFloat}, ::ColNorm) = QRPivoted(LAPACK.geqp3!(A)...)
258252

259253
# Generic fallbacks
260254

261255
"""
262-
qr!(A, pivot = :none; blocksize)
256+
qr!(A, pivot = NoPivot(); blocksize)
263257
264258
`qr!` is the same as [`qr`](@ref) when `A` is a subtype of [`StridedMatrix`](@ref),
265259
but saves space by overwriting the input `A`, instead of creating a copy.
@@ -298,23 +292,17 @@ Stacktrace:
298292
[...]
299293
```
300294
"""
301-
Base.@aggressive_constprop function qr!(A::AbstractMatrix, pivot::Symbol = :none)
302-
if pivot === :none
303-
return qrfactUnblocked!(A)
304-
elseif pivot === :colnorm
305-
return qrfactPivotedUnblocked!(A)
306-
else
307-
throw(ArgumentError("only `:colnorm` and `:none` are supported as `pivot` argument but you supplied `$pivot`"))
308-
end
309-
end
295+
qr!(A::AbstractMatrix, ::NoPivot) = qrfactUnblocked!(A)
296+
qr!(A::AbstractMatrix, ::ColNorm) = qrfactPivotedUnblocked!(A)
297+
qr!(A::AbstractMatrix) = qr!(A, NoPivot())
310298
# TODO: Remove in Julia v2.0
311-
@deprecate qr!(A::AbstractMatrix, ::Val{true}) qr!(A, :colnorm)
312-
@deprecate qr!(A::AbstractMatrix, ::Val{false}) qr!(A, :none)
299+
@deprecate qr!(A::AbstractMatrix, ::Val{true}) qr!(A, ColNorm())
300+
@deprecate qr!(A::AbstractMatrix, ::Val{false}) qr!(A, NoPivot())
313301

314302
_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
315303

316304
"""
317-
qr(A, pivot = :none; blocksize) -> F
305+
qr(A, pivot = NoPivot(); blocksize) -> F
318306
319307
Compute the QR factorization of the matrix `A`: an orthogonal (or unitary if `A` is
320308
complex-valued) matrix `Q`, and an upper triangular matrix `R` such that
@@ -325,7 +313,7 @@ A = Q R
325313
326314
The returned object `F` stores the factorization in a packed format:
327315
328-
- if `pivot == :colnorm` then `F` is a [`QRPivoted`](@ref) object,
316+
- if `pivot == ColNorm()` then `F` is a [`QRPivoted`](@ref) object,
329317
330318
- otherwise if the element type of `A` is a BLAS type ([`Float32`](@ref), [`Float64`](@ref),
331319
`ComplexF32` or `ComplexF64`), then `F` is a [`QRCompactWY`](@ref) object,
@@ -355,7 +343,7 @@ and `F.Q*A` are supported. A `Q` matrix can be converted into a regular matrix w
355343
orthogonal matrix.
356344
357345
The block size for QR decomposition can be specified by keyword argument
358-
`blocksize :: Integer` when `pivot == :none` and `A isa StridedMatrix{<:BlasFloat}`.
346+
`blocksize :: Integer` when `pivot == NoPivot()` and `A isa StridedMatrix{<:BlasFloat}`.
359347
It is ignored when `blocksize > minimum(size(A))`. See [`QRCompactWY`](@ref).
360348
361349
!!! compat "Julia 1.4"
@@ -391,15 +379,15 @@ true
391379
elementary reflectors, so that the `Q` and `R` matrices can be stored
392380
compactly rather as two separate dense matrices.
393381
"""
394-
Base.@aggressive_constprop function qr(A::AbstractMatrix{T}, arg...; kwargs...) where T
382+
function qr(A::AbstractMatrix{T}, arg...; kwargs...) where T
395383
require_one_based_indexing(A)
396384
AA = similar(A, _qreltype(T), size(A))
397385
copyto!(AA, A)
398386
return qr!(AA, arg...; kwargs...)
399387
end
400388
# TODO: remove in Julia v2.0
401-
@deprecate qr(A::AbstractMatrix, ::Val{false}; kwargs...) qr(A, :none; kwargs...)
402-
@deprecate qr(A::AbstractMatrix, ::Val{true}; kwargs...) qr(A, :colnorm; kwargs...)
389+
@deprecate qr(A::AbstractMatrix, ::Val{false}; kwargs...) qr(A, NoPivot(); kwargs...)
390+
@deprecate qr(A::AbstractMatrix, ::Val{true}; kwargs...) qr(A, ColNorm(); kwargs...)
403391

404392
qr(x::Number) = qr(fill(x,1,1))
405393
function qr(v::AbstractVector)

stdlib/LinearAlgebra/test/diagonal.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ end
565565
D = Diagonal(randn(5))
566566
Q = qr(randn(5, 5)).Q
567567
@test D * Q' == Array(D) * Q'
568-
Q = qr(randn(5, 5), :colnorm).Q
568+
Q = qr(randn(5, 5), ColNorm()).Q
569569
@test_throws ArgumentError lmul!(Q, D)
570570
end
571571

stdlib/LinearAlgebra/test/generic.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,13 @@ LinearAlgebra.Transpose(a::ModInt{n}) where {n} = transpose(a)
382382
A = [ModInt{2}(1) ModInt{2}(0); ModInt{2}(1) ModInt{2}(1)]
383383
b = [ModInt{2}(1), ModInt{2}(0)]
384384

385-
@test A*(lu(A, :none)\b) == b
385+
@test A*(lu(A, NoPivot())\b) == b
386386

387387
# Needed for pivoting:
388388
Base.abs(a::ModInt{n}) where {n} = a
389389
Base.:<(a::ModInt{n}, b::ModInt{n}) where {n} = a.k < b.k
390390

391-
@test A*(lu(A, :rowmax)\b) == b
391+
@test A*(lu(A, RowMax())\b) == b
392392
end
393393

394394
@testset "Issue 18742" begin

stdlib/LinearAlgebra/test/lq.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ rectangularQ(Q::LinearAlgebra.LQPackedQ) = convert(Array, Q)
4040
lqa = lq(a)
4141
x = lqa\b
4242
l,q = lqa.L, lqa.Q
43-
qra = qr(a, :colnorm)
43+
qra = qr(a, ColNorm())
4444
@testset "Basic ops" begin
4545
@test size(lqa,1) == size(a,1)
4646
@test size(lqa,3) == 1

stdlib/LinearAlgebra/test/lu.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ dimg = randn(n)/2
8585
end
8686
κd = cond(Array(d),1)
8787
@testset "Tridiagonal LU" begin
88-
lud = @inferred lu(d)
88+
lud = @inferred lu(d)
8989
@test LinearAlgebra.issuccess(lud)
9090
@test @inferred(lu(lud)) == lud
9191
@test_throws ErrorException lud.Z
@@ -229,12 +229,12 @@ end
229229
@test_throws SingularException lu!(copy(A); check = true)
230230
@test !issuccess(lu(A; check = false))
231231
@test !issuccess(lu!(copy(A); check = false))
232-
@test_throws ZeroPivotException lu(A, :none)
233-
@test_throws ZeroPivotException lu!(copy(A), :none)
234-
@test_throws ZeroPivotException lu(A, :none; check = true)
235-
@test_throws ZeroPivotException lu!(copy(A), :none; check = true)
236-
@test !issuccess(lu(A, :none; check = false))
237-
@test !issuccess(lu!(copy(A), :none; check = false))
232+
@test_throws ZeroPivotException lu(A, NoPivot())
233+
@test_throws ZeroPivotException lu!(copy(A), NoPivot())
234+
@test_throws ZeroPivotException lu(A, NoPivot(); check = true)
235+
@test_throws ZeroPivotException lu!(copy(A), NoPivot(); check = true)
236+
@test !issuccess(lu(A, NoPivot(); check = false))
237+
@test !issuccess(lu!(copy(A), NoPivot(); check = false))
238238
F = lu(A; check = false)
239239
@test sprint((io, x) -> show(io, "text/plain", x), F) ==
240240
"Failed factorization of type $(typeof(F))"
@@ -320,7 +320,7 @@ include("trickyarithmetic.jl")
320320
@testset "lu with type whose sum is another type" begin
321321
A = TrickyArithmetic.A[1 2; 3 4]
322322
ElT = TrickyArithmetic.D{TrickyArithmetic.C,TrickyArithmetic.C}
323-
B = lu(A, :none)
323+
B = lu(A, NoPivot())
324324
@test B isa LinearAlgebra.LU{ElT,Matrix{ElT}}
325325
end
326326

0 commit comments

Comments
 (0)