Skip to content

Commit bd47efd

Browse files
authored
Support nested forward/reverse mode (#156)
1 parent d354463 commit bd47efd

File tree

9 files changed

+371
-448
lines changed

9 files changed

+371
-448
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
## [unreleased]
44

5+
- #156:
6+
7+
- Makes forward- and reverse-mode automatic differentiation compatible with
8+
each other, allowing for proper mixed-mode AD
9+
10+
- Adds support for derivatives of literal functions in reverse-mode
11+
512
- #165:
613

714
- Fixes Alexey's Amazing Bug for our tape implementation

src/emmy/abstract/function.cljc

+59-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
[emmy.numsymb :as sym]
2020
[emmy.polynomial]
2121
[emmy.structure :as s]
22+
[emmy.tape :as tape]
2223
[emmy.util :as u]
2324
[emmy.value :as v])
2425
#?(:clj
@@ -250,7 +251,21 @@
250251
(->Function
251252
fexp (f/arity f) (domain-types f) (range-type f))))
252253

253-
(defn- forward-mode-fold [f primal-s tag]
254+
(defn- forward-mode-fold
255+
"Takes
256+
257+
- a literal function `f`
258+
- a structure `primal-s` of the primal components of the args to `f` (with
259+
respect to `tag`)
260+
- the `tag` of the innermost active derivative call
261+
262+
And returns a folding function (designed for use
263+
with [[emmy.structure/fold-chain]]) that
264+
265+
generates a new [[emmy.differential/Dual]] by applying the chain rule and
266+
summing the partial derivatives for each perturbed argument in the input
267+
structure."
268+
[f primal-s tag]
254269
(fn
255270
([] 0)
256271
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
@@ -262,15 +277,50 @@
262277
(g/+ tangent (g/* (literal-apply partial primal-s)
263278
dx))))))))
264279

280+
(defn- reverse-mode-fold
281+
"Takes
282+
283+
- a literal function `f`
284+
- a structure `primal-s` of the primal components of the args to `f` (with
285+
respect to `tag`)
286+
- the `tag` of the innermost active derivative call
287+
288+
And returns a folding function (designed for use
289+
with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of
290+
`f` into a new [[emmy.tape/TapeCell]]."
291+
[f primal-s tag]
292+
(fn
293+
([] [])
294+
([partials]
295+
(tape/make tag (apply f primal-s) partials))
296+
([partials [entry path _]]
297+
(if (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
298+
(let [partial (literal-partial f path)]
299+
(conj partials [entry (literal-apply partial primal-s)]))
300+
partials))))
301+
265302
(defn- literal-derivative
266-
"Takes a literal function `f` and a sequence of arguments `xs`, and generates
267-
an expanded `((D f) xs)` by applying the chain rule and summing the partial
268-
derivatives for each [[emmy.differential/Dual]] argument in the input
269-
structure."
303+
"Takes
304+
305+
- a literal function `f`
306+
- a structure `s` of arguments
307+
- the `tag` of the innermost active derivative call
308+
- an instance of a perturbation `dx` associated with `tag`
309+
310+
and generates the proper return value for `((D f) xs)`.
311+
312+
In forward-mode AD this is a new [[emmy.differential/Dual]] generated by
313+
applying the chain rule and summing the partial derivatives for each perturbed
314+
argument in the input structure.
315+
316+
In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of
317+
pairs of each input paired with the partial derivative of `f` with respect to
318+
that input."
270319
[f s tag dx]
271-
(let [fold-fn (cond (d/dual? dx) forward-mode-fold
272-
:else (u/illegal "No tape or differential inputs."))
273-
primal-s (s/mapr (fn [x] (d/primal x tag)) s)]
320+
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
321+
(d/dual? dx) forward-mode-fold
322+
:else (u/illegal "No tape or differential inputs."))
323+
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
274324
(s/fold-chain (fold-fn f primal-s tag) s)))
275325

276326
(defn- check-argument-type
@@ -305,7 +355,7 @@
305355
(if-let [[tag dx] (s/fold-chain
306356
(fn
307357
([] [])
308-
([acc] (apply d/tag+perturbation acc))
358+
([acc] (apply tape/tag+perturbation acc))
309359
([acc [d]] (conj acc d)))
310360
s)]
311361
(literal-derivative f s tag dx)

src/emmy/calculus/derivative.cljc

+19-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
[emmy.operator :as o]
1616
[emmy.series :as series]
1717
[emmy.structure :as s]
18+
[emmy.tape :as tape]
1819
[emmy.util :as u]
1920
[emmy.value :as v])
2021
#?(:clj
@@ -531,12 +532,23 @@
531532
(letfn [(process-term [term]
532533
(g/simplify
533534
(s/mapr (fn rec [x]
534-
(if (d/dual? x)
535-
(d/bundle-element
536-
(rec (d/primal x))
537-
(rec (d/tangent x))
538-
(d/tag x))
539-
(-> (g/simplify x)
540-
(x/substitute replace-m))))
535+
(cond (d/dual? x)
536+
(d/bundle-element
537+
(rec (d/primal x))
538+
(rec (d/tangent x))
539+
(d/tag x))
540+
541+
(tape/tape? x)
542+
(tape/->TapeCell
543+
(tape/tape-tag x)
544+
(tape/tape-id x)
545+
(rec (tape/tape-primal x))
546+
(mapv (fn [[node partial]]
547+
[(rec node)
548+
(rec partial)])
549+
(tape/tape-partials x)))
550+
551+
:else (-> (g/simplify x)
552+
(x/substitute replace-m))))
541553
term)))]
542554
(series/fmap process-term series)))))

0 commit comments

Comments
 (0)