Skip to content

Commit 96b10a1

Browse files
Add support for sample_shape to primitive distributions (#1576)
This PR passes any `sample_shape` argument provided to a tfp distribution through to the `sample` method, instead of passing it to the constructor. @femtomc one issue that I hit here was that to use this, I have to wrap the argument in `Const`, because it seems like our code tries to trace non-jit-compiled fns, or something like that. I'll add the error I saw as a reply. Co-authored-by: Mathieu Huot <[email protected]>
1 parent 511ff50 commit 96b10a1

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/genjax/_src/core/pytree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ class Const(Generic[R], Pytree):
250250
251251
252252
def f(c):
253-
if c.const == 5:
253+
if c.unwrap() == 5:
254254
return 10.0
255255
else:
256256
return 5.0

src/genjax/_src/generative_functions/distributions/tensorflow_probability/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import jax.numpy as jnp
1717
from tensorflow_probability.substrates import jax as tfp
1818

19+
from genjax._src.core.pytree import Const
1920
from genjax._src.core.typing import Array, Callable
2021
from genjax._src.generative_functions.distributions.distribution import (
2122
ExactDensity,
@@ -49,11 +50,15 @@ def tfp_distribution(
4950
"""
5051

5152
def sampler(key, *args, **kwargs):
53+
sample_shape = kwargs.pop("sample_shape", ())
5254
d = dist(*args, **kwargs)
53-
return d.sample(seed=key)
55+
return d.sample(seed=key, sample_shape=Const.unwrap(sample_shape))
5456

5557
def logpdf(v, *args, **kwargs):
58+
# Remove unused kwarg to match sampler function behavior
59+
kwargs.pop("sample_shape", ())
5660
d = dist(*args, **kwargs)
61+
5762
return d.log_prob(v)
5863

5964
return exact_density(sampler, logpdf, name or dist.__name__)

tests/generative_functions/test_static_gen_fn.py

+8
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def annotated_function(x: float, y: float) -> float:
8484

8585

8686
class TestMisc:
87+
def test_static_sample_shape(self):
88+
@genjax.gen
89+
def f():
90+
return genjax.normal(0.0, 1.0, sample_shape=genjax.Const((2, 2))) @ "normal"
91+
92+
tr = f.simulate(jax.random.key(0), ())
93+
assert tr.get_retval().shape == (2, 2)
94+
8795
def test_switch_chm_and_static(self):
8896
@genjax.gen
8997
def model():

0 commit comments

Comments
 (0)