-
Notifications
You must be signed in to change notification settings - Fork 4
Add deterministic (& incremental aware) caching to the static language. #1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -777,3 +777,37 @@ def higher_higher_model(): | |||
== genjax.normal.assess(choice.get_submap("y1"), (0.0, 1.0))[0] | |||
+ genjax.normal.assess(choice.get_submap("y2"), (0.0, 1.0))[0] | |||
) | |||
|
|||
|
|||
class TestCaching: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@georgematheos See these tests for usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@femtomc this looks great!
Is there a way to add to this test something that checks if f
is run multiple times? Would it work to have f
increment a global counter variable, and check that this isn't incremented after update
calls that shouldn't cause re-computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately JAX may run f
abstractly multiple times during tracing -- that doesn't say much about the generated code.
For this one, we might just have to inspect the Jaxprs...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checked the Jaxprs, it works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try and see to get it working in your codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1357 +/- ##
==========================================
- Coverage 87.94% 87.52% -0.42%
==========================================
Files 53 53
Lines 3856 3992 +136
==========================================
+ Hits 3391 3494 +103
- Misses 465 498 +33 ☔ View full report in Codecov by Sentry. |
) | ||
|
||
|
||
def cache( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bound it on TypeVar("T", bound=Callable)
, then it'll pin the full type sig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I do this, I get an error -- that Callable
is a variable and it's not allowed in type expressions.
addresses = self.cached_addresses.get_visited() | ||
return ChoiceMap.from_mapping(zip(addresses, self.cached_values)) | ||
|
||
def get_cached_state(self, addr) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type this addr, please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and docs for these new public methods (use Cursor to generate with an example!)
@georgematheos could you give some more context by providing the original example from ChiSight that motivated this (or a simplified version)? My understanding from one of our discussions was that you wanted to visualise intermediate results within inference for debugging. If that's the main purpose, we want to tune for that purpose, or at least think about this use case some more. |
@MathieuHuot / @sritchie Let's not merge this in now -- I need to continue working on it to develop an example from George's Notion: https://www.notion.so/chi-mit/Incremental-computation-example-for-McCoy-10701b4b8cde803f9919d308cb0e331c?pvs=4 |
…ctions with custom implementations.
…x into mrb/deterministic_caching
…ctions with custom implementations.
Caching a little message I talked over with Mathieu: # In the static language...
@genjax.gen
def model(x):
v = submodel.vmap(in_axes=(0, ))(x) @ "v"
x = cache("x", jnp.sum)(v) # this is new: incremental aware caching
return some_other_model(x) @ "y"
tr # model trace in hand
# if you change "y", but don't change "v"
# jnp.sum @ "x" doesn't get recomputed.
#
# If we change "v", but don't change the argument x
# then v gets Diff(new_v, UnknownChange)
# the cache primitive sees this, and recomputes jnp.sum @ "x"
#
# Let's say we change "v" using IndexEditRequest(idx = 3, ...)
# NOTE: this is a TODO in the library (we have to implement IndexEditRequest for vmap).
# Now, the value for v -- what is it tagged with?
#
# We should have: v gets Diff(new_v, IndexChange(idx, old_v_at_idx))
# And then what does cache do?
#
# That's where we might want to allow users to specify rules.
@genjax.gen
def model(x):
v = submodel.vmap(in_axes=(0, ))(x) @ "v"
v = some_other_fn(v)
x = cache("x", jnp.sum, custom_rule= lambda v: ...)(v) # this is new: incremental aware caching
return some_other_model(x) @ "y" |
…x into mrb/deterministic_caching
Adds the ability to store deterministic data in
StaticTrace
via a new caching primitive.The caching primitive is aware of incremental computation, and will only recompute the cached function in
edit
(andupdate
) if changes to the arguments are indicated byUnknownChange
.When holding a
StaticTrace
, accessing the cache can be performed viaStaticTrace.get_cache()
(it's stored as a choice map, for now).TODOs:
StaticGenerativeFunction
at the toplevel (so e.g. you get to hold aStaticTrace
, and not some other trace (e.g. calling aStaticGenerativeFunction
with caching inside of aVmapCombinator
-- gives you aVmapTrace
, and there's not yet aTrace
interface which allows you to fetch the cache inside of theVmapTrace
).