Skip to content

Commit 407d332

Browse files
simeonschauboxinabox
andauthoredJun 8, 2021
fix lu deprecation warnings on Julia nightly (#434)
Ref JuliaLang/julia#40623 Co-authored-by: Lyndon White <[email protected]>
1 parent a1809f4 commit 407d332

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed
 

‎Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.3"
3+
version = "0.8.4"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

‎src/rulesets/LinearAlgebra/factorization.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
1515
# for derivations for wide and tall matrices, see
1616
# https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/
1717

18+
const LU_RowMaximum = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum : Val{true}
19+
const LU_NoPivot = VERSION >= v"1.7.0-DEV.1188" ? NoPivot : Val{false}
20+
1821
function frule(
19-
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
22+
(_, ΔA), ::typeof(lu!), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
2023
)
2124
F = lu!(A, pivot; kwargs...)
22-
∂factors = pivot === Val(true) ? ΔA[F.p, :] : ΔA
25+
∂factors = pivot isa LU_RowMaximum ? ΔA[F.p, :] : ΔA
2326
m, n = size(∂factors)
2427
q = min(m, n)
2528
if m == n # square A
@@ -72,7 +75,7 @@ function frule(
7275
end
7376

7477
function rrule(
75-
::typeof(lu), A::StridedMatrix, pivot::Union{Val{false},Val{true}}; kwargs...
78+
::typeof(lu), A::StridedMatrix, pivot::Union{LU_RowMaximum,LU_NoPivot}; kwargs...
7679
)
7780
F = lu(A, pivot; kwargs...)
7881
function lu_pullback(ΔF::Tangent)
@@ -124,7 +127,7 @@ function rrule(
124127
ldiv!(L1', ∂A1)
125128
rdiv!(∂A, U')
126129
end
127-
if pivot === Val(true)
130+
if pivot isa LU_RowMaximum
128131
∂A = ∂A[invperm(F.p), :]
129132
end
130133
return NoTangent(), ∂A, NoTangent()

‎test/rulesets/LinearAlgebra/factorization.jl

+12-9
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ function FiniteDifferences.to_vec(x::Val)
2020
return Bool[], Val_from_vec
2121
end
2222

23+
const LU_ROW_MAXIMUM = VERSION >= v"1.7.0-DEV.1188" ? RowMaximum() : Val(true)
24+
const LU_NO_PIVOT = VERSION >= v"1.7.0-DEV.1188" ? NoPivot() : Val(false)
25+
2326
@testset "Factorizations" begin
2427
@testset "lu decomposition" begin
2528
n = 10
2629
@testset "lu! frule" begin
2730
@testset "lu!(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
2831
T in (Float64, ComplexF64),
29-
pivot in (Val(true), Val(false)),
32+
pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT),
3033
m in (7, 10, 13)
3134

3235
test_frule(lu!, randn(T, m, n), pivot NoTangent())
@@ -35,26 +38,26 @@ end
3538
Asingular = zeros(n, n)
3639
ΔAsingular = rand_tangent(Asingular)
3740
@test_throws SingularException frule(
38-
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true)
41+
(ZeroTangent(), copy(ΔAsingular)), lu!, copy(Asingular), LU_ROW_MAXIMUM
3942
)
40-
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, Val(true); check=false)
43+
frule((ZeroTangent(), ΔAsingular), lu!, Asingular, LU_ROW_MAXIMUM; check=false)
4144
@test true # above line would have errored if this was not working right
4245
end
4346
end
4447
@testset "lu rrule" begin
4548
@testset "lu(A::Matrix{$T}, $pivot) for size(A)=($m, $n)" for
4649
T in (Float64, ComplexF64),
47-
pivot in (Val(true), Val(false)),
50+
pivot in (LU_ROW_MAXIMUM, LU_NO_PIVOT),
4851
m in (7, 10, 13)
4952

5053
test_rrule(lu, randn(T, m, n), pivot NoTangent())
5154
end
5255
@testset "check=false passed to primal function" begin
5356
Asingular = zeros(n, n)
54-
F = lu(Asingular, Val(true); check=false)
57+
F = lu(Asingular, LU_ROW_MAXIMUM; check=false)
5558
ΔF = Tangent{typeof(F)}(; U=rand_tangent(F.U), L=rand_tangent(F.L))
56-
@test_throws SingularException rrule(lu, Asingular, Val(true))
57-
_, back = rrule(lu, Asingular, Val(true); check=false)
59+
@test_throws SingularException rrule(lu, Asingular, LU_ROW_MAXIMUM)
60+
_, back = rrule(lu, Asingular, LU_ROW_MAXIMUM; check=false)
5861
back(ΔF)
5962
@test true # above line would have errored if this was not working right
6063
end
@@ -72,8 +75,8 @@ end
7275
end
7376
@testset "matrix inverse using LU" begin
7477
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
75-
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), Val(true)))
76-
test_rrule(inv, lu(randn(T, n, n), Val(true)))
78+
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), LU_ROW_MAXIMUM))
79+
test_rrule(inv, lu(randn(T, n, n), LU_ROW_MAXIMUM))
7780
end
7881
end
7982
end

2 commit comments

Comments
 (2)

simeonschaub commented on Jun 8, 2021

@simeonschaub
MemberAuthor

JuliaRegistrator commented on Jun 8, 2021

@JuliaRegistrator

Registration pull request created: JuliaRegistries/General/38425

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.4 -m "<description of version>" 407d332e115c6002d1ccb0c1f4615a8223e9ec48
git push origin v0.8.4
Please sign in to comment.