@@ -20,13 +20,16 @@ function FiniteDifferences.to_vec(x::Val)
20
20
return Bool[], Val_from_vec
21
21
end
22
22
23
+ const LU_ROW_MAXIMUM = VERSION >= v " 1.7.0-DEV.1188" ? RowMaximum () : Val (true )
24
+ const LU_NO_PIVOT = VERSION >= v " 1.7.0-DEV.1188" ? NoPivot () : Val (false )
25
+
23
26
@testset " Factorizations" begin
24
27
@testset " lu decomposition" begin
25
28
n = 10
26
29
@testset " lu! frule" begin
27
30
@testset " lu!(A::Matrix{$T }, $pivot ) for size(A)=($m , $n )" for
28
31
T in (Float64, ComplexF64),
29
- pivot in (Val ( true ), Val ( false ) ),
32
+ pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT ),
30
33
m in (7 , 10 , 13 )
31
34
32
35
test_frule (lu!, randn (T, m, n), pivot ⊢ NoTangent ())
35
38
Asingular = zeros (n, n)
36
39
ΔAsingular = rand_tangent (Asingular)
37
40
@test_throws SingularException frule (
38
- (ZeroTangent (), copy (ΔAsingular)), lu!, copy (Asingular), Val ( true )
41
+ (ZeroTangent (), copy (ΔAsingular)), lu!, copy (Asingular), LU_ROW_MAXIMUM
39
42
)
40
- frule ((ZeroTangent (), ΔAsingular), lu!, Asingular, Val ( true ) ; check= false )
43
+ frule ((ZeroTangent (), ΔAsingular), lu!, Asingular, LU_ROW_MAXIMUM ; check= false )
41
44
@test true # above line would have errored if this was not working right
42
45
end
43
46
end
44
47
@testset " lu rrule" begin
45
48
@testset " lu(A::Matrix{$T }, $pivot ) for size(A)=($m , $n )" for
46
49
T in (Float64, ComplexF64),
47
- pivot in (Val ( true ), Val ( false ) ),
50
+ pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT ),
48
51
m in (7 , 10 , 13 )
49
52
50
53
test_rrule (lu, randn (T, m, n), pivot ⊢ NoTangent ())
51
54
end
52
55
@testset " check=false passed to primal function" begin
53
56
Asingular = zeros (n, n)
54
- F = lu (Asingular, Val ( true ) ; check= false )
57
+ F = lu (Asingular, LU_ROW_MAXIMUM ; check= false )
55
58
ΔF = Tangent {typeof(F)} (; U= rand_tangent (F. U), L= rand_tangent (F. L))
56
- @test_throws SingularException rrule (lu, Asingular, Val ( true ) )
57
- _, back = rrule (lu, Asingular, Val ( true ) ; check= false )
59
+ @test_throws SingularException rrule (lu, Asingular, LU_ROW_MAXIMUM )
60
+ _, back = rrule (lu, Asingular, LU_ROW_MAXIMUM ; check= false )
58
61
back (ΔF)
59
62
@test true # above line would have errored if this was not working right
60
63
end
72
75
end
73
76
@testset " matrix inverse using LU" begin
74
77
@testset " inv!(lu(::LU{$T ,<:StridedMatrix}))" for T in (Float64,ComplexF64)
75
- test_frule (LinearAlgebra. inv!, lu (randn (T, n, n), Val ( true ) ))
76
- test_rrule (inv, lu (randn (T, n, n), Val ( true ) ))
78
+ test_frule (LinearAlgebra. inv!, lu (randn (T, n, n), LU_ROW_MAXIMUM ))
79
+ test_rrule (inv, lu (randn (T, n, n), LU_ROW_MAXIMUM ))
77
80
end
78
81
end
79
82
end
0 commit comments