Skip to content

Commit 2d72b50

Browse files
committed
Less clever sum rrule for the default case
The sum rrule relies on broadcasting to figure out the result shape of the cotangent. However, this has two disadvantages: 1. We need to keep the original array around 2. Broadcasting machinery is complicated and tough on (higher-order) AD This adds a special case for `dims=:`, which simply stores the dimensions of the original array and uses `fill` in the pullback, which has a simple rrule and is thus much easier to AD.
1 parent 7593339 commit 2d72b50

File tree

1 file changed

+33
-1
lines changed

1 file changed

+33
-1
lines changed

src/rulesets/Base/mapreduce.jl

+33-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,35 @@ function frule((_, ẋ), ::typeof(sum), x; dims=:)
66
return sum(x; dims=dims), sum(ẋ; dims=dims)
77
end
88

9-
function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
9+
# Internal helper for filling while maintaining array type.
10+
# TODO: Ideallty we'd only need typeof(x) here, but Base doesn't have the
11+
# interfaces for that.
12+
function _typed_fill(x, ȳ, axes)
13+
fill!(similar(x, typeof(ȳ), axes), ȳ)
14+
end
15+
16+
function rrule(::typeof(_typed_fill), x, ȳ, axes)
17+
function _typed_fill_pullback(Ȳ)
18+
return (NoTangent(), NoTangent(), sum(Ȳ), NoTangent())
19+
end
20+
return _typed_fill(x, ȳ, axes), _typed_fill_pullback
21+
end
22+
23+
function sum_rrule(x, dims::Colon)
24+
y = sum(x; dims=dims)
25+
let xdims=size(x)
26+
function sum_pullback(ȳ)
27+
= InplaceableThunk(
28+
x -> x .+= ȳ,
29+
@thunk(_typed_fill(x, ȳ, xdims...)),
30+
)
31+
return (NoTangent(), x̄)
32+
end
33+
y, sum_pullback
34+
end
35+
end
36+
37+
function sum_rrule(x, dims)
1038
y = sum(x; dims=dims)
1139
function sum_pullback(ȳ)
1240
# broadcasting the two works out the size no-matter `dims`
@@ -19,6 +47,10 @@ function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
1947
return y, sum_pullback
2048
end
2149

50+
function rrule(::typeof(sum), x::AbstractArray{T}; dims=:) where {T<:Number}
51+
return sum_rrule(x, dims)
52+
end
53+
2254
# Can't map over Adjoint/Transpose Vector
2355
function rrule(
2456
config::RuleConfig{>:HasReverseMode},

0 commit comments

Comments
 (0)