-
Notifications
You must be signed in to change notification settings - Fork 4
[POPL] PJAX, compile to PJAX, and a great key massacre. #1567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
This is not ready --but I'm asking for reviews so people look at what's going on: import genjax
from genjax import pjax
import jax.numpy as jnp
import jax.random as jrand
from genjax import gen
from jax import make_jaxpr
@gen
def model():
x = genjax.normal(0.0, 1.0) @ "x"
return x
make_jaxpr(model.simulate)(())
Changing the signatures of the GFI methods -- we arrive at something to implement tr = pjax(model.simulate, jrand.key(1))(())
|
Added a batch semantics for the make_jaxpr(vmap(genjax.normal.sample, in_axes=(0, None)))(jnp.ones(5), 1.0)
|
Working on |
This is turning into a big PR … but only because it cleans out several layers of cruft from GenJAX, and completely removes (!) the concern of “managing keys” — unless desired! By allowing everything to compile down to PJAX (Jaxpr + sample_p):
@MathieuHuot @sritchie this PR is immensely breaking, but I think it is a drastic simplification — and I feel confident about arguing for it. I think a good indication is that, despite touching the entire library, this change has managed to remove about 400 LoCs (of ridiculous and redundant splitting of keys, setting up vmaps over keys, yada yada) |
Hey @femtomc, thanks for all the great work!
|
Not sure, I haven't overloaded Jaxpr printing before. Seems like we could do it with an interpreter, but we'd need to handle e.g. indentation, etc -- like pprinting any AST. I'd feel confident saying: we can do it, but it will add another ~100 LoC to this PR as a new interpreter (and maybe we want wait to do it with Matt H.)
The high level is that this PR moves all of GenJAX is to be implemented in PJAX. Any time a sampler is required in the implementation of anything, we use Instead of requiring that a user always seed a key if they want to execute with JAX, we also define the default evaluation Jaxpr interpreter for When you re-run e.g. cells in Jupyter, you get a stream of new randomness into your sampler. Edit: @MathieuHuot I was wrong about JIT -- if you JIT, the keys get baked in. So what users will want to do if they're JIT'ing is If a user wants to, they can seed their computation with a fixed key. This recovers reproducibility for all GenJAX code.
No, we won't run with the same key twice -- because every unique occurrence of |
I've added a new summary of the changes to the top post: What this PR changes & adds:
|
I still need to update the documentation and notebooks to account for the changes, so this PR is not ready yet. |
Things are slightly more subtle with “hiding the keys everywhere” than I thought:
This is not disastrous — for users, the fix is the same: seed your code — we provide a transformation for it, and that will enable you to generate a sampler which accepts PRNGKeys and doesn’t bake anything in. However, it’s a little unsatisfying in terms of the model where users don’t need to worry about keys very much. The general, most often used, pattern will be “write some complex logic, don’t worry about keys” — done, okay — now “seed” your program, which eliminates all sample_p — now you get back a function which accepts a key, which you can JIT, vmap over keys, yada yada. it’s still a nice modular separation of concerns, but it’s less powerful than I originally intended — unless I figure something else out |
for more information, see https://pre-commit.ci
Our (myself, Mathieu -- at least, and maybe more) claim is that the system should be about creating and editing traces, with a coherent batch semantics, and concise idioms for writing SMC programs using `edit`. This harkens back to several of Colin's intuitions about "promoting array semantics" into the generative world. What better place to start than with the README example -- what did I do? * I added `importance_k` and `edit_k` ("derived") interfaces to `GenerativeFunction` and `Trace` -- these methods just invoke the underlying `importance` and `edit` methods using batch semantics. The `_k` postscript means _either_ (for `importance_k`: give me a `n : Nat` for batch dimension) _or_ (for `edit_k`: apply an `EditRequest` vectorized over a batch of traces). * I added a `resample_k` method to `Trace` -- `resample_k` takes a batch trace & SMC weights and returns a batch trace, as well as an estimate of the log marginal likelihood. This enables the concise posterior program in the new README ("HMC-within-SIR"). View the README on this branch for a walkthrough.
for more information, see https://pre-commit.ci
What this PR changes & adds:
sample_p
primitive which denotes "invoke a primitive sampler", and which is configurable (can be used to register new JAX samplers, new samplers which only work with specific backends, etc).Jaxpr + sample_p
(PJAX), a level above "pure JAX samplers".key: PRNGKey
from all generative function code in the library. This is a breaking change to the signatures of the GFI.sample_p
_to denote "a probabilistic sampling operation is occurring".key: PRNGKey
to actually implement a sampler, so for JAX, thesample_p
primitive registers ajax_impl
which accepts a key _explicitly.PRNGKey
. We restore reproducibility by exposing a program transformationgenjax.seed
_that allows thesample_p
primitives in aJaxpr + sample_p
program to have a fixed order of seeds. This transformation is defined acrossscan_p
,cond_p
, etc -- so it works with generic JAX code with switches, scans. It works across code generated by the combinators.Sample usage of
sample_p
viasample_binder
, a utility for defining new sampler primitives:Users can define new "primitive samplers" using
sample_binder
-- by providing a JAX implementation of the sampler. However, using the sample results in a primitive that occurs in theJaxpr
.The
pjax
transformation eliminates these primitives in terms of their pure JAX implementation, and winds a PRNG key through.Notebook here.