-
Notifications
You must be signed in to change notification settings - Fork 2
[GJEP] Dynamic addresses #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
As mentioned in the section on Consider the following model: @gen(Static)
def model(t):
x = normal(0.0, 1.0) @ ("x", t)
_ = normal(0.0, 1.0) @ ("x", t + 1) Now, this is a valid a model -- the addresses will always be different (so it doesn't violate the no duplicate address restrictions). However, when we consider model.update(key, prev, new_constraints, (Diff(t, NoChange), )) The way that update is implemented for the Now we have another implementation choice to make: we can insert a We've encountered a "switch combinator"-like scenario inside our implementation of Now, of course, just like The larger problem is that the number of calls is quadratic in the number of possibly conflated dynamic addresses. So e.g. if you have 4 of these addresses in your model code, each time you update -- for each of those addresses, you're going to to have to run an update for all of them. |
For the above issue, one possible solution I came up with — take advantage of the fact that the traversal order of addresses in GenJAX is deterministic. The pathological case in that solution is if a change to an argument changes the address support — but I believe we can handle that with a simple cond. |
Here's the interesting case to consider in the above solution: Again, consider the following model: @gen(Static)
def model(t):
x = normal(0.0, 1.0) @ ("x", t)
_ = normal(1.0, 1.0) @ ("x", t + 1) If we generate a trace with tr.update(key, empty_choice(), (Diff(1, UnknownChange), )) Now, at the first address, we have So it does not suffice to just grab the subtrace at the index in the list of addresses by traversal order -- because changes to the arguments can affect what subtrace we want to grab. (e.g. -- because we've shifted the address -- we need to grab the second callee, not the first -- but making that decision might by dynamic!) |
Based on my reasoning, here are the degrees of freedom we have available to us for this feature:
In general, I think the trade offs are acceptable. Model changes via dynamic changes to arguments which affect dynamic addresses are the "most exotic" form of dynamism in GenJAX -- and it seems fair to me that the implementation is the least performant around this feature. |
Closing this as "won't address". After internal discussion, I believe that the path forward to support dynamic addresses is to expose them via specific inference functionality (like trace translators), and not to try to push them into the core |
Bumps [anyio](https://github.com/agronholm/anyio) from 4.3.0 to 4.4.0. <details> <summary>Release notes</summary> <p><em>Sourced from <a href="https://github.com/agronholm/anyio/releases">anyio's releases</a>.</em></p> <blockquote> <h2>4.4.0</h2> <ul> <li>Added the <code>BlockingPortalProvider</code> class to aid with constructing synchronous counterparts to asynchronous interfaces that would otherwise require multiple blocking portals</li> <li>Added <code>__slots__</code> to <code>AsyncResource</code> so that child classes can use <code>__slots__</code> (<a href="https://redirect.github.com/agronholm/anyio/pull/733">#733</a>; PR by Justin Su)</li> <li>Added the <code>TaskInfo.has_pending_cancellation()</code> method</li> <li>Fixed erroneous <code>RuntimeError: called 'started' twice on the same task status</code> when cancelling a task in a TaskGroup created with the <code>start()</code> method before the first checkpoint is reached after calling <code>task_status.started()</code> (<a href="https://redirect.github.com/agronholm/anyio/issues/706">#706</a>; PR by Dominik Schwabe)</li> <li>Fixed two bugs with <code>TaskGroup.start()</code> on asyncio: <ul> <li>Fixed erroneous <code>RuntimeError: called 'started' twice on the same task status</code> when cancelling a task in a TaskGroup created with the <code>start()</code> method before the first checkpoint is reached after calling <code>task_status.started()</code> (<a href="https://redirect.github.com/agronholm/anyio/issues/706">#706</a>; PR by Dominik Schwabe)</li> <li>Fixed the entire task group being cancelled if a <code>TaskGroup.start()</code> call gets cancelled (<a href="https://redirect.github.com/agronholm/anyio/issues/685">#685</a>, <a href="https://redirect.github.com/agronholm/anyio/issues/710">#710</a>)</li> </ul> </li> <li>Fixed a race condition that caused crashes when multiple event loops of the same backend were running in separate threads and simultaneously attempted to use AnyIO for their first time (<a href="https://redirect.github.com/agronholm/anyio/issues/425">#425</a>; PR by David Jiricek and Ganden Schaffner)</li> <li>Fixed cancellation delivery on asyncio incrementing the wrong cancel scope's cancellation counter when cascading a cancel operation to a child scope, thus failing to uncancel the host task (<a href="https://redirect.github.com/agronholm/anyio/issues/716">#716</a>)</li> <li>Fixed erroneous <code>TypedAttributeLookupError</code> if a typed attribute getter raises <code>KeyError</code></li> <li>Fixed the asyncio backend not respecting the <code>PYTHONASYNCIODEBUG</code> environment variable when setting the <code>debug</code> flag in <code>anyio.run()</code></li> <li>Fixed <code>SocketStream.receive()</code> not detecting EOF on asyncio if there is also data in the read buffer (<a href="https://redirect.github.com/agronholm/anyio/issues/701">#701</a>)</li> <li>Fixed <code>MemoryObjectStream</code> dropping an item if the item is delivered to a recipient that is waiting to receive an item but has a cancellation pending (<a href="https://redirect.github.com/agronholm/anyio/issues/728">#728</a>)</li> <li>Emit a <code>ResourceWarning</code> for <code>MemoryObjectReceiveStream</code> and <code>MemoryObjectSendStream</code> that were garbage collected without being closed (PR by Andrey Kazantcev)</li> <li>Fixed <code>MemoryObjectSendStream.send()</code> not raising <code>BrokenResourceError</code> when the last corresponding <code>MemoryObjectReceiveStream</code> is closed while waiting to send a falsey item (<a href="https://redirect.github.com/agronholm/anyio/issues/731">#731</a>; PR by Ganden Schaffner)</li> </ul> </blockquote> </details> <details> <summary>Changelog</summary> <p><em>Sourced from <a href="https://github.com/agronholm/anyio/blob/master/docs/versionhistory.rst">anyio's changelog</a>.</em></p> <blockquote> <h1>Version history</h1> <p>This library adheres to <code>Semantic Versioning 2.0 <http://semver.org/></code>_.</p> <p><strong>UNRELEASED</strong></p> <ul> <li>Added support for the <code>from_uri()</code>, <code>full_match()</code>, <code>parser</code> methods/properties in <code>anyio.Path</code>, newly added in Python 3.13</li> </ul> <p><strong>4.4.0</strong></p> <ul> <li> <p>Added the <code>BlockingPortalProvider</code> class to aid with constructing synchronous counterparts to asynchronous interfaces that would otherwise require multiple blocking portals</p> </li> <li> <p>Added <code>__slots__</code> to <code>AsyncResource</code> so that child classes can use <code>__slots__</code> (<code>[#733](agronholm/anyio#733) <https://github.com/agronholm/anyio/pull/733></code>_; PR by Justin Su)</p> </li> <li> <p>Added the <code>TaskInfo.has_pending_cancellation()</code> method</p> </li> <li> <p>Fixed erroneous <code>RuntimeError: called 'started' twice on the same task status</code> when cancelling a task in a TaskGroup created with the <code>start()</code> method before the first checkpoint is reached after calling <code>task_status.started()</code> (<code>[#706](agronholm/anyio#706) <https://github.com/agronholm/anyio/issues/706></code>_; PR by Dominik Schwabe)</p> </li> <li> <p>Fixed two bugs with <code>TaskGroup.start()</code> on asyncio:</p> <ul> <li>Fixed erroneous <code>RuntimeError: called 'started' twice on the same task status</code> when cancelling a task in a TaskGroup created with the <code>start()</code> method before the first checkpoint is reached after calling <code>task_status.started()</code> (<code>[#706](agronholm/anyio#706) <https://github.com/agronholm/anyio/issues/706></code>_; PR by Dominik Schwabe)</li> <li>Fixed the entire task group being cancelled if a <code>TaskGroup.start()</code> call gets cancelled (<code>[#685](agronholm/anyio#685) <https://github.com/agronholm/anyio/issues/685></code><em>, <code>[#710](agronholm/anyio#710) <https://github.com/agronholm/anyio/issues/710></code></em>)</li> </ul> </li> <li> <p>Fixed a race condition that caused crashes when multiple event loops of the same backend were running in separate threads and simultaneously attempted to use AnyIO for their first time (<code>[#425](agronholm/anyio#425) <https://github.com/agronholm/anyio/issues/425></code>_; PR by David Jiricek and Ganden Schaffner)</p> </li> <li> <p>Fixed cancellation delivery on asyncio incrementing the wrong cancel scope's cancellation counter when cascading a cancel operation to a child scope, thus failing to uncancel the host task (<code>[#716](agronholm/anyio#716) <https://github.com/agronholm/anyio/issues/716></code>_)</p> </li> <li> <p>Fixed erroneous <code>TypedAttributeLookupError</code> if a typed attribute getter raises <code>KeyError</code></p> </li> <li> <p>Fixed the asyncio backend not respecting the <code>PYTHONASYNCIODEBUG</code> environment variable when setting the <code>debug</code> flag in <code>anyio.run()</code></p> </li> <li> <p>Fixed <code>SocketStream.receive()</code> not detecting EOF on asyncio if there is also data in the read buffer (<code>[#701](agronholm/anyio#701) <https://github.com/agronholm/anyio/issues/701></code>_)</p> </li> <li> <p>Fixed <code>MemoryObjectStream</code> dropping an item if the item is delivered to a recipient that is waiting to receive an item but has a cancellation pending (<code>[#728](agronholm/anyio#728) <https://github.com/agronholm/anyio/issues/728></code>_)</p> </li> <li> <p>Emit a <code>ResourceWarning</code> for <code>MemoryObjectReceiveStream</code> and <code>MemoryObjectSendStream</code> that were garbage collected without being closed (PR by Andrey Kazantcev)</p> </li> </ul> <!-- raw HTML omitted --> </blockquote> <p>... (truncated)</p> </details> <details> <summary>Commits</summary> <ul> <li><a href="https://github.com/agronholm/anyio/commit/053e8f0a0f7b0f4a47a012eb5c6b1d9d84344e6a"><code>053e8f0</code></a> Bumped up the version</li> <li><a href="https://github.com/agronholm/anyio/commit/e7f750b96f5416d8ae932e15d726b5d03de80b67"><code>e7f750b</code></a> Fixed memory object stream sometimes dropping sent items (<a href="https://redirect.github.com/agronholm/anyio/issues/735">#735</a>)</li> <li><a href="https://github.com/agronholm/anyio/commit/9f5f14b3eb57f6965fc2c16879df93263bb020ea"><code>9f5f14b</code></a> Fixed task group getting cancelled if start() gets cancelled (<a href="https://redirect.github.com/agronholm/anyio/issues/717">#717</a>)</li> <li><a href="https://github.com/agronholm/anyio/commit/8b648bc213a85613b9441913b82a14d9cd839048"><code>8b648bc</code></a> Adjusted the pull request template</li> <li><a href="https://github.com/agronholm/anyio/commit/3ff5e9a6f1813152a7cc9ff27a8394a51812a040"><code>3ff5e9a</code></a> Rearranged changelog items</li> <li><a href="https://github.com/agronholm/anyio/commit/541d1f8197dfa36076f93b39e73ee5ad06012469"><code>541d1f8</code></a> [pre-commit.ci] pre-commit autoupdate (<a href="https://redirect.github.com/agronholm/anyio/issues/734">#734</a>)</li> <li><a href="https://github.com/agronholm/anyio/commit/8a076900333b6b333f1748dd8d1e8ae8079a2924"><code>8a07690</code></a> Fix <code>MemoryObjectSendStream.send(falsey)</code> not raising <code>BrokenResourceError</code> w...</li> <li><a href="https://github.com/agronholm/anyio/commit/4b3de9737672df67b691f38543427e4869639f45"><code>4b3de97</code></a> Adjust the headings in the PR template</li> <li><a href="https://github.com/agronholm/anyio/commit/dfc44cf3c8c5444713258d0f1fda03e425240054"><code>dfc44cf</code></a> Added <code>__slots__</code> to <code>AsyncResource</code> (<a href="https://redirect.github.com/agronholm/anyio/issues/733">#733</a>)</li> <li><a href="https://github.com/agronholm/anyio/commit/96920b054c4d0c76ad440f36d7173ab5d5c86948"><code>96920b0</code></a> Fix typo in PR template (<a href="https://redirect.github.com/agronholm/anyio/issues/730">#730</a>)</li> <li>Additional commits viewable in <a href="https://github.com/agronholm/anyio/compare/4.3.0...4.4.0">compare view</a></li> </ul> </details> <br /> [](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) --- <details> <summary>Dependabot commands and options</summary> <br /> You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show <dependency name> ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) </details> Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This is a reference issue which ties together several other design issues related to "dynamic addresses" in GenJAX, including #556, #473, #239 -- even formalization issues like #235 (which interact with this feature, by virtue of the fact that this feature requires robust masking, described in a moment) and #212.
Quick note:
Dynamic addresses
Fundamentally, this issue is about the following type of generative function:
Seemingly innocuous -- except it's not: it uses
t
- which can be a dynamic, JAX traced, integer array as part of the address.On first design, addresses of this form were explicitly not supported in GenJAX's
Static
language -- the interfaces were implemented usingTrie
-like types, whose keys had to be hashable. In first design, address components were constrained to be statically hashable things - including strings, and Python literal ints.However, @derifatives quickly noted that expressing SMC update steps over sequences of target distributions expressed via
genjax.UnfoldCombinator
models is inconvenient withStatic
language programs whose addresses cannot be dynamic. Why? Well, you'd really like to write a proposal (like the program above):And then, inside SMC -- write a tight
jax.lax.scan
-- which provides the current SMC target indext
to the proposal, to produce a choice map which targets that address.If the choice maps from the
Static
language are disallowed from containing dynamic integer addresses, you cannot construct such ajax.lax.scan
- based SMC program - becausejax.lax.scan
always provides dynamic values to the body function (meaning, you can't uset
in the address, because it's dynamic -- and it will always be dynamic -- inside ofjax.lax.scan
).There's several "shallow" ways to fix this problem. One way would be to force users who use SMC to provide a proposal which doesn't actively include the current SMC target step
t
in their address space -- and then massage their resulting choice maps to work correctly at the current time step. By hoistingt
out of the proposal entirely, we don't have to worry about dynamic addresses. On the other hand, you often do want to know whatt
is so e.g. you can write proposals that look at the current observations. More pressingly, removingt
from user control restricts the user from writing more expressive proposals. As part of Gen's philosophy, we want to err on side of supporting expressivity as much as possible (and, within reason, for our implementation's limitations).Dynamic addresses are infectious
Dynamism in JAX is a tricky property. Where dynamism occurs, code cannot make a decision at trace time, and instead must be conservative. Dynamism immediately applies that we have to push a branching decision into runtime somewhere. We have control over the somewhere, but we cannot eliminate the branching.
As mentioned in other issues, the
ChoiceMap
interfaces likeget_submap
andhas_submap
must contend with dynamism in certain settings. One example is ingenjax.SwitchCombinator
- where "what branch generative function was called" is a dynamic quantity. To handle dynamism in his context, GenJAX has a type calledMask
(#235, https://probcomp.github.io/genjax/genjax/library/core/functional_types.html?h=mask#option-types-the-masking-system):Mask
wraps a dynamic piece of data with a dynamicBool
tag -- it's just like a functionalMaybe
orOption
type.If one invokes
SwitchChoiceMap.get_submap
-- and the address is not valid for the branch which was actually taken, we need some way to denote that behavior. We always return some value - but because we can't know what branch was actually active until runtime, we need to tag that value with a piece of information to say whether or not it is valid or invalid at runtime. We do so usingMask
- which always wraps a value, but the "validity" of the value depends on the tag.But surprise surprise! Dynamism is part of dynamic addresses -- if one constructs a choice map which supports dynamic addresses, and a user asks for
DynamicChoiceMap.get_submap
with a runtime integer array, whether or not that integer array is actually in the choice map is dynamic. Thus, dynamic addresses will likely involve at the very least, some form of masking.Let's explore what else is required below.
Dynamic addresses and
update
Furthermore, we can quickly consider an extreme manifestation of the "infectiousness" of dynamism -- let's say you write:
and generate a trace from this generative function.
If you attempt to update this trace in JAX traced code, and you change
t
-- what happens?Well, Gen allows you to change the measure, including the addresses by changing the arguments. But now, the address which the
update
interface sees might imply that a subtrace (for the callee, here aDistribution
generative functiontfp_normal
) either exists or does not exists.That means that the
update
interface must handle dynamism by introducing branching -- because whether or not a subtrace lives at that address is now a runtime decision.I cover an example like this in more detail in my replies to this issue thread (below).
Solutions: new dynamic choice map type and masking
To take a step towards the solution space, I believe that a design solution will consist of a combination of features:
StaticLanguage
will involve heavy use of masking.Pytree
-- to separate dynamic and static data.Builtin
proposers forVector
combinators #212, and [Defactor] Choice map conversion and generic fallbacks #556. The new type of dynamic choice map, in a very concrete sense, is the most general type of choice map. It can support hierarchical addresses of any structure. By virtue of this fact, it's worth supporting code paths for all existing languages to accept dynamic choice map. This provides a common conversion target for generative function languages which utilize specialized choice maps (e.g. interoperability between the choice map types of different languages, which is required by the Gen's semantics -- which doesn't specialized on the data representation of a choice map, but does require interoperability of the address space of a choice map (no matter the representation)). Without further experience, my prediction is that the cost of conversion which be a tracing time cost -- and utilizing dynamic choice map will incur runtime penalties due to branching. But fundamentally, strong code paths for this functionality supports "slow, but stable and compositional" interface implementationsThe criteria for completeness involves the following testing axes:
The text was updated successfully, but these errors were encountered: