Skip to content

[GJEP] DynamicChoiceMap and DisjointUnionChoiceMap for dynamic index addresses #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
femtomc opened this issue Aug 9, 2023 · 3 comments
Assignees

Comments

@femtomc
Copy link
Collaborator

femtomc commented Aug 9, 2023

The design of this extension is inspired by recent struggles in attempting to write rejuvenation kernels for resample-move SMC:

I’m realizing I’m unsure how to make an MH drift proposal in GenJAX.

Typically, users will use the vector combinators to gain access to generative models which produce posterior targets in a sequence. Here’s an example proposal which a user might attempt to use when targeting a model like this:

@genjax.gen
def step_pose_drift_proposal(tr, position_noise, heading_noise):
    s = tr.args[1] - 1

    old_heading = tr[(‘chain’, ‘heading’)][s]
    genjax.tfp_normal(old_heading, heading_noise) @ (‘chain’, s, ‘heading’)

    return tr

Here, the “chain” address is the call site of the vector combinator in a larger model.

Here, the user is attempting to get the current sequence step s, and use it as an address in their proposal. GenJAX currently disallows this in the Builtin modeling language (the language used to construct the above program).

To fully support dynamic addresses (e.g. JAX integer values) as addresses within choice maps generated by the Builtin language (and others), I propose the following:

  • when users use static integer addresses inside the builtin language, business as normal — they can be used as keys, no changes.
  • when users use dynamic integer addresses inside the builtin language, a “dynamic choice map” is created. Dynamic choice map is a core data type, so that it may be used as part of the implementation of the builtin language interfaces.
  • In the builtin language, dynamic integer addresses can scope heterogeneous address hierarchies. When heterogeneous hierarchies are introduced, the existing dynamic choice map is merged and filled in with choices — to form (another, separate from Switch) type of union choice map. Heterogeneity is represented with masked values.
  • Duplicate dynamic index addresses are disallowed by Gen’s semantics — but we can’t enforce this statically (unlike static addresses) runtime error handling is handled by checkify.

Roughly, this plan should support a very expressive form of dynamic address structure in the builtin language — at the runtime cost of duplicate data and runtime cond calls from masking.

Dynamic choice map

Dynamic choice map likely has the following layout:

@dataclass
class DynamicChoiceMap(ChoiceMap):
    dynamic_indices: List[IntArray]
    submaps: List[ChoiceMap]

Creation of a dynamic choice map will occur during the GFI methods invoked by the Builtin language - but only optionally when a user uses dynamic addresses.

The List representations are trace time only -- they’re a convenient representation which go away at runtime.

The real magic occurs in the choice map interfaces. Let’s consider get_subtree and has_subtree:

get_subtree

ChoiceMap.get_subtree(addr) exposes a way to access submaps. When a user attempts to access into the DynamicChoiceMap using an index, we can't know which submap they are referring to statically. However, we can create genjax.Mask instances with dynamic (runtime) flags over each submap.

Indexing into the DynamicChoiceMap thus creates a DisjointUnionChoiceMap with masked leaves -- a new type of choice map which also exposes custom choice map interfaces. DisjointUnionChoiceMap.get_subtree attempts to get the subtree of all choice maps in the list (and does a mixture of trace time / runtime math to expose the right choices).

has_subtree

The story for has_subtree is similar.

@femtomc femtomc self-assigned this Aug 9, 2023
@femtomc femtomc changed the title [GJEP] [GJEP] RaggedChoiceMap and DisjointUnionChoiceMap for dynamic index addresses Aug 9, 2023
@femtomc femtomc changed the title [GJEP] RaggedChoiceMap and DisjointUnionChoiceMap for dynamic index addresses [GJEP] DynamicChoiceMap and DisjointUnionChoiceMap for dynamic index addresses Aug 10, 2023
@femtomc
Copy link
Collaborator Author

femtomc commented Aug 10, 2023

After reflection, my discussion in get_subtree is not complete:

get_subtree

Indexing into the DynamicChoiceMap thus creates a DisjointUnionChoiceMap with masked leaves -- a new type of choice map which also exposes custom choice map interfaces. DisjointUnionChoiceMap.get_subtree attempts to get the subtree of all choice maps in the list (and does a mixture of trace time / runtime math to expose the right choices).

I don't think masking alone is sufficient to solve the problem -- what I mean will become clean if we look at an example implementation of DisjointUnionChoiceMap.get_subtree. Below, let's assume that the DisjointUnionChoiceMap instance in our thought experiment is valid and generated from DynamicChoiceMap -- meaning that, having been generated from DynamicChoiceMap -- it has masked leaves -- but only one of the masks is actually True.

Let's consider an initial implementation of DisjointUnionChoiceMap.get_subtree:

def get_subtree(self, addr):
    new_subtrees = list(
        filter(
            lambda v: not isinstance(v, EmptyChoiceMap),
            map(lambda v: v.get_subtree(addr), self.subtrees),
        )
    )
    if len(new_subtrees) == 0:
        return EmptyChoiceMap()
    elif len(new_subtrees) == 1:
        return new_subtrees[0]
    else:
        if all(map(lambda v: isinstance(v, AddressLeaf), new_subtrees)):
            leaf_values = map(lambda v: v.get_leaf_value(), new_subtrees)
            assert all(map(lambda v: isinstance(v, Mask), leaf_values))
            masks = jnp.array(
                list(map(lambda v: v.mask, leaf_values)),
            )

            def _check():
                check_flag = jnp.any(masks)
                return checkify.check(
                    check_flag,
                    "(DisjointUnionChoiceMap.get_subtree): masked leaf values have no valid data.",
                )

            global_options.optional_check(_check)

            tag = jnp.argwhere(masks, size=1)
            return tagged_union(tag, new_subtrees)
        else:
            return DisjointUnionChoiceMap(new_subtrees)

Here, the type of DisjointUnionChoiceMap is propagated down as you index into the possible set of subtrees which it shadows. The fundamental reason is that you can't always statically resolve which subtree you're in. Sometimes you can (!) e.g. if the subtrees don't share addresses -- but in the worst case, you might have two subtrees which share addresses down to their leaves.

This fact has two implications:

  • The generic versions of the generative function interface which each implementor (e.g. language, combinator) provides must be able to handle this propagation.
  • The propagation stops when you hit AddressLeaf types (e.g. ValueChoiceMap).

What happens when you traverse all the way down to leaves, and you haven't been able to statically resolve which subtree you are in?

Well, because of the heterogeneity -- it's possible that the leaf values are different types. For now, I'm only considering cases where the values are JAX array-like (and not pieces of structured data e.g. other Pytrees -- which adds complexity, and we can care about that later).

We're assuming that our instance is a valid choice map -- so, if we hit leaves and we have more than 1 subtree, we have to have masked leaves -- and, of those masked leaves, only 1 can actually be True. By using this fact, we can identify now at the leaves which leaf is actually active.

But lo! We still can't collapse it -- because which leaf is actually active is runtime determined, and each leaf may contain an array of different shape from the other leaf! So we have to return a new type of "mask" -- a TaggedUnion, which pulls the leaf values into a single List, and returns a single leaf -- but contains a tag metadata which indicates which leaf value is the actual one.

TaggedUnion is a lot like a Mask -- but it exposes a special method TaggedUnion.switch -- which implements switch functionality. If a user wishes to handle a TaggedUnion, they must provide callables which handle each branch of the TaggedUnion, without knowing which they are actually on. The callables must be defined so that they return the same dtype/shape -- but this constraint is true for most (if not, all) lower level computations performed during inference metadata computation by Gen.

Below is an example:
image

The example should also make clear what I mean by "heterogeneity" in the above discussion: with DynamicChoiceMap, it's certainly possible to have addresses like the following:

("x", jnp.array(1), "z")
("x", jnp.array(2), "z")

whose leaf choice maps are like the ValueChoiceMap instances in my example.

When you index into the index address scope -- you'd get a DisjointUnionChoiceMap over the subtrees below the dynamic indices. Further indexing in (by e.g. get_subtree("z")) would result in a ValueChoiceMap whose leaf value is a TaggedUnion.

@femtomc
Copy link
Collaborator Author

femtomc commented Aug 10, 2023

One may also be wondering: what about addresses like the following:

("x", jnp.array(1), "z", "y")
("x", jnp.array(2), "z")

Well here, we must preserve the DisjointUnionChoiceMap type -- despite the fact that we've hit a leaf type on one of the branches. Just like masking, the trick is that higher-level generative functions push the handling of this type down into their callees, ultimately -- all the way down to the "leaf" callees (like Distribution).

As a callee Distribution needs to handle DisjointUnionChoiceMap (because e.g. it's representing type uncertainty -- indexing into "z", when I give a Distribution the result of that indexing and ask it to invoke a GFI method -- it needs to be able to handle a DisjointUnionChoiceMap which is representing "I may be a ValueChoiceMap, but I also might not apply to you (e.g. I'm a HierarchicalChoiceMap with an address "x").

@sritchie
Copy link
Contributor

Migrated to internal Notion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants