-
Notifications
You must be signed in to change notification settings - Fork 2
[GJEP] DynamicChoiceMap
and DisjointUnionChoiceMap
for dynamic index addresses
#473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
RaggedChoiceMap
and DisjointUnionChoiceMap
for dynamic index addresses
RaggedChoiceMap
and DisjointUnionChoiceMap
for dynamic index addressesDynamicChoiceMap
and DisjointUnionChoiceMap
for dynamic index addresses
After reflection, my discussion in
|
One may also be wondering: what about addresses like the following:
Well here, we must preserve the As a callee |
Migrated to internal Notion. |
The design of this extension is inspired by recent struggles in attempting to write rejuvenation kernels for resample-move SMC:
Typically, users will use the vector combinators to gain access to generative models which produce posterior targets in a sequence. Here’s an example proposal which a user might attempt to use when targeting a model like this:
Here, the
“chain”
address is the call site of the vector combinator in a larger model.Here, the user is attempting to get the current sequence step
s
, and use it as an address in their proposal. GenJAX currently disallows this in theBuiltin
modeling language (the language used to construct the above program).To fully support dynamic addresses (e.g. JAX integer values) as addresses within choice maps generated by the
Builtin
language (and others), I propose the following:Roughly, this plan should support a very expressive form of dynamic address structure in the builtin language — at the runtime cost of duplicate data and runtime
cond
calls from masking.Dynamic choice map
Dynamic choice map likely has the following layout:
Creation of a dynamic choice map will occur during the GFI methods invoked by the
Builtin
language - but only optionally when a user uses dynamic addresses.The
List
representations are trace time only -- they’re a convenient representation which go away at runtime.The real magic occurs in the choice map interfaces. Let’s consider
get_subtree
andhas_subtree
:get_subtree
ChoiceMap.get_subtree(addr)
exposes a way to access submaps. When a user attempts to access into theDynamicChoiceMap
using an index, we can't know which submap they are referring to statically. However, we can creategenjax.Mask
instances with dynamic (runtime) flags over each submap.Indexing into the
DynamicChoiceMap
thus creates aDisjointUnionChoiceMap
with masked leaves -- a new type of choice map which also exposes custom choice map interfaces.DisjointUnionChoiceMap.get_subtree
attempts to get the subtree of all choice maps in the list (and does a mixture of trace time / runtime math to expose the right choices).has_subtree
The story for
has_subtree
is similar.The text was updated successfully, but these errors were encountered: