18
18
import jax
19
19
import jax .numpy as jnp
20
20
import jax .tree_util as jtu
21
+ from jax .experimental .checkify import check
21
22
23
+ from genjax ._src .checkify import optional_check
22
24
from genjax ._src .core .datatypes .generative import Choice
23
25
from genjax ._src .core .datatypes .generative import ChoiceMap
24
26
from genjax ._src .core .datatypes .generative import GenerativeFunction
25
27
from genjax ._src .core .datatypes .generative import Trace
26
28
from genjax ._src .core .pytree .pytree import Pytree
27
29
from genjax ._src .core .pytree .utilities import tree_grad_split
28
30
from genjax ._src .core .pytree .utilities import tree_zipper
31
+ from genjax ._src .core .typing import Bool
29
32
from genjax ._src .core .typing import Callable
30
33
from genjax ._src .core .typing import FloatArray
31
34
from genjax ._src .core .typing import PRNGKey
@@ -92,6 +95,7 @@ def safe_slogdet(v):
92
95
class ExtendingTraceTranslator (TraceTranslator ):
93
96
choice_map_forward : Callable # part of bijection
94
97
choice_map_inverse : Callable # part of bijection
98
+ check_bijection : Bool
95
99
p_argdiffs : Tuple
96
100
q_forward : GenerativeFunction
97
101
q_forward_args : Tuple
@@ -103,7 +107,7 @@ def flatten(self):
103
107
self .q_forward ,
104
108
self .q_forward_args ,
105
109
self .new_observations ,
106
- ), (self .choice_map_forward , self .choice_map_inverse )
110
+ ), (self .choice_map_forward , self .choice_map_inverse , self . check_bijection )
107
111
108
112
@classmethod
109
113
def new (
@@ -114,18 +118,21 @@ def new(
114
118
new_obs : Choice ,
115
119
choice_map_forward : Callable ,
116
120
choice_map_inverse : Callable ,
121
+ check_bijection : Bool ,
117
122
):
118
123
return ExtendingTraceTranslator (
119
124
choice_map_forward ,
120
125
choice_map_inverse ,
126
+ check_bijection ,
121
127
p_argdiffs ,
122
128
q_forward ,
123
129
q_forward_args ,
124
130
new_obs ,
125
131
)
126
132
127
133
def value_and_jacobian_correction (self , forward , trace ):
128
- grad_tree , no_grad_tree = tree_grad_split (trace .get_choices ())
134
+ trace_choices = trace .get_choices ()
135
+ grad_tree , no_grad_tree = tree_grad_split (trace_choices )
129
136
130
137
def _inner (differentiable ):
131
138
choices = tree_zipper (differentiable , no_grad_tree )
@@ -134,6 +141,21 @@ def _inner(differentiable):
134
141
135
142
inner_jacfwd = jax .jacfwd (_inner , has_aux = True )
136
143
J , transformed = inner_jacfwd (grad_tree )
144
+ if self .check_bijection :
145
+
146
+ def optional_check_bijection_is_bijection ():
147
+ backwards = self .choice_map_inverse (transformed )
148
+ flattened = jtu .tree_leaves (
149
+ jtu .tree_map (
150
+ lambda v1 , v2 : jnp .all (v1 == v2 ),
151
+ trace_choices ,
152
+ backwards ,
153
+ )
154
+ )
155
+ check_flag = jnp .all (jnp .array (flattened ))
156
+ check (check_flag , "Bijection check failed" )
157
+
158
+ optional_check (optional_check_bijection_is_bijection )
137
159
J = stack_differentiable (J )
138
160
(_ , J_log_abs_det ) = safe_slogdet (J )
139
161
return transformed , J_log_abs_det
@@ -166,6 +188,7 @@ def extending_trace_translator(
166
188
new_obs : ChoiceMap ,
167
189
choice_map_forward : Callable ,
168
190
choice_map_backward : Callable ,
191
+ check_bijection = False ,
169
192
):
170
193
return ExtendingTraceTranslator .new (
171
194
p_argdiffs ,
@@ -174,6 +197,7 @@ def extending_trace_translator(
174
197
new_obs ,
175
198
choice_map_forward ,
176
199
choice_map_backward ,
200
+ check_bijection ,
177
201
)
178
202
179
203
0 commit comments