Skip to content

Commit 8fb417f

Browse files
authored
emmy.autodiff, perturbed? removal (#182)
1 parent b80a0f4 commit 8fb417f

18 files changed

+559
-572
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
## [unreleased]
44

5+
- #182:
6+
7+
- moves the generic implementations for `TapeCell` and `Dual` to `emmy.autodiff`
8+
9+
- moves `emmy.calculus.derivative` to `emmy.dual/derivative`
10+
11+
- removes `emmy.dual/perturbed?` from `IPerturbed`, as this is no longer used.
12+
513
- #180 renames `emmy.differential` to `emmy.dual`, since the file now contains a
614
proper dual number implementation, not a truncated multivariate power series.
715

src/emmy/abstract/function.cljc

+7-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
(:refer-clojure :exclude [name])
1313
(:require #?(:clj [clojure.pprint :as pprint])
1414
[emmy.abstract.number :as an]
15-
[emmy.dual :as d]
15+
[emmy.autodiff :as ad]
16+
[emmy.dual :as dual]
1617
[emmy.function :as f]
1718
[emmy.generic :as g]
1819
[emmy.matrix :as m]
@@ -268,9 +269,9 @@
268269
[f primal-s tag]
269270
(fn
270271
([] 0)
271-
([tangent] (d/bundle-element (apply f primal-s) tangent tag))
272+
([tangent] (dual/bundle-element (apply f primal-s) tangent tag))
272273
([tangent [x path _]]
273-
(let [dx (d/tangent x tag)]
274+
(let [dx (dual/tangent x tag)]
274275
(if (g/numeric-zero? dx)
275276
tangent
276277
(let [partial (literal-partial f path)]
@@ -318,9 +319,9 @@
318319
that input."
319320
[f s tag dx]
320321
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
321-
(d/dual? dx) forward-mode-fold
322+
(dual/dual? dx) forward-mode-fold
322323
:else (u/illegal "No tape or differential inputs."))
323-
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
324+
primal-s (s/mapr (fn [x] (ad/primal-of x tag)) s)]
324325
(s/fold-chain (fold-fn f primal-s tag) s)))
325326

326327
(defn- check-argument-type
@@ -355,7 +356,7 @@
355356
(if-let [[tag dx] (s/fold-chain
356357
(fn
357358
([] [])
358-
([acc] (apply tape/tag+perturbation acc))
359+
([acc] (apply ad/tag+perturbation acc))
359360
([acc [d]] (conj acc d)))
360361
s)]
361362
(literal-derivative f s tag dx)

0 commit comments

Comments
 (0)