Skip to content

Commit 0317cbc

Browse files
authored
Merge branch 'master' into gs_respect_is_sorted
2 parents d51ee88 + 31866b5 commit 0317cbc

23 files changed

+196
-144
lines changed

.github/workflows/testing.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- name: Install internal dependencies
3838
run: |
3939
pip install torch-scatter -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
40-
pip install torch-sparse -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
40+
pip install torch-sparse==0.6.13 -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
4141
pip install torch-cluster -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
4242
pip install torch-spline-conv -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
4343

CHANGELOG.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926))
9+
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
810
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
911
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
1012
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
@@ -26,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2628
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
2729
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
2830
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
29-
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872))
31+
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935))
3032
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
3133
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
3234
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
@@ -48,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4850
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
4951
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
5052
### Changed
53+
- `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931))
54+
- Fixed `data.subgraph` generation for 0-dim tensors ([#4932](https://github.com/pyg-team/pytorch_geometric/pull/4932))
5155
- Removed unnecssary inclusion of self-loops when sampling negative edges ([#4880](https://github.com/pyg-team/pytorch_geometric/pull/4880))
5256
- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
5357
- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))

docker/Dockerfile

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
FROM ubuntu:18.04
22

3+
# metainformation
4+
LABEL org.opencontainers.image.version = "2.0.4"
5+
LABEL org.opencontainers.image.authors = "Matthias Fey"
6+
LABEL org.opencontainers.image.source = "https://github.com/pyg-team/pytorch_geometric"
7+
LABEL org.opencontainers.image.licenses = "MIT"
8+
LABEL org.opencontainers.image.base.name="docker.io/library/ubuntu:18.04"
9+
310
RUN apt-get update && apt-get install -y apt-transport-https ca-certificates && \
411
rm -rf /var/lib/apt/lists/*
512

docs/source/notes/colabs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ We have prepared a list of colab notebooks that practically introduces you to th
99
4. `Scaling Graph Neural Networks <https://colab.research.google.com/drive/1XAjcjRHrSR_ypCk_feIWFbcBKyT4Lirs?usp=sharing>`__
1010
5. `Point Cloud Classification with Graph Neural Networks <https://colab.research.google.com/drive/1D45E5bUK3gQ40YpZo65ozs7hg5l-eo_U?usp=sharing>`__
1111
6. `Explaining GNN Model Predictions using Captum <https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing>`__
12+
7. `Customizing Aggregations within Message Passing <https://colab.research.google.com/drive/1KKw-VUDQuHhMo7sCd7ZaRROza3leBjRR?usp=sharing>`__
1213

1314
**Stanford CS224W Graph ML Tutorials:**
1415

test/data/test_batch.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def test_batch():
5252
assert str(batch) == ('DataBatch(x=[3], edge_index=[2, 4], y=[1], '
5353
'x_sp=[3, 1, nnz=3], adj=[3, 3, nnz=4], s=[1], '
5454
'array=[1], num_nodes=3, batch=[3], ptr=[2])')
55-
assert batch.num_graphs == 1
56-
assert len(batch) == 10
55+
assert batch.num_graphs == len(batch) == 1
5756
assert batch.x.tolist() == [1, 2, 3]
5857
assert batch.y.tolist() == [1]
5958
assert batch.x_sp.to_dense().view(-1).tolist() == batch.x.tolist()
@@ -72,8 +71,7 @@ def test_batch():
7271
'x_sp=[9, 1, nnz=9], adj=[9, 9, nnz=12], s=[3], '
7372
's_batch=[3], s_ptr=[4], array=[3], num_nodes=9, '
7473
'batch=[9], ptr=[4])')
75-
assert batch.num_graphs == 3
76-
assert len(batch) == 12
74+
assert batch.num_graphs == len(batch) == 3
7775
assert batch.x.tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]
7876
assert batch.y.tolist() == [1, 2, 3]
7977
assert batch.x_sp.to_dense().view(-1).tolist() == batch.x.tolist()
@@ -174,7 +172,7 @@ def __cat_dim__(self, key, value, *args, **kwargs):
174172

175173
assert str(batch) == ('MyDataBatch(x=[5], y=[2], foo=[2, 4], batch=[5], '
176174
'ptr=[3])')
177-
assert len(batch) == 5
175+
assert batch.num_graphs == len(batch) == 2
178176
assert batch.x.tolist() == [1, 2, 3, 1, 2]
179177
assert batch.foo.size() == (2, 4)
180178
assert batch.foo[0].tolist() == foo1.tolist()
@@ -208,7 +206,7 @@ def test_pickling():
208206
assert batch.num_nodes == 20
209207

