Skip to content

Commit ef4b893

Browse files
committed
Add checkify option to console.
1 parent 6906d91 commit ef4b893

File tree

5 files changed

+72
-18
lines changed

5 files changed

+72
-18
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
max-line-length = 88
33
max-complexity = 15
44
select = C,E,F,W,B,B950
5-
extend-ignore = E203,E501,W503,W601,E302,W605,B950,B023,F811,C901,E731
5+
extend-ignore = E203,E501,W503,W601,E302,W605,B950,B023,F811,C901,E731,E722
66
exclude =
77
.git,
88
__init__.py,

src/genjax/_src/console.py

+37-12
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,40 @@ def console(
9696
enforce_checkify=False,
9797
**pretty_kwargs,
9898
):
99-
traceback_kwargs = {
100-
"word_wrap": True,
101-
"show_locals": False,
102-
"max_frames": 30,
103-
"suppress": [jax, plum],
104-
**pretty_kwargs,
105-
}
106-
return GenJAXConsole(
107-
Console(soft_wrap=True),
108-
traceback_kwargs,
109-
enforce_checkify,
110-
)
99+
try:
100+
# Try to ignore these packages in pretty printing.
101+
import asyncio
102+
103+
import ipykernel
104+
import tornado
105+
import traitlets
106+
107+
traceback_kwargs = {
108+
"word_wrap": True,
109+
"show_locals": False,
110+
"max_frames": 30,
111+
"suppress": [
112+
jax,
113+
plum,
114+
asyncio,
115+
tornado,
116+
traitlets,
117+
ipykernel,
118+
],
119+
**pretty_kwargs,
120+
}
121+
except:
122+
traceback_kwargs = {
123+
"word_wrap": True,
124+
"show_locals": False,
125+
"max_frames": 30,
126+
"suppress": [jax, plum],
127+
**pretty_kwargs,
128+
}
129+
130+
finally:
131+
return GenJAXConsole(
132+
Console(soft_wrap=True),
133+
traceback_kwargs,
134+
enforce_checkify,
135+
)

src/genjax/_src/core/datatypes/generative.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,10 @@ def unmask(self):
813813
# contexts.
814814
def _check():
815815
check_flag = jnp.all(self.mask)
816-
checkify.check(check_flag, "Mask is False, the masked value is invalid.\n")
816+
checkify.check(
817+
check_flag,
818+
"Attempted to unmask when the mask flag is False: the masked value is invalid.\n",
819+
)
817820

818821
optional_check(_check)
819822
return self.value

src/genjax/_src/generative_functions/combinators/vector/unfold_combinator.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,12 @@ def _inner(carry, xs):
205205

206206
def _optional_out_of_bounds_check(self, count: IntArray):
207207
def _check():
208-
check_flag = jnp.less(count + 1, self.max_length)
208+
check_flag = jnp.less_equal(count + 1, self.max_length)
209209
checkify.check(
210210
check_flag,
211-
f"\nUnfoldCombinator received a length argument ({count}) longer than specified max length ({self.max_length})",
211+
"UnfoldCombinator received an index argument (idx = {count}) with idx + 1 > max length ({max_length})",
212+
count=jnp.array(count, copy=False),
213+
max_length=jnp.array(self.max_length),
212214
)
213215

214216
optional_check(_check)

src/genjax/_src/inference/translator.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@
1818
import jax
1919
import jax.numpy as jnp
2020
import jax.tree_util as jtu
21+
from jax.experimental.checkify import check
2122

23+
from genjax._src.checkify import optional_check
2224
from genjax._src.core.datatypes.generative import Choice
2325
from genjax._src.core.datatypes.generative import ChoiceMap
2426
from genjax._src.core.datatypes.generative import GenerativeFunction
2527
from genjax._src.core.datatypes.generative import Trace
2628
from genjax._src.core.pytree.pytree import Pytree
2729
from genjax._src.core.pytree.utilities import tree_grad_split
2830
from genjax._src.core.pytree.utilities import tree_zipper
31+
from genjax._src.core.typing import Bool
2932
from genjax._src.core.typing import Callable
3033
from genjax._src.core.typing import FloatArray
3134
from genjax._src.core.typing import PRNGKey
@@ -92,6 +95,7 @@ def safe_slogdet(v):
9295
class ExtendingTraceTranslator(TraceTranslator):
9396
choice_map_forward: Callable # part of bijection
9497
choice_map_inverse: Callable # part of bijection
98+
check_bijection: Bool
9599
p_argdiffs: Tuple
96100
q_forward: GenerativeFunction
97101
q_forward_args: Tuple
@@ -103,7 +107,7 @@ def flatten(self):
103107
self.q_forward,
104108
self.q_forward_args,
105109
self.new_observations,
106-
), (self.choice_map_forward, self.choice_map_inverse)
110+
), (self.choice_map_forward, self.choice_map_inverse, self.check_bijection)
107111

108112
@classmethod
109113
def new(
@@ -114,18 +118,21 @@ def new(
114118
new_obs: Choice,
115119
choice_map_forward: Callable,
116120
choice_map_inverse: Callable,
121+
check_bijection: Bool,
117122
):
118123
return ExtendingTraceTranslator(
119124
choice_map_forward,
120125
choice_map_inverse,
126+
check_bijection,
121127
p_argdiffs,
122128
q_forward,
123129
q_forward_args,
124130
new_obs,
125131
)
126132

127133
def value_and_jacobian_correction(self, forward, trace):
128-
grad_tree, no_grad_tree = tree_grad_split(trace.get_choices())
134+
trace_choices = trace.get_choices()
135+
grad_tree, no_grad_tree = tree_grad_split(trace_choices)
129136

130137
def _inner(differentiable):
131138
choices = tree_zipper(differentiable, no_grad_tree)
@@ -134,6 +141,21 @@ def _inner(differentiable):
134141

135142
inner_jacfwd = jax.jacfwd(_inner, has_aux=True)
136143
J, transformed = inner_jacfwd(grad_tree)
144+
if self.check_bijection:
145+
146+
def optional_check_bijection_is_bijection():
147+
backwards = self.choice_map_inverse(transformed)
148+
flattened = jtu.tree_leaves(
149+
jtu.tree_map(
150+
lambda v1, v2: jnp.all(v1 == v2),
151+
trace_choices,
152+
backwards,
153+
)
154+
)
155+
check_flag = jnp.all(jnp.array(flattened))
156+
check(check_flag, "Bijection check failed")
157+
158+
optional_check(optional_check_bijection_is_bijection)
137159
J = stack_differentiable(J)
138160
(_, J_log_abs_det) = safe_slogdet(J)
139161
return transformed, J_log_abs_det
@@ -166,6 +188,7 @@ def extending_trace_translator(
166188
new_obs: ChoiceMap,
167189
choice_map_forward: Callable,
168190
choice_map_backward: Callable,
191+
check_bijection=False,
169192
):
170193
return ExtendingTraceTranslator.new(
171194
p_argdiffs,
@@ -174,6 +197,7 @@ def extending_trace_translator(
174197
new_obs,
175198
choice_map_forward,
176199
choice_map_backward,
200+
check_bijection,
177201
)
178202

179203

0 commit comments

Comments
 (0)