|
29 | 29 | from genjax._src.core.datatypes.generative import NoneSelection
|
30 | 30 | from genjax._src.core.datatypes.generative import Selection
|
31 | 31 | from genjax._src.core.datatypes.generative import Trace
|
32 |
| -from genjax._src.core.datatypes.masks import mask |
33 | 32 | from genjax._src.core.datatypes.tracetypes import TraceType
|
34 | 33 | from genjax._src.core.datatypes.trie import Trie
|
35 | 34 | from genjax._src.core.interpreters.staging import concrete_cond
|
| 35 | +from genjax._src.core.transforms.incremental import tree_diff_primal |
36 | 36 | from genjax._src.core.typing import Any
|
37 | 37 | from genjax._src.core.typing import FloatArray
|
38 | 38 | from genjax._src.core.typing import IntArray
|
@@ -334,30 +334,22 @@ def update(
|
334 | 334 | chm: VectorChoiceMap,
|
335 | 335 | argdiffs: Tuple,
|
336 | 336 | ):
|
337 |
| - def _update(key, prev, chm, argdiffs): |
338 |
| - key, (retdiff, w, tr, d) = self.kernel.update(key, prev, chm, argdiffs) |
339 |
| - return key, (retdiff, w, tr, d) |
340 |
| - |
341 |
| - def _inner(key, index, prev, chm, argdiffs): |
342 |
| - check = index == chm.get_index() |
343 |
| - masked = mask(check, chm.inner) |
344 |
| - return _update(key, prev, masked, argdiffs) |
345 |
| - |
346 |
| - args = jtu.tree_leaves(argdiffs) |
| 337 | + args = tree_diff_primal(argdiffs) |
347 | 338 | self._static_check_broadcastable(args)
|
348 | 339 | broadcast_dim_length = self._static_broadcast_dim_length(args)
|
349 |
| - indices = np.array([i for i in range(0, broadcast_dim_length)]) |
350 | 340 | prev_inaxes_tree = jtu.tree_map(
|
351 | 341 | lambda v: None if v.shape == () else 0, prev.inner
|
352 | 342 | )
|
353 | 343 | key, sub_keys = slash(key, broadcast_dim_length)
|
| 344 | + |
354 | 345 | _, (retdiff, w, tr, discard) = jax.vmap(
|
355 |
| - _inner, in_axes=(0, 0, prev_inaxes_tree, 0, self.in_axes) |
356 |
| - )(sub_keys, indices, prev.inner, chm, argdiffs) |
| 346 | + self.kernel.update, in_axes=(0, prev_inaxes_tree, 0, self.in_axes) |
| 347 | + )(sub_keys, prev.inner, chm.inner, argdiffs) |
357 | 348 | w = jnp.sum(w)
|
358 | 349 | retval = tr.get_retval()
|
359 | 350 | scores = tr.get_score()
|
360 | 351 | map_tr = MapTrace(self, tr, args, retval, jnp.sum(scores))
|
| 352 | + discard = VectorChoiceMap(discard) |
361 | 353 | return key, (retdiff, w, map_tr, discard)
|
362 | 354 |
|
363 | 355 | # The choice map passed in here is empty, but perhaps
|
|
0 commit comments