Skip to content

Commit 481112f

Browse files
authored
Merge pull request #425 from probcomp/mrb/407
Fix Map.update for #407.
2 parents 605fff9 + 157d48b commit 481112f

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

Diff for: src/genjax/_src/generative_functions/combinators/vector/map_combinator.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
from genjax._src.core.datatypes.generative import NoneSelection
3030
from genjax._src.core.datatypes.generative import Selection
3131
from genjax._src.core.datatypes.generative import Trace
32-
from genjax._src.core.datatypes.masks import mask
3332
from genjax._src.core.datatypes.tracetypes import TraceType
3433
from genjax._src.core.datatypes.trie import Trie
3534
from genjax._src.core.interpreters.staging import concrete_cond
35+
from genjax._src.core.transforms.incremental import tree_diff_primal
3636
from genjax._src.core.typing import Any
3737
from genjax._src.core.typing import FloatArray
3838
from genjax._src.core.typing import IntArray
@@ -334,30 +334,22 @@ def update(
334334
chm: VectorChoiceMap,
335335
argdiffs: Tuple,
336336
):
337-
def _update(key, prev, chm, argdiffs):
338-
key, (retdiff, w, tr, d) = self.kernel.update(key, prev, chm, argdiffs)
339-
return key, (retdiff, w, tr, d)
340-
341-
def _inner(key, index, prev, chm, argdiffs):
342-
check = index == chm.get_index()
343-
masked = mask(check, chm.inner)
344-
return _update(key, prev, masked, argdiffs)
345-
346-
args = jtu.tree_leaves(argdiffs)
337+
args = tree_diff_primal(argdiffs)
347338
self._static_check_broadcastable(args)
348339
broadcast_dim_length = self._static_broadcast_dim_length(args)
349-
indices = np.array([i for i in range(0, broadcast_dim_length)])
350340
prev_inaxes_tree = jtu.tree_map(
351341
lambda v: None if v.shape == () else 0, prev.inner
352342
)
353343
key, sub_keys = slash(key, broadcast_dim_length)
344+
354345
_, (retdiff, w, tr, discard) = jax.vmap(
355-
_inner, in_axes=(0, 0, prev_inaxes_tree, 0, self.in_axes)
356-
)(sub_keys, indices, prev.inner, chm, argdiffs)
346+
self.kernel.update, in_axes=(0, prev_inaxes_tree, 0, self.in_axes)
347+
)(sub_keys, prev.inner, chm.inner, argdiffs)
357348
w = jnp.sum(w)
358349
retval = tr.get_retval()
359350
scores = tr.get_score()
360351
map_tr = MapTrace(self, tr, args, retval, jnp.sum(scores))
352+
discard = VectorChoiceMap(discard)
361353
return key, (retdiff, w, map_tr, discard)
362354

363355
# The choice map passed in here is empty, but perhaps

Diff for: src/genjax/_src/inference/mcmc/metropolis_hastings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def apply(self, key: PRNGKey, trace: Trace, proposal_args: Tuple):
4343
fwd_weight = proposal_tr.get_score()
4444
diffs = Diff.no_change(model_args)
4545
key, (_, weight, new, discard) = model.update(
46-
key, trace, proposal_tr.get_choices(), diffs
46+
key, trace, proposal_tr.strip(), diffs
4747
)
4848
proposal_args_bwd = (new, *proposal_args)
4949
key, (bwd_weight, _) = self.proposal.importance(key, discard, proposal_args_bwd)

Diff for: tests/inference/mcmc/test_metropolis_hastings.py

+21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import jax
17+
import jax.numpy as jnp
1718

1819
import genjax
1920
from genjax.inference.mcmc import MetropolisHastings
@@ -46,3 +47,23 @@ def proposal(nowAt, d):
4647
assert tr.get_score() != new.get_score()
4748
else:
4849
assert tr.get_score() == new.get_score()
50+
51+
def test_map_combinator(self):
52+
@genjax.gen
53+
def model():
54+
loc = genjax.normal(0., 1.) @ 'loc'
55+
xs = genjax.Map(genjax.normal, in_axes=(None, 0))(loc, jnp.arange(10)) @ 'xs'
56+
return xs
57+
58+
59+
@genjax.gen
60+
def proposal(choices):
61+
loc = choices['loc']
62+
xs = genjax.Map(genjax.normal, in_axes=(None, 0))(loc, jnp.arange(10)) @ 'xs'
63+
return xs
64+
65+
66+
key = jax.random.PRNGKey(314159)
67+
key, trace = genjax.simulate(model)(key, ())
68+
genjax.inference.mcmc.mh(proposal).apply(key, trace, ())
69+
assert True

0 commit comments

Comments
 (0)