|
| 1 | +# Copyright 2024 MIT Probabilistic Computing Project |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import itertools as it |
| 16 | + |
| 17 | +import jax.core as jc |
| 18 | +from jax import tree_util |
| 19 | +from jax import util as jax_util |
| 20 | +from jax.extend.core import Primitive |
| 21 | +from jax.interpreters import mlir |
| 22 | +from jax.interpreters import partial_eval as pe |
| 23 | + |
| 24 | +from genjax._src.core.compiler.staging import stage |
| 25 | + |
| 26 | +######################### |
| 27 | +# Custom JAX primitives # |
| 28 | +######################### |
| 29 | + |
| 30 | + |
| 31 | +class InitialStylePrimitive(Primitive): |
| 32 | + """Contains default implementations of transformations.""" |
| 33 | + |
| 34 | + def __init__(self, name): |
| 35 | + super(InitialStylePrimitive, self).__init__(name) |
| 36 | + self.multiple_results = True |
| 37 | + |
| 38 | + def _abstract(*flat_avals, **params): |
| 39 | + abs_eval = params["abs_eval"] |
| 40 | + return abs_eval(*flat_avals, **params) |
| 41 | + |
| 42 | + self.def_abstract_eval(_abstract) |
| 43 | + |
| 44 | + def fun_impl(*args, **params): |
| 45 | + impl = params["impl"] |
| 46 | + return impl(*args, **params) |
| 47 | + |
| 48 | + self.def_impl(fun_impl) |
| 49 | + |
| 50 | + def _mlir(ctx: mlir.LoweringRuleContext, *mlir_args, **params): |
| 51 | + lowering = mlir.lower_fun(self.impl, multiple_results=True) |
| 52 | + return lowering(ctx, *mlir_args, **params) |
| 53 | + |
| 54 | + mlir.register_lowering(self, _mlir) |
| 55 | + |
| 56 | + |
| 57 | +def initial_style_bind(prim, **params): |
| 58 | + """Binds a primitive to a function call.""" |
| 59 | + |
| 60 | + def bind(f): |
| 61 | + """Wraps a function to be bound to a primitive, keeping track of Pytree |
| 62 | + information.""" |
| 63 | + |
| 64 | + def wrapped(*args, **kwargs): |
| 65 | + """Runs a function and binds it to a call primitive.""" |
| 66 | + jaxpr, (flat_args, in_tree, out_tree) = stage(f)(*args, **kwargs) |
| 67 | + debug_info = jaxpr.jaxpr.debug_info |
| 68 | + |
| 69 | + def _impl(*args, **params): |
| 70 | + consts, args = jax_util.split_list(args, [params["num_consts"]]) |
| 71 | + return jc.eval_jaxpr(jaxpr.jaxpr, consts, *args) |
| 72 | + |
| 73 | + def _abs_eval(*flat_avals, **params): |
| 74 | + return pe.abstract_eval_fun( |
| 75 | + _impl, |
| 76 | + *flat_avals, |
| 77 | + debug_info=debug_info, |
| 78 | + **params, |
| 79 | + ) |
| 80 | + |
| 81 | + outs = prim.bind( |
| 82 | + *it.chain(jaxpr.literals, flat_args), |
| 83 | + abs_eval=params.get("abs_eval", _abs_eval), |
| 84 | + impl=_impl, |
| 85 | + in_tree=in_tree, |
| 86 | + out_tree=out_tree, |
| 87 | + num_consts=len(jaxpr.literals), |
| 88 | + **params, |
| 89 | + ) |
| 90 | + return tree_util.tree_unflatten(out_tree(), outs) |
| 91 | + |
| 92 | + return wrapped |
| 93 | + |
| 94 | + return bind |
0 commit comments