Skip to content

Releases: genjax-dev/genjax

v0.10.3

26 Mar 19:07
96b10a1
Compare
Choose a tag to compare

What's Changed

  • Kill the manual CHANGELOG (we already have an automated one). by @femtomc in #1564
  • [Cold Harbor, 1 / N] Organize the lowest levels of the GenJAX compiler. by @femtomc in #1566
  • New logo, docs assets clean up. by @femtomc in #1568
  • Update README.md to taste. by @femtomc in #1570
  • Bump webfactory/ssh-agent from 0.9.0 to 0.9.1 by @dependabot in #1575
  • Add support for sample_shape to primitive distributions by @sritchie in #1576

Full Changelog: v0.10.2...v0.10.3

v0.10.2

14 Mar 19:55
af24cd7
Compare
Choose a tag to compare

What's Changed

  • Show mapping tutorial in docs by @fzaiser in #1557
  • Do a little cleanup around InitialStylePrimitive. by @femtomc in #1556
  • Update usage of JAX partial_eval / wrap_init to accept debug info by @femtomc in #1563

Full Changelog: v0.10.1...v0.10.2

v0.10.1

11 Mar 15:28
0565036
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.10.0...v0.10.1

v0.10.0

27 Feb 22:09
24c097b
Compare
Choose a tag to compare

What's Changed

Breaking Changes

  • Remove MaskSel, Flag => bool in a number of spots (GEN-979) by @sritchie in #1529
  • Expunge the concept of Constraint; just rely on choice map everywhere. by @femtomc in #1532
  • Clean up the usage of the trace primitive and remove global "handler stack". by @femtomc in #1531

Fixes

Cookbook

  • Adding IndexRequest for Vmap and cookbook entry on IndexRequest by @MathieuHuot in #1518
  • Add mathjax support for $, $$ delimiters (GEN-975) by @sritchie in #1525
  • fix math rendering in notebooks (GEN-987) by @sritchie in #1535

Misc

Full Changelog: v0.9.3...v0.10.0

v0.9.3: bugfix release

14 Feb 19:16
93c9b9d
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.9.2...v0.9.3

v0.9.2: simpler trace printing

12 Feb 00:39
a898281
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.9.1...v0.9.2

v0.9.1, trace simplification proceeds

11 Feb 04:07
a272c43
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.9.0...v0.9.1

v0.9.0: Trace tidying, we're on pypi!

10 Feb 20:49
3905831
Compare
Choose a tag to compare

This is the first GenJAX release published to PyPI proper.

Pypi page: https://pypi.org/project/genjax/
Docs live here: https://chisym.github.io/genjax/

What's Changed

New Stuff

  • Adding some TFP distributions (initially Gamma), a test, nuking Bates as it doesn't work by @MathieuHuot in #1493
  • Add switch to ChoiceMapBuilder (GEN-897) by @sritchie in #1469
  • Add all, none, leaf and S[()] == S.none (GEN-945) by @sritchie in #1492
  • Implement project for VmapCombinator by @sritchie in #1505

Changes / Fixes

Cookbook

Misc

Full Changelog: v0.8.1...v0.8.2

v0.8.1

18 Dec 21:30
552dcef
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.8.0...v0.8.1

v0.8.0: simpler ChoiceMaps, new combinators

03 Dec 21:40
d39957c
Compare
Choose a tag to compare

GenJAX 0.8.0 is here! We have a few minor breaking changes, a nice set of new features and a bunch of ChoiceMap improvements for your pleasure. The sections below list out each PR that was merged since the previous release, along with some exposition about the new features where appropriate.

To install this new version, run

pip install keyring keyrings.google-artifactregistry-auth
pip install "genjax==0.8.0" --extra-index-url https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple/

or bump the version to "0.8.0" wherever you've pinned your current version.

Thanks to @georgematheos , @esli999 , @MathieuHuot , @femtomc, @limarta and @ahiser1117 for code, testing, bug reports and repros.

Breaking Changes

  • dimap's post-xform should accept the original args (GEN-781) by @sritchie in #1429
  • Present transformed args to post_fn for DimapCombinator by @sritchie in #1449

dimap's post argument and contramap's argument now accept 3 arguments: the pre-transformed input args, the post-transformed input args, and the return value. Before they only took the pre-transformed args and the return value.

Fixes / new stuff

This fix allows you to use pytree instances as the input to a vmapped model.

This allows you to use Const.unwrap(x) on a Const or non-Const object; Const.unwrap acts as identity when supplied with a non-Const.

It's now possible to use the @gen decorator on the method of a class. See below for an example:

In [5]: import jax
   ...: import genjax
   ...: from genjax._src.core.pytree import Pytree
   ...: from genjax._src.core.typing import Array, ArrayLike
   ...:
   ...:
   ...: @genjax.gen
   ...: def outside(x):
   ...:     return genjax.normal(x, 1.0) @ "x"
   ...:
   ...: @Pytree.dataclass
   ...: class Model(Pytree):
   ...:     foo: ArrayLike
   ...:     bar: ArrayLike
   ...:
   ...:     @genjax.gen
   ...:     def run(self, x):
   ...:         y = genjax.normal(self.foo, self.bar) @ "y"
   ...:         z = genjax.normal(x, 1.0) @ "z"
   ...:         return y + z
   ...:
   ...: key = jax.random.PRNGKey(0)
   ...: m = Model(foo=4.0, bar=6.0)
   ...: m.run.simulate(key, (1.0,)).get_choices()
Out[5]:
Static({
  'y': Choice(v=<jax.Array(-4.405305, dtype=float32)>),
  'z': Choice(v=<jax.Array(1.3905154, dtype=float32)>),
})
  • masked_iterate, masked_iterate_final (GEN-327, GEN-859) by @sritchie in #1450

We have two new combinators! masked_iterate and masked_iterate_final are similar to their non-masked variants, except instead of providing an initial value, you provide a pair of initial_value and a boolean array. The nth iteration will be scored if the nth boolean entry is True, and not scored if its entry is False.

For example:

masks = jnp.array([True, True])

@genjax.gen
def step(x):
    _ = (
        genjax.normal.mask().vmap(in_axes=(0, None, None))(masks, x, 1.0)
        @ "rats"
    )
    return x

# Create some initial traces:
key = jax.random.key(0)
mask_steps = jnp.arange(10) < 5
model = step.masked_iterate_final()
init_particle = model.simulate(key, (0.0, mask_steps))
init_particle.get_choices()
image
  • (StaticGenerativeFunction) Prevents allocation and accumulation of score, when its available in callee subtraces. by @femtomc in #1418
  • Allow PythonicPytree indexing using JAX arrays by @georgematheos in #1440

ChoiceMap

Breaking change — any place you've used an ellipsis like ... to query a ChoiceMap, please replace it with :.

ChoiceMap instances now automatically simplify! There is no more "filtered choicemap", as all filters are pushed down to the leaves:

l_side = C["x"].set({"y":1.0, "z": 2.0}).mask(jnp.array(True))
r_side = C["x", "b"].set(1.0)

l_side | r_side
image

We've also gotten rid of any returned Indexed choicemap layer. JAX's arrays already keep track of their indexing; there's no need to introduce another Indexed layer on our side.

@genjax.gen
def kernel(mean):
    # Three independent choices
    x = genjax.normal(mean, 1.0) @ "x"
    y = genjax.normal(x, 2.0) @ "y"  # Conditioned on x
    z = genjax.normal(0.0, 3.0) @ "z"
    return x + y + z

vmapped_model = kernel.vmap()
key = jax.random.key(0)
chm = vmapped_model.simulate(key, (jnp.arange(10.0),)).get_choices()
chm
image

You can still query this choicemap like before, with a : to note the index:

chm[:, "x"]
image

But you can also just omit the : and get the value via chm["x"]. Any non-string entries you place into the braces will be pushed down to the array-shaped leaves of the choicemap.

  • Update choicemap cookbook (GEN-568) by @MathieuHuot in #1381
  • Remove ellipsis support from queries, add SwitchChm (GEN-698, GEN-696, GEN-542) by @sritchie in #1421
  • Enforce "Mask has scalar flag", allow vectorized Mask (GEN-768, GEN-765, GEN-769) by @sritchie in #1423
  • remove Filtered, fix MaskTrace vmappability (GEN-662, GEN-569, GEN-775, GEN-767) by @sritchie in #1422
  • Remove Xor, chip away at Or (GEN-774) by @sritchie in #1436
  • Remove Indexed layers with None address by @sritchie in #1437

Edit

  • (1 / N) Remove ChoiceMapEditRequest, replace with StaticRequest by @femtomc in #1404
  • (2 / N) Do a bit of re-org to support deriving the implementations of edit for primitive EditRequest types by @femtomc in #1407
  • (3 / N) Add an edit request which can be used to forcibly coerce Diff values according to user input. by @femtomc in #1406
  • (5 / N) Add HMC as an EditRequest. by @femtomc in #1408
  • (6 / N) Add Rejuvenate -- a proposal-driven regeneration EditRequest by @femtomc in #1411
  • (7 / N) Add IndexRequest edit to ScanCombinator by @femtomc in #1405

Cookbook

Misc

  • Support JAX>=0.4.34 by removing recast_to_float0 and using our own implementation. by @femtomc in #1410
  • Adding a test for mask scan combinator by @MathieuHuot in #1443

Full Changelog: v0.7.0...v0.8.0