Skip to content

Commit 9bb19b6

Browse files
Add batch and ptr vectors for a list of tensors and nested dicts (#4837)
* Added support of _ptr and _batch for nested structures, matching the implementation of _collate. Added test in test_batch.py. * Changed comment added pre-commits. * Added support of batch vectors for recursive data structures. * Update CHANGELOG.md Updated CHANGELOG.md. Still missing pull link. * Extended `InMemoryDataset` to infer len() correctly when using lists of tensors. Modified seperate such that InMemoryDataset can deal with lists of tensors correctly. Added test for it. * Added test to cover special elif case in `separate`. * Apply suggestions from code review Fix elif bug, change _batch name, clean-up (by rust1s) Co-authored-by: Matthias Fey <[email protected]> * Merged function _batch and _ptr to _batch_and_ptr. * Updated CHANGELOG.md * Fixed Type Hint. * Update collate.py Removed old/wrong code comment. Co-authored-by: janmeissnerRWTH <[email protected]> Co-authored-by: Matthias Fey <[email protected]>
1 parent 3d6eb74 commit 9bb19b6

File tree

6 files changed

+147
-11
lines changed

6 files changed

+147
-11
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
89
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
910
- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
1011
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
@@ -47,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4748
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
4849
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
4950
### Changed
51+
- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
52+
- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
5053
- Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852))
5154
- Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present
5255
- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828))

test/data/test_batch.py

+39
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,42 @@ def test_batch_with_empty_list():
387387
assert batch.nontensor == [[], []]
388388
assert batch[0].nontensor == []
389389
assert batch[1].nontensor == []
390+
391+
392+
def test_nested_follow_batch():
393+
def tr(n, m):
394+
return torch.rand((n, m))
395+
396+
d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)},
397+
x=tr(10, 5))
398+
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)},
399+
x=tr(11, 5))
400+
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)},
401+
x=tr(9, 5))
402+
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)},
403+
x=tr(8, 5))
404+
405+
# Dataset
406+
data_list = [d1, d2, d3, d4]
407+
408+
batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a'])
409+
410+
# assert shapes
411+
assert batch.xs[0].shape == (19, 3)
412+
assert batch.xs[1].shape == (56, 4)
413+
assert batch.xs[2].shape == (7, 2)
414+
assert batch.a['aa'].shape == (25, 3)
415+
416+
assert len(batch.xs_batch) == 3
417+
assert len(batch.a_batch) == 1
418+
419+
# assert _batch
420+
assert batch.xs_batch[0].tolist() == \
421+
[0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]
422+
assert batch.xs_batch[1].tolist() == \
423+
[0] * 11 + [1] * 14 + [2] * 15 + [3] * 16
424+
assert batch.xs_batch[2].tolist() == \
425+
[0] * 1 + [1] * 3 + [2] * 2 + [3] * 1
426+
427+
assert batch.a_batch['aa'].tolist() == \
428+
[0] * 11 + [1] * 2 + [2] * 4 + [3] * 8

test/data/test_dataset.py

+55
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch_sparse import SparseTensor
23

34
from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset
45

@@ -158,3 +159,57 @@ def _process(self):
158159
ds = DS3()
159160
assert not ds.enter_download
160161
assert not ds.enter_process
162+
163+
164+
class MyTestDataset2(InMemoryDataset):
165+
def __init__(self, data_list):
166+
super().__init__('/tmp/MyTestDataset2')
167+
self.data, self.slices = self.collate(data_list)
168+
169+
170+
def test_lists_of_tensors_in_memory_dataset():
171+
def tr(n, m):
172+
return torch.rand((n, m))
173+
174+
d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)])
175+
d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)])
176+
d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)])
177+
d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)])
178+
179+
data_list = [d1, d2, d3, d4]
180+
181+
dataset = MyTestDataset2(data_list)
182+
assert len(dataset) == 4
183+
assert dataset[0].xs[1].shape == (11, 4)
184+
assert dataset[0].xs[2].shape == (1, 2)
185+
assert dataset[1].xs[0].shape == (5, 3)
186+
assert dataset[2].xs[1].shape == (15, 4)
187+
assert dataset[3].xs[1].shape == (16, 4)
188+
189+
190+
class MyTestDataset3(InMemoryDataset):
191+
def __init__(self, data_list):
192+
super().__init__('/tmp/MyTestDataset3')
193+
self.data, self.slices = self.collate(data_list)
194+
195+
196+
def test_lists_of_SparseTensors():
197+
e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]])
198+
e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]])
199+
e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]])
200+
e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]])
201+
adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11))
202+
adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22))
203+
adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12))
204+
adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15))
205+
206+
d1 = Data(adj_test=[adj1, adj2])
207+
d2 = Data(adj_test=[adj3, adj4])
208+
209+
data_list = [d1, d2]
210+
dataset = MyTestDataset3(data_list)
211+
assert len(dataset) == 2
212+
assert dataset[0].adj_test[0].sparse_sizes() == (11, 11)
213+
assert dataset[0].adj_test[1].sparse_sizes() == (22, 22)
214+
assert dataset[1].adj_test[0].sparse_sizes() == (12, 12)
215+
assert dataset[1].adj_test[1].sparse_sizes() == (15, 15)

torch_geometric/data/collate.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,10 @@ def collate(
9696
inc_dict[attr] = incs
9797

9898
# Add an additional batch vector for the given attribute:
99-
if (attr in follow_batch and isinstance(slices, Tensor)
100-
and slices.dim() == 1):
101-
repeats = slices[1:] - slices[:-1]
102-
batch = repeat_interleave(repeats.tolist(), device=device)
99+
if attr in follow_batch:
100+
batch, ptr = _batch_and_ptr(slices, device)
103101
out_store[f'{attr}_batch'] = batch
104-
out_store[f'{attr}_ptr'] = cumsum(repeats.to(device))
102+
out_store[f'{attr}_ptr'] = ptr
105103

106104
# In case the storage holds node, we add a top-level batch vector it:
107105
if (add_batch and isinstance(stores[0], NodeStorage)
@@ -199,6 +197,39 @@ def _collate(
199197
return values, slices, None
200198

201199

200+
def _batch_and_ptr(
201+
slices: Any,
202+
device: Optional[torch.device] = None,
203+
) -> Tuple[Any, Any]:
204+
if (isinstance(slices, Tensor) and slices.dim() == 1):
205+
# Default case, turn slices tensor into batch.
206+
repeats = slices[1:] - slices[:-1]
207+
batch = repeat_interleave(repeats.tolist(), device=device)
208+
ptr = cumsum(repeats.to(device))
209+
return batch, ptr
210+
211+
elif isinstance(slices, Mapping):
212+
# Recursively batch elements of dictionaries.
213+
batch, ptr = {}, {}
214+
for k, v in slices.items():
215+
batch[k], ptr[k] = _batch_and_ptr(v, device)
216+
return batch, ptr
217+
218+
elif (isinstance(slices, Sequence) and not isinstance(slices, str)
219+
and isinstance(slices[0], Tensor)):
220+
# Recursively batch elements of lists.
221+
batch, ptr = [], []
222+
for s in slices:
223+
sub_batch, sub_ptr = _batch_and_ptr(s, device)
224+
batch.append(sub_batch)
225+
ptr.append(sub_ptr)
226+
return batch, ptr
227+
228+
else:
229+
# Failure of batching, usually due to slices.dim() != 1
230+
return None, None
231+
232+
202233
###############################################################################
203234

204235

torch_geometric/data/in_memory_dataset.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from collections.abc import Mapping
2+
from collections.abc import Mapping, Sequence
33
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
44

55
import torch
@@ -130,10 +130,13 @@ def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset':
130130
return dataset
131131

132132

133-
def nested_iter(mapping: Mapping) -> Iterable:
134-
for key, value in mapping.items():
135-
if isinstance(value, Mapping):
133+
def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
134+
if isinstance(node, Mapping):
135+
for key, value in node.items():
136136
for inner_key, inner_value in nested_iter(value):
137137
yield inner_key, inner_value
138-
else:
139-
yield key, value
138+
elif isinstance(node, Sequence):
139+
for i, inner_value in enumerate(node):
140+
yield i, inner_value
141+
else:
142+
yield None, node

torch_geometric/data/separate.py

+5
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def _separate(
9090
and not isinstance(value[0], str) and len(value[0]) > 0
9191
and isinstance(value[0][0], (Tensor, SparseTensor))):
9292
# Recursively separate elements of lists of lists.
93+
return [elem[idx] for elem in value]
94+
95+
elif (isinstance(value, Sequence) and not isinstance(value, str)
96+
and isinstance(value[0], (Tensor, SparseTensor))):
97+
# Recursively separate elements of lists of Tensors/SparseTensors.
9398
return [
9499
_separate(key, elem, idx, slices[i],
95100
incs[i] if decrement else None, batch, store, decrement)

0 commit comments

Comments
 (0)