Skip to content

[POPL] PJAX, compile to PJAX, and a great key massacre. #1567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 43 commits into
base: main
Choose a base branch
from

Conversation

femtomc
Copy link
Collaborator

@femtomc femtomc commented Mar 15, 2025

What this PR changes & adds:

  • Adds a sample_p primitive which denotes "invoke a primitive sampler", and which is configurable (can be used to register new JAX samplers, new samplers which only work with specific backends, etc).
  • This raises the abstraction level of the library -- instead of being about pure JAX samplers (which take and evolve PRNGKey), it's about Jaxpr + sample_p (PJAX), a level above "pure JAX samplers".
  • To adopt this viewpoint to Gen, we remove key: PRNGKey from all generative function code in the library. This is a breaking change to the signatures of the GFI.
  • When generative functions need to sample, they use sample_p _to denote "a probabilistic sampling operation is occurring".
  • This primitive is given semantics by a particular backend: JAX, for instance, needs a key: PRNGKey to actually implement a sampler, so for JAX, the sample_p primitive registers a jax_impl which accepts a key _explicitly.
  • This primitive is given batch semantics that automatically accounts for vectorization.
  • Using PJAX, we can completely remove the requirement that a user ever need to interact or make their own PRNGKeys from the library. All code which previously manipulated keys, can have key manipulation expunged -- while retaining the correctness and functionality of the library.
  • Of course, when we do this -- we lose reproducibility -- the ability for a user to repeat execution with a fixed PRNGKey. We restore reproducibility by exposing a program transformation genjax.seed _that allows the sample_p primitives in a Jaxpr + sample_p program to have a fixed order of seeds. This transformation is defined across scan_p, cond_p, etc -- so it works with generic JAX code with switches, scans. It works across code generated by the combinators.
  • All the generative function interfaces and implementations are changed to cohere with this change.

Sample usage of sample_p via sample_binder, a utility for defining new sampler primitives:

import genjax
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
from genjax import sample_binder
from jax import make_jaxpr
from jax import jit
import jax.numpy as jnp
import jax.random as jrand

beta = sample_binder(
    lambda key, conc0, conc1: tfd.Beta(conc0, conc1).sample(seed=key) # JAX impl
)
print(make_jaxpr(beta)(2.0, 2.0))
jit(beta)(2.0, 2.0)

Users can define new "primitive samplers" using sample_binder -- by providing a JAX implementation of the sampler. However, using the sample results in a primitive that occurs in the Jaxpr.

The pjax transformation eliminates these primitives in terms of their pure JAX implementation, and winds a PRNG key through.

Notebook here.

@femtomc femtomc changed the title Implementation of interpreter from PJAX to pure JAX. [Cold Harbor, 2 / N] Implementation of interpreter from PJAX to pure JAX. Mar 15, 2025
@femtomc
Copy link
Collaborator Author

femtomc commented Mar 15, 2025

This is not ready --but I'm asking for reviews so people look at what's going on:

import genjax
from genjax import pjax
import jax.numpy as jnp
import jax.random as jrand
from genjax import gen
from jax import make_jaxpr

@gen
def model():
    x = genjax.normal(0.0, 1.0) @ "x"
    return x

make_jaxpr(model.simulate)(())
{ lambda ; . let
    a:f32[] = sample[
      abs_eval=<function initial_style_bind.<locals>.bind.<locals>.wrapped.<locals>._abs_eval at 0x11b707920>
      impl=<function initial_style_bind.<locals>.bind.<locals>.wrapped.<locals>._impl at 0x11b7077e0>
      in_tree=PyTreeDef((*, *))
      jax_impl=<function tfp_distribution.<locals>.sampler.<locals>._sampler at 0x11b706980>
      num_consts=0
      out_tree=<function transformation_with_aux2.<locals>.<lambda> at 0x11b706fc0>
      raise_exception=<function sample_binder.<locals>.sampler.<locals>.raise_exception at 0x11b706e80>
    ] 0.0 1.0
    b:f32[] = div a 1.0
    c:f32[] = div 0.0 1.0
    d:f32[] = sub b c
    e:f32[] = square d
    f:f32[] = mul -0.5 e
    g:f32[] = log 1.0
    h:f32[] = add 0.9189385175704956 g
    i:f32[] = sub f h
  in (a, 0.0, 1.0, a, i) }

Changing the signatures of the GFI methods -- we arrive at something mini-like -- the interface model.simulate produces code in terms of JAX + sample_p

to implement model.simulate in terms of JAX, we use the transformation (called) pjax

tr = pjax(model.simulate, jrand.key(1))(())
StaticTrace(
  gen_fn=StaticGenerativeFunction(
    source=Closure(dyn_args=(), fn=<function model at 0x11c7ec900>),
  ),
  args=(),
  retval=<jax.Array(-0.24392003, dtype=float32)>,
  subtraces={
    'x': DistributionTrace(
      gen_fn=genjax.normal(),
      args=(0.0, 1.0),
      value=<jax.Array(-0.24392003, dtype=float32)>,
      score=<jax.Array(-0.948687, dtype=float32)>,
    ),
  },
)

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

Added a batch semantics for the sample_p primitive:

make_jaxpr(vmap(genjax.normal.sample, in_axes=(0, None)))(jnp.ones(5), 1.0)
{ lambda ; a:f32[5] b:f32[]. let
    c:f32[5] = sample[
      abstract=<function initial_style_bind.<locals>.bind.<locals>.wrapped.<locals>.abstract at 0x117b21e40>
      impl=<function initial_style_bind.<locals>.bind.<locals>.wrapped.<locals>.impl at 0x117b20680>
      in_tree=PyTreeDef((*, *))
      jax_impl=<function tfp_distribution.<locals>.sampler.<locals>._sampler at 0x117b222a0>
      num_consts=0
      out_tree=<function transformation_with_aux2.<locals>.<lambda> at 0x117b21ee0>
      raise_exception=<function sample_binder.<locals>.sampler.<locals>.raise_exception at 0x117b20180>
    ] a b
  in (c,) }

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

Adding batch semantics to sample_p implies that the Vmap combinator works instantly:

image image

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

Working on scan_p and cond_p now (for Scan combinator and Switch)

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

This is turning into a big PR … but only because it cleans out several layers of cruft from GenJAX, and completely removes (!) the concern of “managing keys” — unless desired!

By allowing everything to compile down to PJAX (Jaxpr + sample_p):

  • All interfaces and their implementations are simplified (no more key splitting)
  • All existing interpreters don’t need to worry about keys (this includes VI, incremental, etc)
  • The keys can be completely hidden from the user — allowing “numpy style” interactive randomness — which drastically simplifies the presentation of the system (we don’t even have to talk about keys… if we don’t want to — one can recover the old behavior by using a new transformation (genjax.seed))

@MathieuHuot @sritchie this PR is immensely breaking, but I think it is a drastic simplification — and I feel confident about arguing for it.

I think a good indication is that, despite touching the entire library, this change has managed to remove about 400 LoCs (of ridiculous and redundant splitting of keys, setting up vmaps over keys, yada yada)

@MathieuHuot
Copy link
Collaborator

Hey @femtomc, thanks for all the great work!
Some questions:

  • how hard is it to add a pretty printer for the PJAX code that it would look more like x : f[5] = sample(normal, a b)?
  • is the high level picture as follows: compile GF code to PJAX, then apply a key transform + simulate transform to sample, which compiles to JAXPR, and then we're good to do?
  • It's not immediately clear to me just yet whether the key handling will be correct because it's now hidden from the user and there could be blobs of JAX code all over the place. For instance, if I have a Python for loop, and within in I call a jitted model.simulate, but I don't jit the for loop. In that case, won't we run the exact same simulation a bunch of times and mess up probabilities? Internally, the computation is deterministic starting at a fixed key of your choice, but one JAX computation doesn't know about another one unless you jit the overall code to make the 2 blobs of JAX code into one.

