Skip to content

Commit 622ba21

Browse files
authored
add forward, reverse-mode AD operators, enable tests for reverse-mode (#185)
- adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows the user to switch between forward and reverse mode automatic differentiation - adds a new `emmy.calculus.derivative/gradient` that acts like `emmy.tape/gradient` but is capable of taking multiple variables - adds new operators `emmy.calculus.derivative/{D-forward, D-reverse}` and operator-returning-functions `emmy.calculus.derivative/{partial-forward, partial-reverse}` that allow the user to explicitly invoke forward-mode or reverse-mode automatic differentiation. `D` and `partial` still default to forward-mode - modifies `emmy.tape/gradient` to correctly error when passed invalid selectors, just like `emmy.dual/derivative`.
1 parent d4c4c07 commit 622ba21

File tree

4 files changed

+192
-54
lines changed

4 files changed

+192
-54
lines changed

CHANGELOG.md

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33
## [unreleased]
44

5+
- #185:
6+
7+
- adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows the
8+
user to switch between forward and reverse mode automatic differentiation
9+
10+
- adds a new `emmy.calculus.derivative/gradient` that acts like
11+
`emmy.tape/gradient` but is capable of taking multiple variables
12+
13+
- adds new operators `emmy.calculus.derivative/{D-forward, D-reverse}` and
14+
operator-returning-functions `emmy.calculus.derivative/{partial-forward,
15+
partial-reverse}` that allow the user to explicitly invoke forward-mode or
16+
reverse-mode automatic differentiation. `D` and `partial` still default to
17+
forward-mode
18+
19+
- modifies `emmy.tape/gradient` to correctly error when passed invalid
20+
selectors, just like `emmy.dual/derivative`.
21+
522
- #183:
623

724
- adds `emmy.{autodiff, tape}` to `emmy.sci`'s exported namespace set

src/emmy/calculus/derivative.cljc

+132-33
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222

2323
;; ## Single and Multivariable Calculus
2424
;;
25-
;; These functions put together the pieces laid out
26-
;; in [[emmy.dual]] and declare an interface for taking
27-
;; derivatives.
25+
;; These functions put together the pieces laid out in [[emmy.dual]] and declare
26+
;; an interface for taking derivatives.
2827

2928
;; The result of applying the derivative `(D f)` of a multivariable function `f`
3029
;; to a sequence of `args` is a structure of the same shape as `args` with all
@@ -37,7 +36,8 @@
3736
;; To generate the result:
3837
;;
3938
;; - For a single non-structural argument, return `(d/derivative f)`
40-
;; - else, bundle up all arguments into a single [[s/Structure]] instance `xs`
39+
;; - else, bundle up all arguments into a single [[emmy.structure/Structure]]
40+
;; instance `xs`
4141
;; - Generate `xs'` by replacing each entry in `xs` with `((d/derivative f')
4242
;; entry)`, where `f'` is a function of ONLY that entry that
4343
;; calls `(f (assoc-in xs path entry))`. In other words, replace each entry
@@ -49,7 +49,7 @@
4949
;; above.
5050
;;
5151
;; [[jacobian]] handles this main logic. [[jacobian]] can only take a structural
52-
;; input. [[euclidean]] and [[multivariate]] below widen handle, respectively,
52+
;; input. [[euclidean]] and [[multivariate]] below handle, respectively,
5353
;; optionally-structural and multivariable arguments.
5454

5555
(defn- deep-partial
@@ -109,12 +109,12 @@
109109
(u/illegal (str "Bad selectors " selectors " for structure " input))))))
110110

111111
(defn- euclidean
112-
"Slightly more general version of [[jacobian]] that can handle a single
113-
non-structural input; dispatches to either [[jacobian]] or [[derivative]]
114-
depending on the input type.
112+
"Slightly more general version of [[jacobian]] that can handle a single input;
113+
dispatches to either [[jacobian]] or [[derivative]] depending on whether or
114+
not the input is structural.
115115
116116
If you pass non-empty `selectors`, the returned function will throw if it
117-
receives a non-structural, non-numerical argument."
117+
receives a non-structural, non-scalar argument."
118118
([f] (euclidean f []))
119119
([f selectors]
120120
(let [selectors (vec selectors)]
@@ -143,6 +143,28 @@
143143
(str "Selectors " selectors
144144
" not allowed for non-structural input " input)))))))
145145