210208
assert batch.__class__.__name__ == 'DataBatch'
211-
assert len(batch) == 3
209+
assert batch.num_graphs == len(batch) == 4
212210

213211
os.remove(path)
214212

@@ -230,8 +228,7 @@ def test_recursive_batch():
230228

231229
batch = Batch.from_data_list([data1, data2])
232230

233-
assert len(batch) == 5
234-
assert batch.num_graphs == 2
231+
assert batch.num_graphs == len(batch) == 2
235232
assert batch.num_nodes == 90
236233

237234
assert torch.allclose(batch.x['1'],
@@ -267,7 +264,7 @@ def test_batching_of_batches():
267264
batch = Batch.from_data_list([data, data])
268265

269266
batch = Batch.from_data_list([batch, batch])
270-
assert len(batch) == 2
267+
assert batch.num_graphs == len(batch) == 2
271268
assert batch.x[0:2].tolist() == data.x.tolist()
272269
assert batch.x[2:4].tolist() == data.x.tolist()
273270
assert batch.x[4:6].tolist() == data.x.tolist()
@@ -296,8 +293,7 @@ def test_hetero_batch():
296293

297294
batch = Batch.from_data_list([data1, data2])
298295

299-
assert len(batch) == 5
300-
assert batch.num_graphs == 2
296+
assert batch.num_graphs == len(batch) == 2
301297
assert batch.num_nodes == 450
302298

303299
assert torch.allclose(batch['p'].x[:100], data1['p'].x)

test/datasets/test_enzymes.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,24 @@ def test_enzymes(get_dataset):
2222
assert len(dataset[mask]) == 100
2323

2424
loader = DataLoader(dataset, batch_size=len(dataset))
25-
for data in loader:
26-
assert data.num_graphs == 600
25+
for batch in loader:
26+
assert batch.num_graphs == len(batch) == 600
2727

28-
avg_num_nodes = data.num_nodes / data.num_graphs
28+
avg_num_nodes = batch.num_nodes / batch.num_graphs
2929
assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63
3030

31-
avg_num_edges = data.num_edges / (2 * data.num_graphs)
31+
avg_num_edges = batch.num_edges / (2 * batch.num_graphs)
3232
assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14
3333

34-
assert len(data) == 5
35-
assert list(data.x.size()) == [data.num_nodes, 3]
36-
assert list(data.y.size()) == [data.num_graphs]
37-
assert data.y.max() + 1 == 6
38-
assert list(data.batch.size()) == [data.num_nodes]
39-
assert data.ptr.numel() == data.num_graphs + 1
34+
assert list(batch.x.size()) == [batch.num_nodes, 3]
35+
assert list(batch.y.size()) == [batch.num_graphs]
36+
assert batch.y.max() + 1 == 6
37+
assert list(batch.batch.size()) == [batch.num_nodes]
38+
assert batch.ptr.numel() == batch.num_graphs + 1
4039

41-
assert data.has_isolated_nodes()
42-
assert not data.has_self_loops()
43-
assert data.is_undirected()
40+
assert batch.has_isolated_nodes()
41+
assert not batch.has_self_loops()
42+
assert batch.is_undirected()
4443

4544
loader = DataListLoader(dataset, batch_size=len(dataset))
4645
for data_list in loader:
@@ -49,7 +48,6 @@ def test_enzymes(get_dataset):
4948
dataset.transform = ToDense(num_nodes=126)
5049
loader = DenseDataLoader(dataset, batch_size=len(dataset))
5150
for data in loader:
52-
assert len(data) == 4
5351
assert list(data.x.size()) == [600, 126, 3]
5452
assert list(data.adj.size()) == [600, 126, 126]
5553
assert list(data.mask.size()) == [600, 126]

test/datasets/test_planetoid.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,24 @@ def test_citeseer(get_dataset):
88
assert len(dataset) == 1
99
assert dataset.__repr__() == 'CiteSeer()'
1010

11-
for data in loader:
12-
assert data.num_graphs == 1
13-
assert data.num_nodes == 3327
14-
assert data.num_edges / 2 == 4552
15-
16-
assert len(data) == 8
17-
assert list(data.x.size()) == [data.num_nodes, 3703]
18-
assert list(data.y.size()) == [data.num_nodes]
19-
assert data.y.max() + 1 == 6
20-
assert data.train_mask.sum() == 6 * 20
21-
assert data.val_mask.sum() == 500
22-
assert data.test_mask.sum() == 1000
23-
assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0
24-
assert list(data.batch.size()) == [data.num_nodes]
25-
assert data.ptr.tolist() == [0, data.num_nodes]
26-
27-
assert data.has_isolated_nodes()
28-
assert not data.has_self_loops()
29-
assert data.is_undirected()
11+
for batch in loader:
12+
assert batch.num_graphs == len(batch) == 1
13+
assert batch.num_nodes == 3327
14+
assert batch.num_edges / 2 == 4552
15+
16+
assert list(batch.x.size()) == [batch.num_nodes, 3703]
17+
assert list(batch.y.size()) == [batch.num_nodes]
18+
assert batch.y.max() + 1 == 6
19+
assert batch.train_mask.sum() == 6 * 20
20+
assert batch.val_mask.sum() == 500
21+
assert batch.test_mask.sum() == 1000
22+
assert (batch.train_mask & batch.val_mask & batch.test_mask).sum() == 0
23+
assert list(batch.batch.size()) == [batch.num_nodes]
24+
assert batch.ptr.tolist() == [0, batch.num_nodes]
25+
26+
assert batch.has_isolated_nodes()
27+
assert not batch.has_self_loops()
28+
assert batch.is_undirected()
3029

3130

3231
def test_citeseer_with_full_split(get_dataset):

test/loader/test_dataloader.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_dataloader(num_workers):
3939
assert len(loader) == 2
4040

4141
for batch in loader:
42-
assert len(batch) == 8
42+
assert batch.num_graphs == len(batch) == 2
4343
assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1]
4444
assert batch.ptr.tolist() == [0, 3, 6]
4545
assert batch.x.tolist() == [[1], [1], [1], [1], [1], [1]]
@@ -58,7 +58,7 @@ def test_dataloader(num_workers):
5858
assert len(loader) == 2
5959

6060
for batch in loader:
61-
assert len(batch) == 10
61+
assert batch.num_graphs == len(batch) == 2
6262
assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1]
6363

6464

@@ -72,10 +72,10 @@ def test_multiprocessing():
7272
queue.put(batch)
7373

7474
batch = queue.get()
75-
assert len(batch) == 3
75+
assert batch.num_graphs == len(batch) == 2
7676

7777
batch = queue.get()
78-
assert len(batch) == 3
78+
assert batch.num_graphs == len(batch) == 2
7979

8080

8181
def test_pin_memory():
@@ -104,7 +104,7 @@ def test_heterogeneous_dataloader(num_workers):
104104
assert len(loader) == 2
105105

106106
for batch in loader:
107-
assert len(batch) == 5
107+
assert batch.num_graphs == len(batch) == 2
108108
assert batch.num_nodes == 600
109109

110110
for store in batch.stores:

test/loader/test_shadow.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_shadow_k_hop_sampler():
2020
assert len(loader) == 1
2121

2222
batch1 = next(iter(loader))
23-
assert len(batch1) == 7
23+
assert batch1.num_graphs == len(batch1) == 2
2424

2525
assert batch1.batch.tolist() == [0, 0, 0, 0, 1, 1, 1]
2626
assert batch1.ptr.tolist() == [0, 4, 7]
@@ -42,7 +42,7 @@ def test_shadow_k_hop_sampler():
4242
assert len(loader) == 1
4343

4444
batch2 = next(iter(loader))
45-
assert len(batch2) == 6
45+
assert batch2.num_graphs == len(batch2) == 2
4646

4747
assert batch1.batch.tolist() == batch2.batch.tolist()
4848
assert batch1.ptr.tolist() == batch2.ptr.tolist()

test/nn/aggr/test_scaler.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ def test_degree_scaler_aggregation():
1010
ptr = torch.tensor([0, 2, 5, 6])
1111
deg = torch.tensor([0, 3, 0, 1, 1, 0])
1212

