Skip to content

Add deterministic (& incremental aware) caching to the static language. #1357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from

Conversation

femtomc
Copy link
Collaborator

@femtomc femtomc commented Sep 21, 2024

Adds the ability to store deterministic data in StaticTrace via a new caching primitive.

The caching primitive is aware of incremental computation, and will only recompute the cached function in edit (and update) if changes to the arguments are indicated by UnknownChange.

When holding a StaticTrace, accessing the cache can be performed via StaticTrace.get_cache() (it's stored as a choice map, for now).

TODOs:

  • Caching is not compositional currently -- meaning, there's not a compositional trace interface for it, so you can only use it if you have a StaticGenerativeFunction at the toplevel (so e.g. you get to hold a StaticTrace, and not some other trace (e.g. calling a StaticGenerativeFunction with caching inside of a VmapCombinator -- gives you a VmapTrace, and there's not yet a Trace interface which allows you to fetch the cache inside of the VmapTrace).

@@ -777,3 +777,37 @@ def higher_higher_model():
== genjax.normal.assess(choice.get_submap("y1"), (0.0, 1.0))[0]
+ genjax.normal.assess(choice.get_submap("y2"), (0.0, 1.0))[0]
)


class TestCaching:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@georgematheos See these tests for usage.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@femtomc this looks great!

Is there a way to add to this test something that checks if f is run multiple times? Would it work to have f increment a global counter variable, and check that this isn't incremented after update calls that shouldn't cause re-computation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately JAX may run f abstractly multiple times during tracing -- that doesn't say much about the generated code.

For this one, we might just have to inspect the Jaxprs...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checked the Jaxprs, it works

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try and see to get it working in your codebase

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link

codecov bot commented Sep 21, 2024

Codecov Report

Attention: Patch coverage is 76.76056% with 33 lines in your changes missing coverage. Please review.

Project coverage is 87.52%. Comparing base (799c83c) to head (76c5315).

Files with missing lines Patch % Lines
src/genjax/_src/generative_functions/static.py 75.38% 32 Missing ⚠️
...genjax/_src/core/generative/generative_function.py 87.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1357      +/-   ##
==========================================
- Coverage   87.94%   87.52%   -0.42%     
==========================================
  Files          53       53              
  Lines        3856     3992     +136     
==========================================
+ Hits         3391     3494     +103     
- Misses        465      498      +33     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sritchie sritchie added this to the 0.7.0 milestone Sep 23, 2024
)


def cache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bound it on TypeVar("T", bound=Callable), then it'll pin the full type sig

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I do this, I get an error -- that Callable is a variable and it's not allowed in type expressions.

addresses = self.cached_addresses.get_visited()
return ChoiceMap.from_mapping(zip(addresses, self.cached_values))

def get_cached_state(self, addr) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type this addr, please

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and docs for these new public methods (use Cursor to generate with an example!)

@MathieuHuot
Copy link
Collaborator

MathieuHuot commented Sep 25, 2024

@georgematheos could you give some more context by providing the original example from ChiSight that motivated this (or a simplified version)?

My understanding from one of our discussions was that you wanted to visualise intermediate results within inference for debugging. If that's the main purpose, we want to tune for that purpose, or at least think about this use case some more.

@femtomc
Copy link
Collaborator Author

femtomc commented Sep 25, 2024

@MathieuHuot / @sritchie Let's not merge this in now -- I need to continue working on it to develop an example from George's Notion: https://www.notion.so/chi-mit/Incremental-computation-example-for-McCoy-10701b4b8cde803f9919d308cb0e331c?pvs=4

@femtomc
Copy link
Collaborator Author

femtomc commented Sep 26, 2024

Caching a little message I talked over with Mathieu:

# In the static language...
@genjax.gen
def model(x):
    v = submodel.vmap(in_axes=(0, ))(x) @ "v"
    x = cache("x", jnp.sum)(v) # this is new: incremental aware caching
    return some_other_model(x) @ "y"

tr # model trace in hand
# if you change "y", but don't change "v"
# jnp.sum @ "x" doesn't get recomputed.
#
# If we change "v", but don't change the argument x
# then v gets Diff(new_v, UnknownChange)
# the cache primitive sees this, and recomputes jnp.sum @ "x"
#
# Let's say we change "v" using IndexEditRequest(idx = 3, ...)
# NOTE: this is a TODO in the library (we have to implement IndexEditRequest for vmap).
# Now, the value for v -- what is it tagged with?
# 
# We should have: v gets Diff(new_v, IndexChange(idx, old_v_at_idx))
# And then what does cache do?
# 
# That's where we might want to allow users to specify rules.
@genjax.gen
def model(x):
    v = submodel.vmap(in_axes=(0, ))(x) @ "v"
    v = some_other_fn(v)
    x = cache("x", jnp.sum, custom_rule= lambda v: ...)(v) # this is new: incremental aware caching
    return some_other_model(x) @ "y"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants