Skip to content

Commit 902b6a0

Browse files
fix lu deprecation warnings on Julia nightly (#434)
Ref JuliaLang/julia#40623 Co-authored-by: Lyndon White <[email protected]>
1 parent 5d4e021 commit 902b6a0

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
1515
# for derivations for wide and tall matrices, see
1616
# https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/
1717

18+
const LU_RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true}
19+
const LU_NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false}
20+
1821
function frule(
19-
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
22+
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
2023
)
2124
F = lu!(A, pivot; kwargs...)
22-
∂factors = pivot === Val(true) ? ΔA[F.p, :] : ΔA
25+
∂factors = pivot isa LU_RowMaximum ? ΔA[F.p, :] : ΔA
2326
m, n = size(∂factors)
2427
q = min(m, n)
2528
if m == n # square A
@@ -72,7 +75,7 @@ function frule(
7275
end
7376

7477
function rrule(
75-
::typeof(lu), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
78+
::typeof(lu), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
7679
)
7780
F = lu(A, pivot; kwargs...)
7881
function lu_pullback(ΔF::Tangent)
@@ -124,7 +127,7 @@ function rrule(
124127
ldiv!(L1', ∂A1)
125128
rdiv!(∂A, U')
126129
end
127-
if pivot === Val(true)
130+
if pivot isa LU_RowMaximum
128131
∂A = ∂A[invperm(F.p), :]
129132
end
130133
return NoTangent(), ∂A, NoTangent()

test/rulesets/LinearAlgebra/factorization.jl

+12-9
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ function FiniteDifferences.to_vec(x::Val)
2020
return Bool[], Val_from_vec
2121
end
2222

23+
const LU_ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true)
24+
const LU_NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false)
25+
2326
@testset "Factorizations" begin
2427
@testset "lu decomposition" begin
2528
n = 10
2629
@testset "lu! frule" begin
2730
@testset "lu!(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
2831
T in (Float64, ComplexF64),
29-
pivot in (Val(true), Val(false)),
32+
pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT),
3033
m in (7, 10, 13)
3134

3235
test_frule(lu!, randn(T, m, n), pivot NoTangent())
@@ -35,26 +38,26 @@ end
3538
Asingular = zeros(n, n)
3639
ΔAsingular = rand_tangent(Asingular)
3740
@test_throws SingularException frule(
38-
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true)
41+
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), LU_ROW_MAXIMUM
3942
)
40-
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, Val(true); check=false)
43+
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, LU_ROW_MAXIMUM; check=false)
4144
@test true # above line would have errored if this was not working right
4245
end
4346
end
4447
@testset "lu rrule" begin
4548
@testset "lu(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
4649
T in (Float64, ComplexF64),
47-
pivot in (Val(true), Val(false)),
50+
pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT),
4851
m in (7, 10, 13)
4952

5053
test_rrule(lu, randn(T, m, n), pivot NoTangent())
5154
end
5255
@testset "check=false passed to primal function" begin
5356
Asingular = zeros(n, n)
54-
F = lu(Asingular, Val(true); check=false)
57+
F = lu(Asingular, LU_ROW_MAXIMUM; check=false)
5558
ΔF = Tangent{typeof(F)}(; U=rand_tangent(F.U), L=rand_tangent(F.L))
56-
@test_throws SingularException rrule(lu, Asingular, Val(true))
57-
_, back = rrule(lu, Asingular, Val(true); check=false)
59+
@test_throws SingularException rrule(lu, Asingular, LU_ROW_MAXIMUM)
60+
_, back = rrule(lu, Asingular, LU_ROW_MAXIMUM; check=false)
5861
back(ΔF)
5962
@test true # above line would have errored if this was not working right
6063
end
@@ -72,8 +75,8 @@ end
7275
end
7376
@testset "matrix inverse using LU" begin
7477
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
75-
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), Val(true)))
76-
test_rrule(inv, lu(randn(T, n, n), Val(true)))
78+
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), LU_ROW_MAXIMUM))
79+
test_rrule(inv, lu(randn(T, n, n), LU_ROW_MAXIMUM))
7780
end
7881
end
7982
end

0 commit comments

Comments
 (0)