Skip to content

Commit 7761358

Browse files
authored
Fix labels, pyright in CI (#1592)
1 parent 9be1350 commit 7761358

File tree

6 files changed

+12
-13
lines changed

6 files changed

+12
-13
lines changed

.github/workflows/labeler.yml

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
jobs:
99
labeler:
1010
runs-on: ubuntu-22.04
11+
permissions:
12+
contents: read
13+
issues: write
1114
steps:
1215
- name: Check out the repository
1316
uses: actions/checkout@v4

.github/workflows/ruff.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ jobs:
1818

1919
- uses: chartboost/ruff-action@v1
2020
with:
21-
version: 0.9.9
21+
version: 0.11.2
2222
args: check --output-format github

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@ repos:
4343
- id: vulture
4444

4545
- repo: https://github.com/RobertCraigie/pyright-python
46-
rev: v1.1.398
46+
rev: v1.1.399
4747
hooks:
4848
- id: pyright

src/genjax/_src/adev/primitives.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
"""Defines ADEV primitives."""
1516

17+
# pyright: reportPrivateImportUsage=false
18+
1619
import jax
1720
import jax._src.core
1821
import jax._src.dtypes as jax_dtypes

src/genjax/_src/core/compiler/staging.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,8 @@ def stage(f):
287287
"""Returns a function that stages a function to a ClosedJaxpr."""
288288

289289
def wrapped(*args, **kwargs):
290-
debug_info = api_util.debug_info("Tracing to Jaxpr", f, args, kwargs)
291-
fun = lu.wrap_init(f, params=kwargs, debug_info=debug_info)
290+
debug_info = api_util.debug_info("Tracing to Jaxpr", f, args, kwargs) # pyright: ignore[reportAttributeAccessIssue]
291+
fun = lu.wrap_init(f, params=kwargs, debug_info=debug_info) # pyright: ignore[reportCallIssue]
292292
flat_args, in_tree = jtu.tree_flatten(args)
293293
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
294294
flat_avals = safe_map(get_shaped_aval, flat_args)

src/genjax/_src/core/generative/choice_map.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -1493,8 +1493,7 @@ def build(chm: ChoiceMap, addr: DynamicAddressComponent) -> ChoiceMap:
14931493
return Indexed(chm, addr)
14941494

14951495
def filter(self, selection: Selection | Flag) -> ChoiceMap:
1496-
addr = _full_slice if self.addr is None else self.addr
1497-
return self.c.filter(selection).extend(addr)
1496+
return self.c.filter(selection).extend(self.addr)
14981497

14991498
def get_value(self) -> Any:
15001499
return None
@@ -1510,13 +1509,7 @@ def get_inner_map(self, addr: AddressComponent) -> ChoiceMap:
15101509
"Only scalar dynamic addresses are supported by get_submap."
15111510
)
15121511

1513-
if self.addr is None:
1514-
# None means that this instance was created with `:`, so no masking is required and we assume that the user will provide an in-bounds `int | ScalarInt`` address. If they don't they will run up against JAX's clamping behavior.
1515-
return jtu.tree_map(
1516-
lambda v: v[addr], self.c, is_leaf=lambda x: isinstance(x, Mask)
1517-
)
1518-
1519-
elif isinstance(self.addr, Array) and self.addr.shape:
1512+
if isinstance(self.addr, Array) and self.addr.shape:
15201513
# We can't allow slices, as self.addr might look like, e.g. `[2,5,6]`, and we don't have any way to combine this "sparse array selector" with an incoming slice.
15211514
assert not isinstance(addr, slice), (
15221515
f"Slices are not allowed against array-shaped dynamic addresses. Tried to apply {addr} to {self.addr}."

0 commit comments

Comments
 (0)