@femtomc femtomc requested a review from MathieuHuot March 16, 2025 15:11
@femtomc femtomc changed the title [Cold Harbor, 2 / N] Implementation of interpreter from PJAX to pure JAX. [Cold Harbor, 2 / N] Compiler to PJAX, and the great key massacre. Mar 16, 2025
@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

@MathieuHuot

how hard is it to add a pretty printer for the PJAX code that it would look more like x : f[5] = sample(normal, a b)?

Not sure, I haven't overloaded Jaxpr printing before. Seems like we could do it with an interpreter, but we'd need to handle e.g. indentation, etc -- like pprinting any AST. I'd feel confident saying: we can do it, but it will add another ~100 LoC to this PR as a new interpreter (and maybe we want wait to do it with Matt H.)

is the high level picture as follows: compile GF code to PJAX, then apply a key transform + simulate transform to sample, which compiles to JAXPR, and then we're good to do?

The high level is that this PR moves all of GenJAX is to be implemented in PJAX. Any time a sampler is required in the implementation of anything, we use sample_p. You have roughly the correct idea for the transforms -- all interfaces emit code in Jaxpr + sample_p -- a key transformation can then be used to transformation Jaxpr + sample_p -> Jaxpr. Using sample_p by itself doesn't require any usage of PRNGKey, it's just a new primitive.