13-
aggrs = ['mean', 'sum', 'max']
14-
scalers = [
13+
aggr = ['mean', 'sum', 'max']
14+
scaler = [
1515
'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'
1616
]
17-
aggr = DegreeScalerAggregation(aggrs, scalers, deg)
17+
aggr = DegreeScalerAggregation(aggr, scaler, deg)
1818
assert str(aggr) == 'DegreeScalerAggregation()'
1919

2020
out = aggr(x, index)
2121
assert out.size() == (3, 240)
22+
assert torch.allclose(torch.jit.script(aggr)(x, index), out)
2223

2324
with pytest.raises(NotImplementedError):
2425
aggr(x, ptr=ptr)

test/nn/test_resolver.py

+19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_geometric.nn.resolver import (
66
activation_resolver,
77
aggregation_resolver,
8+
normalization_resolver,
89
)
910

1011

@@ -34,3 +35,21 @@ def test_aggregation_resolver(aggr_tuple):
3435
aggr_module, aggr_repr = aggr_tuple
3536
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
3637
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)
38+
39+
40+
@pytest.mark.parametrize('norm_tuple', [
41+
(torch_geometric.nn.norm.BatchNorm, 'batch_norm', (16, )),
42+
(torch_geometric.nn.norm.InstanceNorm, 'instance_norm', (16, )),
43+
(torch_geometric.nn.norm.LayerNorm, 'layer_norm', (16, )),
44+
(torch_geometric.nn.norm.GraphNorm, 'graph_norm', (16, )),
45+
(torch_geometric.nn.norm.GraphSizeNorm, 'graphsize_norm', ()),
46+
(torch_geometric.nn.norm.PairNorm, 'pair_norm', ()),
47+
(torch_geometric.nn.norm.MessageNorm, 'message_norm', ()),
48+
(torch_geometric.nn.norm.DiffGroupNorm, 'diffgroup_norm', (16, 4)),
49+
])
50+
def test_normalization_resolver(norm_tuple):
51+
norm_module, norm_repr, norm_args = norm_tuple
52+
assert isinstance(normalization_resolver(norm_module(*norm_args)),
53+
norm_module)
54+
assert isinstance(normalization_resolver(norm_repr, *norm_args),
55+
norm_module)

test/transforms/test_rooted_subgraph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_rooted_subgraph_minibatch():
7373
loader = DataLoader([data, data], batch_size=2)
7474
batch = next(iter(loader))
7575
batch = batch.map_data()
76-
assert len(batch) == 6
76+
assert batch.num_graphs == len(batch) == 2
7777

7878
assert batch.x.size() == (14, 8)
7979
assert batch.edge_index.size() == (2, 16)

test/transforms/test_to_superpixels.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ def test_to_superpixels():
5757
assert y == 7
5858

5959
loader = DataLoader(dataset, batch_size=2, shuffle=False)
60-
for data, y in loader:
61-
assert len(data) == 4
62-
assert data.pos.dim() == 2 and data.pos.size(1) == 2
63-
assert data.x.dim() == 2 and data.x.size(1) == 1
64-
assert data.batch.dim() == 1
65-
assert data.ptr.dim() == 1
66-
assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
60+
for batch, y in loader:
61+
assert batch.num_graphs == len(batch) == 2
62+
assert batch.pos.dim() == 2 and batch.pos.size(1) == 2
63+
assert batch.x.dim() == 2 and batch.x.size(1) == 1
64+
assert batch.batch.dim() == 1
65+
assert batch.ptr.dim() == 1
66+
assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0)
6767
assert y.tolist() == [7, 2]
6868
break
6969

@@ -81,15 +81,15 @@ def test_to_superpixels():
8181
assert y == 7
8282

8383
loader = DataLoader(dataset, batch_size=2, shuffle=False)
84-
for data, y in loader:
85-
assert len(data) == 6
86-
assert data.pos.dim() == 2 and data.pos.size(1) == 2
87-
assert data.x.dim() == 2 and data.x.size(1) == 1
88-
assert data.batch.dim() == 1
89-
assert data.ptr.dim() == 1
90-
assert data.pos.size(0) == data.x.size(0) == data.batch.size(0)
91-
assert data.seg.size() == (2, 28, 28)
92-
assert data.img.size() == (2, 1, 28, 28)
84+
for batch, y in loader:
85+
assert batch.num_graphs == len(batch) == 2
86+
assert batch.pos.dim() == 2 and batch.pos.size(1) == 2
87+
assert batch.x.dim() == 2 and batch.x.size(1) == 1
88+
assert batch.batch.dim() == 1
89+
assert batch.ptr.dim() == 1
90+
assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0)
91+
assert batch.seg.size() == (2, 28, 28)
92+
assert batch.img.size() == (2, 1, 28, 28)
9393
assert y.tolist() == [7, 2]
9494
break
9595

torch_geometric/data/batch.py

+3
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ def num_graphs(self) -> int:
180180
else:
181181
raise ValueError("Can not infer the number of graphs")
182182

183+
def __len__(self) -> int:
184+
return self.num_graphs
185+
183186
def __reduce__(self):
184187
state = self.__dict__.copy()
185188
return DynamicInheritanceGetter(), self.__class__.__bases__, state

0 commit comments

Comments
 (0)