-
Notifications
You must be signed in to change notification settings - Fork 2
[BUG] genjax.mh
and genjax.Map
don't play well together
#407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@midfield I think the correct example is this: @genjax.gen
def model():
loc = genjax.normal(0., 1.) @ 'loc'
xs = genjax.Map(genjax.normal, in_axes=(None, 0))(loc, jnp.arange(10)) @ 'xs'
return xs
_, trace = genjax.simulate(model)(key, ())
@genjax.gen
def proposal(choices):
loc = choices['loc']
xs = genjax.Map(genjax.normal, in_axes=(None, 0))(loc, jnp.arange(10)) @ 'xs'
return xs
genjax.mh(proposal).apply(key, trace, ()) (Note change to This throws for me:
with an error from one of the axis size dimension functions. |
Yes, that's right, sorry for the typo. I edited the original report. But this is still an error no? |
It is! That's a type issue -- a function expects a I'll investigate and fix. |
Describe the bug
See code below
To Reproduce
Expected behavior
Doesn't produce error.
Execution environment (please complete the following information):
The text was updated successfully, but these errors were encountered: