Skip to content

Commit 15e55ee

Browse files
committed
move branches into InplaceableThunk
1 parent 7b50309 commit 15e55ee

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

src/rulesets/LinearAlgebra/norm.jl

+36-23
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,36 @@ end
2020
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
2121
y = LinearAlgebra.norm(x, p)
2222
function norm_pullback_p(Δy)
23-
∂x = if isempty(x) || p == 0
24-
InplaceableThunk(
25-
@thunk(zero.(x) .* (zero(y) * zero(real(Δy)))),
26-
identity,
27-
)
23+
∂x = InplaceableThunk(
24+
# out-of-place versions
25+
if isempty(x) || p == 0
26+
@thunk(zero.(x) .* (zero(y) * zero(real(Δy))))
2827
elseif p == 2
29-
InplaceableThunk(
30-
@thunk(_norm2_back(x, y, Δy)),
31-
dx -> _norm2_back!(dx, x, y, Δy),
32-
)
28+
@thunk(_norm2_back(x, y, Δy))
3329
elseif p == 1
34-
InplaceableThunk(
35-
@thunk(_norm1_back(x, y, Δy)),
36-
dx -> _norm1_back!(dx, x, y, Δy),
37-
)
30+
@thunk(_norm1_back(x, y, Δy))
3831
elseif p == Inf
39-
_normInf_back(x, y, Δy)
32+
@thunk(_normInf_back(x, y, Δy))
4033
elseif p == -Inf
41-
_normInf_back(x, y, Δy)
34+
@thunk(_normInf_back(x, y, Δy))
4235
else
43-
_normp_back_x(x, p, y, Δy)
36+
@thunk(_normp_back_x(x, p, y, Δy))
37+
end,
38+
# in-place versions
39+
if isempty(x) || p == 0
40+
identity
41+
elseif p == 2
42+
dx -> _norm2_back!(dx, x, y, Δy)
43+
elseif p == 1
44+
dx -> _norm1_back!(dx, x, y, Δy)
45+
elseif p == Inf
46+
dx -> dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
47+
elseif p == -Inf
48+
dx -> dx .+= _normInf_back(x, y, Δy)
49+
else
50+
dx -> dx .+= _normp_back_x(x, p, y, Δy)
4451
end
52+
)
4553
∂p = @thunk _normp_back_p(x, p, y, Δy)
4654
return (NO_FIELDS, ∂x, ∂p)
4755
end
@@ -51,14 +59,19 @@ end
5159
function rrule(::typeof(norm), x::AbstractArray{<:Number})
5260
y = LinearAlgebra.norm(x)
5361
function norm_pullback_2(Δy)
54-
∂x = if isempty(x)
55-
zero.(x) .* (zero(y) * zero(real(Δy)))
56-
else
57-
InplaceableThunk(
58-
@thunk(_norm2_back(x, y, Δy)),
59-
dx -> _norm2_back!(dx, x, y, Δy),
62+
∂x = InplaceableThunk(
63+
if isempty(x)
64+
@thunk(zero.(x) .* (zero(y) * zero(real(Δy))))
65+
else
66+
@thunk(_norm2_back(x, y, Δy))
67+
end
68+
,
69+
if isempty(x)
70+
identity
71+
else
72+
dx -> _norm2_back!(dx, x, y, Δy)
73+
end
6074
)
61-
end
6275
return (NO_FIELDS, ∂x)
6376
end
6477
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())

0 commit comments

Comments
 (0)