Skip to content

Kill Indexed for good, push down into leaves [in progress] (GEN-888) #1448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

sritchie
Copy link
Contributor

No description provided.

@sritchie sritchie changed the title Kill Indexed for good, push down into leaves [in progress] Kill Indexed for good, push down into leaves [in progress] (GEN-631) Nov 29, 2024
Copy link

linear bot commented Nov 29, 2024

GEN-631 Investigate removing "Indexed" data structure

I suspect that we could remove Indexed, and the idea of interspersing dynamic indices with static addresses, fully from the codebase.

This observation hit as I was implementing slicing over in github. I realized that to do slicing correctly, Indexed could no longer `tree_map` in its get_submap call, but would have to push its accumulated indices down to the leaves to be applied only once we hit a Choice.

Currently, the indices in Indexed are dense index arrays, usually built by a scan or vmap call. But if a user manually calls jax.vmap or jax.lax.scan, they can introduce extra dimensions into these arrays.
This is broken because each Indexed layer currently handles a single entry in the __get__ addresses; a multi-dimensional outer layer would only perform a single lookup and then pass an array-shaped mask down the line. If this doesn't break things it's certainly hard to reason about.

Can we delete Indexed?

Notice:

  • a ChoiceMap is a pytree with tensors at its leaves
  • all of the tensor dimensions have to match, as validated by jax.vmap
  • if the dimensions DON'T match, the user has to specify what dimensions to vmap over via in_axes
  • none of the indexing that we intercept in get_submap calls to Indexed does anything in the tree. It's all pushed down to the leaves.

It would be conceptually simpler to think of ChoiceMap as a nested dictionary with static, string-shaped keys and tensors down at the leaves.

This would solve:

  • the confusion about how to access elements - the path exposed by treescope is now the query path down to the leaf
  • slice syntax would now be automatically supported, down at the leaves
  • the "how do I create a choicemap for vmapped models" question becomes simple — just make the tensor you want and assign it with, e.g., C["path", "to", "tensor"].set(tensor)
  • users wouldn't to learn a new update syntax — they can get a python dictionary with tensor leaves out and modify it in Python, no new functions needs
  • the distinction between jax.vmap and .vmap() can now be summarized as — .vmap() calls create a new dimension but sum their score along that dimension, while jax.vmap results in broadcast, non-summed scores. Same with .scan() vs jax.lax.scan.

What do we lose?

  • removing Indexed would lose the ability to specify updates for a sparse subset of indices. We would need to find some other way of specifying a partial update with something like, e.g., a combo of an array and a selector, or something like a sparse array
    • Indexed currently accomplishes this by introducing a mask call that gets pushed down to the leaves. Mask can remain as an explicit idea without Indexed having to introduce it.

Questions:

  • How do we specify partial updates to a Scan?
  • Can we fully remove Or and Xor? I think we can, in favor of implementing the operations between dictionary-shaped choicemaps.
    • For __or__, if two leaves clash, pick the one on the left.
    • For __xor__, if two leaves clash, error. (This is subject to the more complex combine-leaves-with-masked-values logic that we'll have to preserve)

Copy link

codecov bot commented Dec 2, 2024

Codecov Report

Attention: Patch coverage is 57.40741% with 46 lines in your changes missing coverage. Please review.

Project coverage is 87.86%. Comparing base (52c06fc) to head (6ca40b3).

Files with missing lines Patch % Lines
...rc/genjax/_src/core/generative/functional_types.py 43.83% 41 Missing ⚠️
src/genjax/_src/core/generative/choice_map.py 85.29% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1448      +/-   ##
==========================================
- Coverage   88.67%   87.86%   -0.81%     
==========================================
  Files          55       55              
  Lines        3990     4064      +74     
==========================================
+ Hits         3538     3571      +33     
- Misses        452      493      +41     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sritchie sritchie changed the title Kill Indexed for good, push down into leaves [in progress] (GEN-631) Kill Indexed for good, push down into leaves [in progress] (GEN-888) Dec 4, 2024
Copy link

linear bot commented Dec 4, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant