|
19 | 19 | [emmy.numsymb :as sym]
|
20 | 20 | [emmy.polynomial]
|
21 | 21 | [emmy.structure :as s]
|
| 22 | + [emmy.tape :as tape] |
22 | 23 | [emmy.util :as u]
|
23 | 24 | [emmy.value :as v])
|
24 | 25 | #?(:clj
|
|
250 | 251 | (->Function
|
251 | 252 | fexp (f/arity f) (domain-types f) (range-type f))))
|
252 | 253 |
|
253 |
| -(defn- forward-mode-fold [f primal-s tag] |
| 254 | +(defn- forward-mode-fold |
| 255 | + "Takes |
| 256 | +
|
| 257 | + - a literal function `f` |
| 258 | + - a structure `primal-s` of the primal components of the args to `f` (with |
| 259 | + respect to `tag`) |
| 260 | + - the `tag` of the innermost active derivative call |
| 261 | +
|
| 262 | + And returns a folding function (designed for use |
| 263 | + with [[emmy.structure/fold-chain]]) that |
| 264 | +
|
| 265 | + generates a new [[emmy.differential/Dual]] by applying the chain rule and |
| 266 | + summing the partial derivatives for each perturbed argument in the input |
| 267 | + structure." |
| 268 | + [f primal-s tag] |
254 | 269 | (fn
|
255 | 270 | ([] 0)
|
256 | 271 | ([tangent] (d/bundle-element (apply f primal-s) tangent tag))
|
|
262 | 277 | (g/+ tangent (g/* (literal-apply partial primal-s)
|
263 | 278 | dx))))))))
|
264 | 279 |
|
| 280 | +(defn- reverse-mode-fold |
| 281 | + "Takes |
| 282 | +
|
| 283 | + - a literal function `f` |
| 284 | + - a structure `primal-s` of the primal components of the args to `f` (with |
| 285 | + respect to `tag`) |
| 286 | + - the `tag` of the innermost active derivative call |
| 287 | +
|
| 288 | + And returns a folding function (designed for use |
| 289 | + with [[emmy.structure/fold-chain]]) that assembles all partial derivatives of |
| 290 | + `f` into a new [[emmy.tape/TapeCell]]." |
| 291 | + [f primal-s tag] |
| 292 | + (fn |
| 293 | + ([] []) |
| 294 | + ([partials] |
| 295 | + (tape/make tag (apply f primal-s) partials)) |
| 296 | + ([partials [entry path _]] |
| 297 | + (if (and (tape/tape? entry) (= tag (tape/tape-tag entry))) |
| 298 | + (let [partial (literal-partial f path)] |
| 299 | + (conj partials [entry (literal-apply partial primal-s)])) |
| 300 | + partials)))) |
| 301 | + |
265 | 302 | (defn- literal-derivative
|
266 |
| - "Takes a literal function `f` and a sequence of arguments `xs`, and generates |
267 |
| - an expanded `((D f) xs)` by applying the chain rule and summing the partial |
268 |
| - derivatives for each [[emmy.differential/Dual]] argument in the input |
269 |
| - structure." |
| 303 | + "Takes |
| 304 | +
|
| 305 | + - a literal function `f` |
| 306 | + - a structure `s` of arguments |
| 307 | + - the `tag` of the innermost active derivative call |
| 308 | + - an instance of a perturbation `dx` associated with `tag` |
| 309 | +
|
| 310 | + and generates the proper return value for `((D f) xs)`. |
| 311 | +
|
| 312 | + In forward-mode AD this is a new [[emmy.differential/Dual]] generated by |
| 313 | + applying the chain rule and summing the partial derivatives for each perturbed |
| 314 | + argument in the input structure. |
| 315 | +
|
| 316 | + In reverse-mode, this is a new [[emmy.tape/TapeCell]] containing a sequence of |
| 317 | + pairs of each input paired with the partial derivative of `f` with respect to |
| 318 | + that input." |
270 | 319 | [f s tag dx]
|
271 |
| - (let [fold-fn (cond (d/dual? dx) forward-mode-fold |
272 |
| - :else (u/illegal "No tape or differential inputs.")) |
273 |
| - primal-s (s/mapr (fn [x] (d/primal x tag)) s)] |
| 320 | + (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold |
| 321 | + (d/dual? dx) forward-mode-fold |
| 322 | + :else (u/illegal "No tape or differential inputs.")) |
| 323 | + primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)] |
274 | 324 | (s/fold-chain (fold-fn f primal-s tag) s)))
|
275 | 325 |
|
276 | 326 | (defn- check-argument-type
|
|
305 | 355 | (if-let [[tag dx] (s/fold-chain
|
306 | 356 | (fn
|
307 | 357 | ([] [])
|
308 |
| - ([acc] (apply d/tag+perturbation acc)) |
| 358 | + ([acc] (apply tape/tag+perturbation acc)) |
309 | 359 | ([acc [d]] (conj acc d)))
|
310 | 360 | s)]
|
311 | 361 | (literal-derivative f s tag dx)
|
|
0 commit comments