146+
(defn- multi
147+
"Given
148+
149+
- some higher-order function `op` that transforms a function of a single
150+
variable into another function of a single variable
151+
- function `f` capable of taking multiple arguments
152+
153+
returns a new function that acts like `(op f)` but can take multiple
154+
arguments.
155+
156+
When passed multiple arguments, the returned functon packages them into a
157+
single `[[emmy.structure/up]]` instance. Any [[emmy.matrix/Matrix]] present in
158+
the argument list will be converted into a `down` of `up`s (a row of columns)."
159+
[op f]
160+
(-> (fn
161+
([] 0)
162+
([x] ((op f) x))
163+
([x & more]
164+
((multi op (fn [xs] (apply f xs)))
165+
(matrix/seq-> (cons x more)))))
166+
(f/with-arity (f/arity f) {:from ::multi})))
167+
146168
(defn- multivariate
147169
"Slightly wider version of [[euclidean]]. Accepts:
148170
@@ -152,24 +174,39 @@
152174
153175
And returns a new function that computes either the
154176
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
155-
or the entry at `selectors`.
177+
or the entry at `selectors` using [forward-mode automatic
178+
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Forward_accumulation).
156179
157180
Any multivariable function will have its argument vector coerced into an `up`
158-
structure. Any [[matrix/Matrix]] in a multiple-arg function call will be
181+
structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be
159182
converted into a `down` of `up`s (a row of columns).
160183
161-
Single-argument functions don't transform their arguments."
184+
Arguments to single-variable functions are not transformed."
162185
([f] (multivariate f []))
163186
([f selectors]
164-
(let [d #(euclidean % selectors)
165-
df (d f)
166-
df* (d (fn [args] (apply f args)))]
167-
(-> (fn
168-
([] 0)
169-
([x] (df x))
170-
([x & more]
171-
(df* (matrix/seq-> (cons x more)))))
172-
(f/with-arity (f/arity f) {:from ::multivariate})))))
187+
(let [d #(euclidean % selectors)]
188+
(multi d f))))
189+
190+
(defn gradient
191+
"Accepts:
192+
193+
- some function `f` of potentially many arguments
194+
- optionally, a sequence of selectors meant to index into the structural
195+
argument, or argument vector, of `f`
196+
197+
And returns a new function that computes either the
198+
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
199+
or the entry at `selectors` using [reverse-mode automatic
200+
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation).
201+
202+
Any multivariable function will have its argument vector coerced into an `up`
203+
structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be
204+
converted into a `down` of `up`s (a row of columns).
205+
206+
Arguments to single-variable functions are not transformed."
207+
([f] (gradient f []))
208+
([f selectors]
209+
(multi #(tape/gradient % selectors) f)))
173210

174211
;; ## Generic [[g/partial-derivative]] Installation
175212
;;
@@ -192,24 +229,66 @@
192229
;; passed to the structure of functions, instead of separately for every entry
193230
;; in the structure.
194231
;;
232+
;; A dynamic variable controls whether or not this process uses forward-mode or
233+
;; reverse-mode AD.
234+
;;
195235
;; TODO: I think this is going to cause problems for, say, a Structure of
196236
;; PowerSeries, where there is actually a cheap `g/partial-derivative`
197237
;; implementation for the components. I vote to back out this `::s/structure`
198238
;; installation.
199239

240+
(def ^:dynamic *mode* d/FORWARD-MODE)
241+
200242
(doseq [t [::v/function ::s/structure]]
201243
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
202-
(multivariate f selectors))
244+
(if (= *mode* d/FORWARD-MODE)
245+
(multivariate f selectors)
246+
(gradient f selectors)))
203247

204248
(defmethod g/partial-derivative [t nil] [f _]
205-
(multivariate f [])))
249+
(if (= *mode* d/FORWARD-MODE)
250+
(multivariate f [])
251+
(gradient f []))))
206252

