Skip to content

Commit c169d1a

Browse files
committed
Local.
1 parent ae6dda7 commit c169d1a

File tree

4 files changed

+44
-5
lines changed

4 files changed

+44
-5
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ site/* linguist-documentation
55
*.html linguist-vendored
66
*.js linguist-vendored
77
*.ipynb linguist-vendored
8+
*.css linguist-vendored

mkdocs.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ repo_url: https://github.com/probcomp/genjax
1818
copyright: Copyright © 2023 MIT Probabilistic Computing Project
1919

2020
nav:
21-
- Home: index.md
21+
- Home:
22+
- index.md
23+
- Gen's concepts: genjax/concepts/generative_functions.md
24+
- Diff against Gen.jl: genjax/diff_jl.md
2225
- Inference notebooks: genjax/notebooks.md
23-
- Concepts:
24-
- Generative functions: genjax/concepts/generative_functions.md
25-
- Diff with Gen.jl: genjax/concepts/diff_jl.md
2626
- Language apéritifs: genjax/language_aperitifs.md
2727
- Public API reference:
2828
- Core: genjax/library/core.md

src/genjax/_src/core/datatypes/trie.py

+16
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,22 @@ def get_subtrees_shallow(self):
289289
return self.trie.get_subtrees_shallow()
290290

291291

292+
###################
293+
# TrieConvertable #
294+
###################
295+
296+
# A mixin: denotes that a choice map can be converted to a TrieChoiceMap
297+
298+
299+
@dataclass
300+
class TrieConvertable:
301+
def convert(self) -> TrieChoiceMap:
302+
new = TrieChoiceMap.new()
303+
for (k, v) in self.get_submaps_shallow():
304+
pass
305+
return new
306+
307+
292308
##############
293309
# Shorthands #
294310
##############

src/genjax/_src/generative_functions/combinators/vector/map_combinator.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
from genjax._src.core.datatypes.generative import GenerativeFunction
2929
from genjax._src.core.datatypes.masks import mask
3030
from genjax._src.core.datatypes.tracetypes import TraceType
31+
from genjax._src.core.datatypes.trie import TrieChoiceMap
3132
from genjax._src.core.interpreters.staging import concrete_cond
3233
from genjax._src.core.typing import Any
3334
from genjax._src.core.typing import FloatArray
35+
from genjax._src.core.typing import Int
3436
from genjax._src.core.typing import IntArray
3537
from genjax._src.core.typing import PRNGKey
3638
from genjax._src.core.typing import Tuple
@@ -194,6 +196,14 @@ def _padder(self, v, key_len):
194196
np.pad(v, pad_axes) if isinstance(v, np.ndarray) else jnp.pad(v, pad_axes)
195197
)
196198

199+
def _static_check_trie_index_compatible(
200+
self, chm: TrieChoiceMap, broadcast_dim_length: Int
201+
):
202+
for (k, _) in chm.get_subtrees_shallow():
203+
assert isinstance(k, int)
204+
# TODO: pull outside loop, just check the last address.
205+
assert k < broadcast_dim_length
206+
197207
def _importance_vcm(self, key, chm, args):
198208
def _importance(key, chm, args):
199209
return self.kernel.importance(key, chm, args)
@@ -218,6 +228,16 @@ def _inner(key, index, chm, args):
218228
map_tr = VectorTrace(self, indices, tr, args, retval, scores)
219229
return key, (w, map_tr)
220230

231+
# Implements a conversion from `TrieChoiceMap`.
232+
def _importance_tchm(self, key, chm, args):
233+
broadcast_dim_length = self._static_broadcast_dim_length(args)
234+
self._static_check_trie_index_compatible(chm, broadcast_dim_length)
235+
236+
# Okay, so the TrieChoiceMap has an address hierarchy which is compatible with the index structure of the MapCombinator choices.
237+
# Let's coerce TrieChoiceMap into VectorChoiceMap and then just call `_importance_vcm`.
238+
vector_chm = self._coerce_to_vector_chm(chm)
239+
return self._importance_vcm(key, vector_chm, args)
240+
221241
def _importance_empty(self, key, _, args):
222242
key, map_tr = self.simulate(key, args)
223243
w = 0.0
@@ -227,7 +247,7 @@ def _importance_empty(self, key, _, args):
227247
def importance(
228248
self,
229249
key: PRNGKey,
230-
chm: Union[EmptyChoiceMap, VectorChoiceMap],
250+
chm: Union[EmptyChoiceMap, TrieChoiceMap, VectorChoiceMap],
231251
args: Tuple,
232252
**_,
233253
) -> Tuple[PRNGKey, Tuple[FloatArray, VectorTrace]]:
@@ -237,6 +257,8 @@ def importance(
237257
# Note: these branches are resolved at tracing time.
238258
if isinstance(chm, VectorChoiceMap):
239259
return self._importance_vcm(key, chm, args)
260+
elif isinstance(chm, TrieChoiceMap):
261+
return self._importance_tchm(key, chm, args)
240262
else:
241263
assert isinstance(chm, EmptyChoiceMap)
242264
return self._importance_empty(key, chm, args)

0 commit comments

Comments
 (0)