28
28
from genjax ._src .core .datatypes .generative import GenerativeFunction
29
29
from genjax ._src .core .datatypes .masks import mask
30
30
from genjax ._src .core .datatypes .tracetypes import TraceType
31
+ from genjax ._src .core .datatypes .trie import TrieChoiceMap
31
32
from genjax ._src .core .interpreters .staging import concrete_cond
32
33
from genjax ._src .core .typing import Any
33
34
from genjax ._src .core .typing import FloatArray
35
+ from genjax ._src .core .typing import Int
34
36
from genjax ._src .core .typing import IntArray
35
37
from genjax ._src .core .typing import PRNGKey
36
38
from genjax ._src .core .typing import Tuple
@@ -194,6 +196,14 @@ def _padder(self, v, key_len):
194
196
np .pad (v , pad_axes ) if isinstance (v , np .ndarray ) else jnp .pad (v , pad_axes )
195
197
)
196
198
199
+ def _static_check_trie_index_compatible (
200
+ self , chm : TrieChoiceMap , broadcast_dim_length : Int
201
+ ):
202
+ for (k , _ ) in chm .get_subtrees_shallow ():
203
+ assert isinstance (k , int )
204
+ # TODO: pull outside loop, just check the last address.
205
+ assert k < broadcast_dim_length
206
+
197
207
def _importance_vcm (self , key , chm , args ):
198
208
def _importance (key , chm , args ):
199
209
return self .kernel .importance (key , chm , args )
@@ -218,6 +228,16 @@ def _inner(key, index, chm, args):
218
228
map_tr = VectorTrace (self , indices , tr , args , retval , scores )
219
229
return key , (w , map_tr )
220
230
231
+ # Implements a conversion from `TrieChoiceMap`.
232
+ def _importance_tchm (self , key , chm , args ):
233
+ broadcast_dim_length = self ._static_broadcast_dim_length (args )
234
+ self ._static_check_trie_index_compatible (chm , broadcast_dim_length )
235
+
236
+ # Okay, so the TrieChoiceMap has an address hierarchy which is compatible with the index structure of the MapCombinator choices.
237
+ # Let's coerce TrieChoiceMap into VectorChoiceMap and then just call `_importance_vcm`.
238
+ vector_chm = self ._coerce_to_vector_chm (chm )
239
+ return self ._importance_vcm (key , vector_chm , args )
240
+
221
241
def _importance_empty (self , key , _ , args ):
222
242
key , map_tr = self .simulate (key , args )
223
243
w = 0.0
@@ -227,7 +247,7 @@ def _importance_empty(self, key, _, args):
227
247
def importance (
228
248
self ,
229
249
key : PRNGKey ,
230
- chm : Union [EmptyChoiceMap , VectorChoiceMap ],
250
+ chm : Union [EmptyChoiceMap , TrieChoiceMap , VectorChoiceMap ],
231
251
args : Tuple ,
232
252
** _ ,
233
253
) -> Tuple [PRNGKey , Tuple [FloatArray , VectorTrace ]]:
@@ -237,6 +257,8 @@ def importance(
237
257
# Note: these branches are resolved at tracing time.
238
258
if isinstance (chm , VectorChoiceMap ):
239
259
return self ._importance_vcm (key , chm , args )
260
+ elif isinstance (chm , TrieChoiceMap ):
261
+ return self ._importance_tchm (key , chm , args )
240
262
else :
241
263
assert isinstance (chm , EmptyChoiceMap )
242
264
return self ._importance_empty (key , chm , args )
0 commit comments