207253
;; ## Operators
208254
;;
209-
;; This section exposes various differential operators as [[o/Operator]]
210-
;; instances.
255+
;; This section exposes various differential operators
256+
;; as [[emmy.operator/Operator]] instances.
257+
258+
(def ^{:arglists '([f])}
259+
D-forward
260+
"Forward-mode derivative operator. Takes some function `f` and returns a
261+
function whose value at some point can multiply an increment in the arguments
262+
to produce the best linear estimate of the increment in the function value.
211263
212-
(def D
264+
For univariate functions, [[D-forward]] computes a derivative. For vector-valued
265+
functions, [[D-forward]] computes
266+
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
267+
of `f`."
268+
(o/make-operator
269+
(fn [x]
270+
(binding [*mode* d/FORWARD-MODE]
271+
(g/partial-derivative x [])))
272+
g/derivative-symbol))
273+
274+
(def ^{:arglists '([f])}
275+
D-reverse
276+
"Reverse-mode derivative operator. Takes some function `f` and returns a
277+
function whose value at some point can multiply an increment in the arguments
278+
to produce the best linear estimate of the increment in the function value.
279+
280+
For univariate functions, [[D-reverse]] computes a derivative. For vector-valued
281+
functions, [[D-reverse]] computes
282+
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
283+
of `f`."
284+
(o/make-operator
285+
(fn [x]
286+
(binding [*mode* d/REVERSE-MODE]
287+
(g/partial-derivative x [])))
288+
g/derivative-symbol))
289+
290+
(def ^{:arglists '([f])}
291+
D
213292
"Derivative operator. Takes some function `f` and returns a function whose value
214293
at some point can multiply an increment in the arguments to produce the best
215294
linear estimate of the increment in the function value.
@@ -222,8 +301,7 @@
222301
The related [[emmy.env/Grad]] returns a function that produces a structure of
223302
the opposite orientation as [[D]]. Both of these functions use forward-mode
224303
automatic differentiation."
225-
(o/make-operator #(g/partial-derivative % [])
226-
g/derivative-symbol))
304+
D-forward)
227305

228306
(defn D-as-matrix [F]
229307
(fn [s]
@@ -232,13 +310,34 @@
232310
((D F) s)
233311
s)))
234312

235-
(defn partial
313+
(defn partial-forward
314+
"Returns an operator that, when applied to a function `f`, produces a function
315+
that uses forward-mode automatic differentiation to compute the partial
316+
derivative of `f` at the (zero-based) slot index provided via `selectors`."
317+
[& selectors]
318+
(o/make-operator
319+
(fn [x]
320+
(binding [*mode* d/FORWARD-MODE]
321+
(g/partial-derivative x selectors)))
322+
`(~'partial ~@selectors)))
323+
324+
(defn partial-reverse
236325
"Returns an operator that, when applied to a function `f`, produces a function
237-
that computes the partial derivative of `f` at the (zero-based) slot index
238-
provided via `selectors`."
326+
that uses reverse-mode automatic differentiation to compute the partial
327+
derivative of `f` at the (zero-based) slot index provided via `selectors`."
239328
[& selectors]
240-
(o/make-operator #(g/partial-derivative % selectors)
241-
`(~'partial ~@selectors)))
329+
(o/make-operator
330+
(fn [x]
331+
(binding [*mode* d/REVERSE-MODE]
332+
(g/partial-derivative x selectors)))
333+
`(~'partial ~@selectors)))
334+
335+
(def ^{:arglists '([& selectors])}
336+
partial
337+
"Returns an operator that, when applied to a function `f`, produces a function
338+
that uses forward-mode automatic differentiation to compute the partial
339+
derivative of `f` at the (zero-based) slot index provided via `selectors`."
340+
partial-forward)
242341

243342
;; ## Derivative Utilities
244343
;;

src/emmy/tape.cljc

+12-8
Original file line numberDiff line numberDiff line change
@@ -512,16 +512,20 @@
512512
(u/illegal
513513
(str "Selectors " selectors
514514
" not allowed for non-structural input " x)))
515-
516515
(let [tag (d/fresh-tag)
517-
inputs (if (empty? selectors)
518-
(tapify x tag)
519-
(update-in x selectors tapify tag))
520-
output (d/with-active-tag tag f [inputs])
516+
input (if-let [piece (get-in x selectors)]
517+
(if (empty? selectors)
518+
(tapify piece tag)
519+
(assoc-in x selectors (tapify piece tag)))
520+
;; The call to `get-in` will return nil if the
521+
;; `selectors` don't index correctly into the supplied
522+
;; `input`, triggering this exception.
523+
(u/illegal
524+
(str "Bad selectors " selectors " for structure " x)))
525+
output (d/with-active-tag tag f [input])
521526
completed (d/extract-tangent output tag d/REVERSE-MODE)]
522-
(if (empty? selectors)
523-
(interpret inputs completed tag)
524-
(interpret (get-in inputs selectors) completed tag)))))))
527+
(-> (get-in input selectors)
528+
(interpret completed tag)))))))
525529

526530
(defmethod g/zero-like [::tape] [_] 0)
527531
(defmethod g/one-like [::tape] [_] 1)

test/emmy/calculus/derivative_test.cljc

+31-13
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,27 @@
446446
(testing "space"
447447
(let [g (af/literal-function 'g [0 0] 0)
448448
h (af/literal-function 'h [0 0] 0)]
449-
(is (= '(+ (((partial 0) g) x y) (((partial 0) h) x y))
450-
(simplify (((partial 0) (+ g h)) 'x 'y))))
451-
(is (= '(+ (* (((partial 0) g) x y) (h x y)) (* (((partial 0) h) x y) (g x y)))
452-
(simplify (((partial 0) (* g h)) 'x 'y))))
453-
(is (= '(+ (* (((partial 0) g) x y) (h x y) (expt (g x y) (+ (h x y) -1)))
454-
(* (((partial 0) h) x y) (log (g x y)) (expt (g x y) (h x y))))
455-
(simplify (((partial 0) (g/expt g h)) 'x 'y))))))
449+
(is (zero?
450+
(simplify
451+
(g/- (g/+ (((partial 0) g) 'x 'y)
452+
(((partial 0) h) 'x 'y))
453+
(((partial 0) (+ g h)) 'x 'y)))))
454+
(is (zero?
455+
(simplify
456+
(g/-
457+
(g/+ (g/* (((partial 0) g) 'x 'y) (h 'x 'y))
458+
(g/* (((partial 0) h) 'x 'y) (g 'x 'y)))
459+
(((partial 0) (* g h)) 'x 'y)))))
460+
(is (zero?
461+
(simplify
462+
(g/-
463+
(g/+ (g/* (((partial 0) g) 'x 'y)
464+
(h 'x 'y)
465+
(g/expt (g 'x 'y) (+ (h 'x 'y) -1)))
466+
(g/* (((partial 0) h) 'x 'y)
467+
(g/log (g 'x 'y))
468+
(g/expt (g 'x 'y) (h 'x 'y))))
469+
(((partial 0) (g/expt g h)) 'x 'y)))))))
456470

457471
(testing "operators"
458472
(is (= '(down 1 1 1 1 1 1 1 1 1 1)
@@ -485,9 +499,12 @@
485499
f3 (fn [x y] (* (tan x) (log y)))
486500
f4 (fn [x y] (* (tan x) (sin y)))
487501
f5 (fn [x y] (/ (tan x) (sin y)))]
488-
(is (= '(down (* (log y) (cos x))
489-
(/ (sin x) y))
490-
(simplify ((D f2) 'x 'y))))
502+
(is (= '(down 0 0)
503+
(simplify
504+
(g/- (s/down
505+
(g/* (g/log 'y) (g/cos 'x))
506+
(g// (g/sin 'x) 'y))
507+
((D f2) 'x 'y)))))
491508
(is (= '(down (/ (log y) (expt (cos x) 2))
492509
(/ (tan x) y))
493510
(simplify ((D f3) 'x 'y))))
@@ -1616,9 +1633,6 @@
16161633
"symbolic-taylor-series keeps the arguments symbolic, even when they
16171634
are numbers."))))
16181635

1619-
;; TODO enable when we add our gradient impl in the next PR.
1620-
1621-
#_
16221636
(deftest mixed-mode-tests
16231637
(testing "multiple input, vector output"
16241638
(let [f (fn [a b c d e f]
@@ -1686,3 +1700,7 @@
16861700

16871701
(deftest forward-mode-tests
16881702
(all-tests d/D d/partial))
1703+
1704+
(deftest reverse-mode-tests
1705+
(all-tests d/D-reverse
1706+
d/partial-reverse))

0 commit comments

Comments
 (0)