Skip to content

Commit a9511eb

Browse files
authored
Merge pull request #419 from JuliaDiff/ox/muladd
Ox/muladd
2 parents ef0c440 + 080fe63 commit a9511eb

File tree

3 files changed

+129
-4
lines changed

3 files changed

+129
-4
lines changed

Project.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.68"
3+
version = "0.7.69"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,8 +14,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414
[compat]
1515
ChainRulesCore = "0.9.44"
1616
ChainRulesTestUtils = "0.6.8"
17-
Compat = "3"
18-
FiniteDifferences = "0.11, 0.12"
17+
Compat = "3.30"
18+
FiniteDifferences = "0.12.8"
1919
Reexport = "0.2, 1"
2020
Requires = "0.5.2, 1"
2121
julia = "1"

src/rulesets/Base/arraymath.jl

+84
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,90 @@ function rrule(
9595
end
9696

9797

98+
#####
99+
##### `muladd`
100+
#####
101+
102+
function rrule(
103+
::typeof(muladd),
104+
A::AbstractMatrix{<:CommutativeMulNumber},
105+
B::AbstractVecOrMat{<:CommutativeMulNumber},
106+
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
107+
)
108+
# The useful case, mul! fused with +
109+
function muladd_pullback_1(Ȳ)
110+
matmul = (
111+
InplaceableThunk(
112+
@thunk(Ȳ * B'),
113+
dA -> mul!(dA, Ȳ, B', true, true)
114+
),
115+
InplaceableThunk(
116+
@thunk(A' * Ȳ),
117+
dB -> mul!(dB, A', Ȳ, true, true)
118+
)
119+
)
120+
addon = if z isa Bool
121+
DoesNotExist()
122+
elseif z isa Number
123+
@thunk(sum(Ȳ))
124+
else
125+
InplaceableThunk(
126+
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
127+
dz -> sum!(dz, Ȳ; init=false)
128+
)
129+
end
130+
(NO_FIELDS, matmul..., addon)
131+
end
132+
return muladd(A, B, z), muladd_pullback_1
133+
end
134+
135+
function rrule(
136+
::typeof(muladd),
137+
ut::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
138+
v::AbstractVector{<:CommutativeMulNumber},
139+
z::CommutativeMulNumber,
140+
)
141+
# This case is dot(u,v)+z, but would also match signature above.
142+
function muladd_pullback_2(dy)
143+
ut_thunk = InplaceableThunk(
144+
@thunk(v' .* dy),
145+
dut -> dut .+= v' .* dy
146+
)
147+
v_thunk = InplaceableThunk(
148+
@thunk(ut' .* dy),
149+
dv -> dv .+= ut' .* dy
150+
)
151+
(NO_FIELDS, ut_thunk, v_thunk, z isa Bool ? DoesNotExist() : dy)
152+
end
153+
return muladd(ut, v, z), muladd_pullback_2
154+
end
155+
156+
function rrule(
157+
::typeof(muladd),
158+
u::AbstractVector{<:CommutativeMulNumber},
159+
vt::LinearAlgebra.AdjOrTransAbsVec{<:CommutativeMulNumber},
160+
z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}},
161+
)
162+
# Outer product, just broadcasting
163+
function muladd_pullback_3(Ȳ)
164+
proj = (
165+
@thunk(vec(sum(Ȳ .* conj.(vt), dims=2))),
166+
@thunk(vec(sum(u .* conj.(Ȳ), dims=1))'),
167+
)
168+
addon = if z isa Bool
169+
DoesNotExist()
170+
elseif z isa Number
171+
@thunk(sum(Ȳ))
172+
else
173+
InplaceableThunk(
174+
@thunk(sum!(similar(z, eltype(Ȳ)), Ȳ)),
175+
dz -> sum!(dz, Ȳ; init=false)
176+
)
177+
end
178+
(NO_FIELDS, proj..., addon)
179+
end
180+
return muladd(u, vt, z), muladd_pullback_3
181+
end
98182

99183
#####
100184
##### `/`

test/rulesets/Base/arraymath.jl

+42-1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,47 @@
6262
end
6363
end
6464

65+
@testset "muladd: $T" for T in (Float64, ComplexF64)
66+
@testset "add $(typeof(z))" for z in [rand(T), rand(T, 3), rand(T, 3, 3), false]
67+
@testset "matrix * matrix" begin
68+
A = rand(T, 3, 3)
69+
B = rand(T, 3, 3)
70+
test_rrule(muladd, A, B, z)
71+
test_rrule(muladd, A', B, z)
72+
test_rrule(muladd, A , B', z)
73+
74+
C = rand(T, 3, 5)
75+
D = rand(T, 5, 3)
76+
test_rrule(muladd, C, D, z)
77+
end
78+
if ndims(z) <= 1
79+
@testset "matrix * vector" begin
80+
A, B = rand(T, 3, 3), rand(T, 3)
81+
test_rrule(muladd, A, B, z)
82+
test_rrule(muladd, A, B rand(T, 3,1), z)
83+
end
84+
@testset "adjoint * matrix" begin
85+
At, B = rand(T, 3)', rand(T, 3, 3)
86+
test_rrule(muladd, At, B, z')
87+
test_rrule(muladd, At rand(T,1,3), B, z')
88+
end
89+
end
90+
if ndims(z) == 0
91+
@testset "adjoint * vector" begin # like dot
92+
A, B = rand(T, 3)', rand(T, 3)
93+
test_rrule(muladd, A, B, z)
94+
test_rrule(muladd, A rand(T,1,3), B, z')
95+
end
96+
end
97+
if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1)
98+
@testset "vector * adjoint" begin # outer product
99+
A, B = rand(T, 3), rand(T, 3)'
100+
test_rrule(muladd, A, B, z)
101+
test_rrule(muladd, A, B rand(T,1,3), z)
102+
end
103+
end
104+
end
105+
end
65106

66107
@testset "$f" for f in (/, \)
67108
@testset "Matrix" begin
@@ -89,13 +130,13 @@
89130
end
90131
end
91132
end
133+
92134
@testset "/ and \\ Scalar-AbstractArray" begin
93135
A = randn(3, 4, 5)
94136
test_rrule(/, A, 7.2)
95137
test_rrule(\, 7.2, A)
96138
end
97139

98-
99140
@testset "negation" begin
100141
A = randn(4, 4)
101142
= randn(4, 4)

0 commit comments

Comments
 (0)