Skip to content

Commit cb97e8e

Browse files
authored
[Cold Harbor, 1 / N] Organize the lowest levels of the GenJAX compiler. (#1566)
**No new functionality** Performing a little re-organization in `_src.core` to make the structure of the GenJAX stack significantly more clear. A new lower level module `genjax._src.core.compiler` -- contains infrastructure for JAX primitives, interpreters, staging, etc. Interpreters share some core functionality (`Environment`) in `genjax._src.core.compiler.interpreters.environment` and `genjax._src.core.compiler.initial_style_primitive`.
1 parent 9e2a23b commit cb97e8e

38 files changed

+399
-326
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
data
2+
.hypothesis
23
*.pyc
34
*.so
45
*.egg-info

docs/cookbook/active/generative_function_interface.ipynb

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030
"\n",
3131
"from genjax import ChoiceMapBuilder as C\n",
3232
"from genjax import (\n",
33+
" Diff,\n",
34+
" NoChange,\n",
35+
" UnknownChange,\n",
3336
" bernoulli,\n",
3437
" beta,\n",
3538
" gen,\n",
3639
" pretty,\n",
3740
")\n",
3841
"from genjax._src.generative_functions.static import MissingAddress\n",
39-
"from genjax.incremental import Diff, NoChange, UnknownChange\n",
4042
"\n",
4143
"key = jax.random.key(0)\n",
4244
"pretty()\n",

docs/cookbook/inactive/inference/mcmc.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"\n",
3636
"from genjax import ChoiceMapBuilder as C\n",
3737
"from genjax import gen, normal, pretty\n",
38-
"from genjax._src.core.interpreters.incremental import Diff\n",
38+
"from genjax._src.core.compiler.interpreters.incremental import Diff\n",
3939
"\n",
4040
"key = jax.random.key(0)\n",
4141
"pretty()"

src/genjax/_src/adev/core.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
from jax.interpreters import ad as jax_autodiff
2525
from jax.interpreters import batching
2626

27-
from genjax._src.core.interpreters.forward import (
28-
Environment,
27+
from genjax._src.core.compiler.initial_style_primitive import (
2928
InitialStylePrimitive,
3029
initial_style_bind,
3130
)
32-
from genjax._src.core.interpreters.staging import stage
31+
from genjax._src.core.compiler.interpreters.environment import Environment
32+
from genjax._src.core.compiler.staging import stage
3333
from genjax._src.core.pytree import Pytree
3434
from genjax._src.core.typing import (
3535
Annotated,
@@ -248,7 +248,7 @@ def flat_unzip(duals: list[Any]):
248248
return list(primals), list(tangents)
249249

250250
@staticmethod
251-
def _eval_jaxpr_adev_jvp(
251+
def eval_jaxpr_adev(
252252
key: PRNGKey,
253253
jaxpr: Jaxpr,
254254
consts: list[ArrayLike],
@@ -404,7 +404,7 @@ def _inner(key, dual_tree: DualTree):
404404
closed_jaxpr, (_, _, out_tree) = stage(f)(*primals)
405405
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
406406
dual_leaves = Dual.tree_leaves(Dual.tree_pure(dual_tree))
407-
out_duals = ADInterpreter._eval_jaxpr_adev_jvp(
407+
out_duals = ADInterpreter.eval_jaxpr_adev(
408408
key,
409409
jaxpr,
410410
consts,

src/genjax/core/interpreters.py renamed to src/genjax/_src/core/compiler/__init__.py

-17
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from genjax._src.core.interpreters.forward import StatefulHandler, forward
16-
from genjax._src.core.interpreters.incremental import incremental
17-
from genjax._src.core.interpreters.staging import (
18-
get_shaped_aval,
19-
stage,
20-
to_shape_fn,
21-
)
22-
23-
__all__ = [
24-
"StatefulHandler",
25-
"forward",
26-
"get_shaped_aval",
27-
"incremental",
28-
"stage",
29-
"to_shape_fn",
30-
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2024 MIT Probabilistic Computing Project
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools as it
16+
17+
import jax.core as jc
18+
from jax import tree_util
19+
from jax import util as jax_util
20+
from jax.extend.core import Primitive
21+
from jax.interpreters import mlir
22+
from jax.interpreters import partial_eval as pe
23+
24+
from genjax._src.core.compiler.staging import stage
25+
26+
#########################
27+
# Custom JAX primitives #
28+
#########################
29+
30+
31+
class InitialStylePrimitive(Primitive):
32+
"""Contains default implementations of transformations."""
33+
34+
def __init__(self, name):
35+
super(InitialStylePrimitive, self).__init__(name)
36+
self.multiple_results = True
37+
38+
def _abstract(*flat_avals, **params):
39+
abs_eval = params["abs_eval"]
40+
return abs_eval(*flat_avals, **params)
41+
42+
self.def_abstract_eval(_abstract)
43+
44+
def fun_impl(*args, **params):
45+
impl = params["impl"]
46+
return impl(*args, **params)
47+
48+
self.def_impl(fun_impl)
49+
50+
def _mlir(ctx: mlir.LoweringRuleContext, *mlir_args, **params):
51+
lowering = mlir.lower_fun(self.impl, multiple_results=True)
52+
return lowering(ctx, *mlir_args, **params)
53+
54+
mlir.register_lowering(self, _mlir)
55+
56+
57+
def initial_style_bind(prim, **params):
58+
"""Binds a primitive to a function call."""
59+
60+
def bind(f):
61+
"""Wraps a function to be bound to a primitive, keeping track of Pytree
62+
information."""
63+
64+
def wrapped(*args, **kwargs):
65+
"""Runs a function and binds it to a call primitive."""
66+
jaxpr, (flat_args, in_tree, out_tree) = stage(f)(*args, **kwargs)
67+
debug_info = jaxpr.jaxpr.debug_info
68+
69+
def _impl(*args, **params):
70+
consts, args = jax_util.split_list(args, [params["num_consts"]])
71+
return jc.eval_jaxpr(jaxpr.jaxpr, consts, *args)
72+
73+
def _abs_eval(*flat_avals, **params):
74+
return pe.abstract_eval_fun(
75+
_impl,
76+
*flat_avals,
77+
debug_info=debug_info,
78+
**params,
79+
)
80+
81+
outs = prim.bind(
82+
*it.chain(jaxpr.literals, flat_args),
83+
abs_eval=params.get("abs_eval", _abs_eval),
84+
impl=_impl,
85+
in_tree=in_tree,
86+
out_tree=out_tree,
87+
num_consts=len(jaxpr.literals),
88+
**params,
89+
)
90+
return tree_util.tree_unflatten(out_tree(), outs)
91+
92+
return wrapped
93+
94+
return bind
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 The MIT Probabilistic Computing Project
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import jax.core as jc
15+
from jax.extend.core import Literal, Var
16+
17+
from genjax._src.core.pytree import Pytree
18+
from genjax._src.core.typing import Any
19+
20+
VarOrLiteral = Var | Literal
21+
22+
23+
@Pytree.dataclass
24+
class Environment(Pytree):
25+
"""Keeps track of variables and their values during interpretation."""
26+
27+
env: dict[int, Any] = Pytree.field(default_factory=dict)
28+
29+
def read(self, var: VarOrLiteral) -> Any:
30+
"""
31+
Read a value from a variable in the environment.
32+
"""
33+
v = self.get(var)
34+
if v is None:
35+
assert isinstance(var, Var)
36+
raise ValueError(
37+
f"Unbound variable in interpreter environment at count {var.count}:\nEnvironment keys (count): {list(self.env.keys())}"
38+
)
39+
return v
40+
41+
def get(self, var: VarOrLiteral) -> Any:
42+
if isinstance(var, Literal):
43+
return var.val
44+
else:
45+
return self.env.get(var.count)
46+
47+
def write(self, var: VarOrLiteral, cell: Any) -> Any:
48+
"""
49+
Write a value to a variable in the environment.
50+
"""
51+
if isinstance(var, Literal):
52+
return cell
53+
cur_cell = self.get(var)
54+
if isinstance(var, jc.DropVar):
55+
return cur_cell
56+
self.env[var.count] = cell
57+
return self.env[var.count]
58+
59+
def __getitem__(self, var: VarOrLiteral) -> Any:
60+
return self.read(var)
61+
62+
def __setitem__(self, key, val):
63+
raise ValueError(
64+
"Environments do not support __setitem__. Please use the "
65+
"`write` method instead."
66+
)
67+
68+
def __contains__(self, var: VarOrLiteral):
69+
"""
70+
Check if a variable is in the environment.
71+
"""
72+
if isinstance(var, Literal):
73+
return True
74+
return var.count in self.env
75+
76+
def copy(self):
77+
"""
78+
`Environment.copy` is sometimes used to create a new environment with the same variables and values as the original, especially in CPS interpreters (where a continuation closes over the application of an interpreter to a `Jaxpr`).
79+
"""
80+
keys = list(self.env.keys())
81+
return Environment({k: self.env[k] for k in keys})

src/genjax/_src/core/interpreters/incremental.py renamed to src/genjax/_src/core/compiler/interpreters/incremental.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
from jax import util as jax_util
3333
from jax.extend.core import Jaxpr, Primitive
3434

35-
from genjax._src.core.interpreters.forward import Environment, StatefulHandler
36-
from genjax._src.core.interpreters.staging import stage
35+
from genjax._src.core.compiler.interpreters.environment import Environment
36+
from genjax._src.core.compiler.interpreters.stateful import StatefulHandler
37+
from genjax._src.core.compiler.staging import stage
3738
from genjax._src.core.pytree import Pytree
3839
from genjax._src.core.typing import (
3940
Any,
@@ -312,9 +313,9 @@ class IncrementalInterpreter(Pytree):
312313
default_factory=dict
313314
)
314315

315-
def _eval_jaxpr_forward(
316+
def eval_jaxpr_incremental(
316317
self,
317-
_stateful_handler,
318+
stateful_handler,
318319
jaxpr: Jaxpr,
319320
consts: list[Any],
320321
primals: list[Any],
@@ -334,8 +335,8 @@ def _eval_jaxpr_forward(
334335
]
335336
subfuns, params = _eqn.primitive.get_bind_params(_eqn.params)
336337
args = subfuns + induals
337-
if _stateful_handler and _stateful_handler.handles(_eqn.primitive):
338-
outduals = _stateful_handler.dispatch(_eqn.primitive, *args, **params)
338+
if stateful_handler and stateful_handler.handles(_eqn.primitive):
339+
outduals = stateful_handler.dispatch(_eqn.primitive, *args, **params)
339340
else:
340341
outduals = default_propagation_rule(_eqn.primitive, *args, **params)
341342
if not _eqn.primitive.multiple_results:
@@ -353,7 +354,7 @@ def _inner(*args):
353354
tangents, is_leaf=lambda v: isinstance(v, ChangeTangent)
354355
)
355356
jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
356-
flat_out = self._eval_jaxpr_forward(
357+
flat_out = self.eval_jaxpr_incremental(
357358
_stateful_handler,
358359
jaxpr,
359360
consts,

0 commit comments

Comments
 (0)