Instead of requiring that a user always seed a key if they want to execute with JAX, we also define the default evaluation Jaxpr interpreter for sample_p to seed a key from a global non-JAX int counter (https://github.com/ChiSym/genjax/blob/de102e7283b8115de1e7220db2793906067fb141/src/genjax/_src/core/compiler/pjax.py#L75-L77). This is static (not dynamic for JAX), and simply counts the number of unique calls to sample_p, and uses that to generate an "infinite" stream of keys.

When you re-run e.g. cells in Jupyter, you get a stream of new randomness into your sampler.

Edit: @MathieuHuot I was wrong about JIT -- if you JIT, the keys get baked in. So what users will want to do if they're JIT'ing is jit(seed(their_fn))(key, *args) which will return a function that allows new randomness (use the key transform explicitly). The "non-JIT" behavior is numpy-like -- you can generate randomness on the fly (no baking in).

If a user wants to, they can seed their computation with a fixed key. This recovers reproducibility for all GenJAX code.

It's not immediately clear to me just yet whether the key handling will be correct because it's now hidden from the user and there could be blobs of JAX code all over the place. For instance, if I have a Python for loop, and within in I call a jitted model.simulate, but I don't jit the for loop. In that case, won't we run the exact same simulation a bunch of times and mess up probabilities? Internally, the computation is deterministic starting at a fixed key of your choice, but one JAX computation doesn't know about another one unless you jit the overall code to make the 2 blobs of JAX code into one.

No, we won't run with the same key twice -- because every unique occurrence of sample_p will get a globally unique key in the default JAX eval interpreter (as I mention above).

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

I've added a new summary of the changes to the top post:

What this PR changes & adds:

  • Adds a common sample_p primitive which denotes "any sampler", and which is configurable (can be used to register new JAX samplers, new samplers which only work with specific backends, etc).
  • This raises the abstraction level of the library -- instead of being about pure JAX samplers (which take and evolve PRNGKey), it's about Jaxpr + sample_p (PJAX), a level above "pure JAX samplers".
  • To adopt this viewpoint to Gen, we remove key: PRNGKey from all generative function code in the library. This is a breaking change to the signatures of the GFI.
  • When generative functions need to sample, they use sample_p _to denote "a probabilistic sampling operation is occurring".
  • This primitive is given semantics by a particular backend: JAX, for instance, needs key: PRNGKey, so for JAX, the sample_p primitive registers a jax_impl which accepts a key _explicitly.
  • This primitive is given batch semantics that automatically accounts for vectorization.
  • Using PJAX, we can completely remove the requirement that a user ever need to interact or make their own PRNGKeys from the library. All code which previously manipulated keys, can have key manipulation expunged -- while retaining the correctness and functionality of the library.
  • Of course, when we do this -- we lose reproducibility -- the ability for a user to repeat execution with a fixed PRNGKey. We restore reproducibility by exposing a program transformation genjax.seed _that allows the sample_p primitives in a Jaxpr + sample_p program to have a fixed order of seeds. This transformation is defined across scan_p, cond_p, etc -- so it works with generic JAX code with switches, scans. It works across code generated by the combinators.
  • All the generative function interfaces and implementations are changed to cohere with this change.

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

I still need to update the documentation and notebooks to account for the changes, so this PR is not ready yet.

@femtomc
Copy link
Collaborator Author

femtomc commented Mar 16, 2025

Things are slightly more subtle with “hiding the keys everywhere” than I thought:

  • JIT is problematic, because if we use a library global counter to seed keys, JIT will bake in the counter values as static — and cache the result — implying that JITing a PJAX sampler will return the same value each time.
  • of course, seeding the keys from outside fixes this, so for users, should raise a sharp error — the right place to put this error is when JAX is attempting to lower a primitive to MLIR. The error: “hey, JAX is attempting to lower some of your code, and it will bake the key in — please seed your code”. Seeding eliminates the primitive — so once you seed, at any level which contains sample_p, you’re good to go
  • The problem is that JIT is not the only place that JAX lowers the primitive to MLIR — the execution of any control flow / second order primitives involves lowering their Jaxpr to MLIR, which means that the above error gets triggered with JIT … but also any combinator

This is not disastrous — for users, the fix is the same: seed your code — we provide a transformation for it, and that will enable you to generate a sampler which accepts PRNGKeys and doesn’t bake anything in.

However, it’s a little unsatisfying in terms of the model where users don’t need to worry about keys very much.

The general, most often used, pattern will be “write some complex logic, don’t worry about keys” — done, okay — now “seed” your program, which eliminates all sample_p — now you get back a function which accepts a key, which you can JIT, vmap over keys, yada yada.

it’s still a nice modular separation of concerns, but it’s less powerful than I originally intended — unless I figure something else out

@femtomc femtomc marked this pull request as draft March 16, 2025 20:45
@femtomc femtomc changed the title [Cold Harbor, 2 / N] Compiler to PJAX, and the great key massacre. [POPL] Compiler to PJAX, and the great key massacre. Mar 16, 2025
@femtomc femtomc changed the title [POPL] Compiler to PJAX, and the great key massacre. [POPL] PJAX, compile to PJAX, and a great key massacre. Mar 16, 2025
Base automatically changed from mrb/cold_harbor to main March 16, 2025 21:35
femtomc and others added 16 commits March 16, 2025 19:04
Our (myself, Mathieu -- at least, and maybe more) claim is that the
system should be about creating and editing traces, with a coherent
batch semantics, and concise idioms for writing SMC programs using
`edit`. This harkens back to several of Colin's intuitions about
"promoting array semantics" into the generative world.

What better place to start than with the README example -- what did I
do?
* I added `importance_k` and `edit_k` ("derived") interfaces to
`GenerativeFunction` and `Trace` -- these methods just invoke the
underlying `importance` and `edit` methods using batch semantics. The
`_k` postscript means _either_ (for `importance_k`: give me a `n : Nat`
for batch dimension) _or_ (for `edit_k`: apply an `EditRequest`
vectorized over a batch of traces).
* I added a `resample_k` method to `Trace` -- `resample_k` takes a batch
trace & SMC weights and returns a batch trace, as well as an estimate of
the log marginal likelihood.

This enables the concise posterior program in the new README
("HMC-within-SIR").

View the README on this branch for a walkthrough.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants