Skip to content

Commit 0f7e018

Browse files
authored
Add filter_per_worker flag to data loaders (#4873)
* filter per worker * changelog
1 parent 927346e commit 0f7e018

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
89
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
910
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857))
1011
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))

torch_geometric/loader/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.utils.data.dataloader import _BaseDataLoaderIter
44

55

6-
class DataLoaderIterator(object):
6+
class DataLoaderIterator:
77
r"""A data loader iterator extended by a simple post transformation
88
function :meth:`transform_fn`. While the iterator may request items from
99
different sub-processes, :meth:`transform_fn` will always be executed in

torch_geometric/loader/hgt_loader.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ class HGTLoader(torch.utils.data.DataLoader):
7777
transform (Callable, optional): A function/transform that takes in
7878
an a sampled mini-batch and returns a transformed version.
7979
(default: :obj:`None`)
80+
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
81+
the returning data in each worker's subprocess rather than in the
82+
main process.
83+
Setting this to :obj:`True` is generally not recommended:
84+
(1) it may result in too many open file handles,
85+
(2) it may slown down data loading,
86+
(3) it requires operating on CPU tensors.
87+
(default: :obj:`False`)
8088
**kwargs (optional): Additional arguments of
8189
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
8290
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -87,6 +95,7 @@ def __init__(
8795
num_samples: Union[List[int], Dict[NodeType, List[int]]],
8896
input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]],
8997
transform: Callable = None,
98+
filter_per_worker: bool = False,
9099
**kwargs,
91100
):
92101
if 'collate_fn' in kwargs:
@@ -112,6 +121,7 @@ def __init__(
112121
self.input_nodes = input_nodes
113122
self.num_hops = max([len(v) for v in num_samples.values()])
114123
self.transform = transform
124+
self.filter_per_worker = filter_per_worker
115125
self.sample_fn = torch.ops.torch_sparse.hgt_sample
116126

117127
# Convert the graph data into a suitable format for sampling.
@@ -120,7 +130,7 @@ def __init__(
120130
self.colptr_dict, self.row_dict, self.perm_dict = to_hetero_csc(
121131
data, device='cpu', share_memory=kwargs.get('num_workers', 0) > 0)
122132

123-
super().__init__(input_nodes[1].tolist(), collate_fn=self.sample,
133+
super().__init__(input_nodes[1].tolist(), collate_fn=self.collate_fn,
124134
**kwargs)
125135

126136
def sample(self, indices: List[int]) -> HeteroData:
@@ -134,8 +144,7 @@ def sample(self, indices: List[int]) -> HeteroData:
134144
)
135145
return node_dict, row_dict, col_dict, edge_dict, len(indices)
136146

137-
def transform_fn(self, out: Any) -> HeteroData:
138-
# NOTE This function will always be executed on the main thread!
147+
def filter_fn(self, out: Any) -> HeteroData:
139148
node_dict, row_dict, col_dict, edge_dict, batch_size = out
140149

141150
data = filter_hetero_data(self.data, node_dict, row_dict, col_dict,
@@ -144,8 +153,18 @@ def transform_fn(self, out: Any) -> HeteroData:
144153

145154
return data if self.transform is None else self.transform(data)
146155

156+
def collate_fn(self, indices: List[int]) -> Any:
157+
out = self.sample(indices)
158+
if self.filter_per_worker:
159+
# We execute `filter_fn` in the worker process.
160+
out = self.filter_fn(out)
161+
return out
162+
147163
def _get_iterator(self) -> Iterator:
148-
return DataLoaderIterator(super()._get_iterator(), self.transform_fn)
164+
if self.filter_per_worker:
165+
return super()._get_iterator()
166+
# We execute `filter_fn` in the main process.
167+
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)
149168

150169
def __repr__(self) -> str:
151170
return f'{self.__class__.__name__}()'

torch_geometric/loader/link_neighbor_loader.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,14 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
223223
:obj:`edge_index` is sorted by column. This avoids internal
224224
re-sorting of the data and can improve runtime and memory
225225
efficiency. (default: :obj:`False`)
226+
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
227+
the returning data in each worker's subprocess rather than in the
228+
main process.
229+
Setting this to :obj:`True` is generally not recommended:
230+
(1) it may result in too many open file handles,
231+
(2) it may slown down data loading,
232+
(3) it requires operating on CPU tensors.
233+
(default: :obj:`False`)
226234
**kwargs (optional): Additional arguments of
227235
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
228236
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -238,6 +246,7 @@ def __init__(
238246
neg_sampling_ratio: float = 0.0,
239247
transform: Callable = None,
240248
is_sorted: bool = False,
249+
filter_per_worker: bool = False,
241250
neighbor_sampler: Optional[LinkNeighborSampler] = None,
242251
**kwargs,
243252
):
@@ -255,9 +264,10 @@ def __init__(
255264
self.edge_label = edge_label
256265
self.replace = replace
257266
self.directed = directed
267+
self.neg_sampling_ratio = neg_sampling_ratio
258268
self.transform = transform
269+
self.filter_per_worker = filter_per_worker
259270
self.neighbor_sampler = neighbor_sampler
260-
self.neg_sampling_ratio = neg_sampling_ratio
261271

262272
edge_type, edge_label_index = get_edge_label_index(
263273
data, edge_label_index)
@@ -275,10 +285,9 @@ def __init__(
275285
)
276286

277287
super().__init__(Dataset(edge_label_index, edge_label),
278-
collate_fn=self.neighbor_sampler, **kwargs)
288+
collate_fn=self.collate_fn, **kwargs)
279289

280-
def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
281-
# NOTE This function will always be executed on the main thread!
290+
def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
282291
if isinstance(self.data, Data):
283292
node, row, col, edge, edge_label_index, edge_label = out
284293
data = filter_data(self.data, node, row, col, edge,
@@ -300,8 +309,18 @@ def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
300309

301310
return data if self.transform is None else self.transform(data)
302311

312+
def collate_fn(self, index: Union[List[int], Tensor]) -> Any:
313+
out = self.neighbor_sampler(index)
314+
if self.filter_per_worker:
315+
# We execute `filter_fn` in the worker process.
316+
out = self.filter_fn(out)
317+
return out
318+
303319
def _get_iterator(self) -> Iterator:
304-
return DataLoaderIterator(super()._get_iterator(), self.transform_fn)
320+
if self.filter_per_worker:
321+
return super()._get_iterator()
322+
# We execute `filter_fn` in the main process.
323+
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)
305324

306325
def __repr__(self) -> str:
307326
return f'{self.__class__.__name__}()'

torch_geometric/loader/neighbor_loader.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,14 @@ class NeighborLoader(torch.utils.data.DataLoader):
370370
:obj:`edge_index` is sorted by column. This avoids internal
371371
re-sorting of the data and can improve runtime and memory
372372
efficiency. (default: :obj:`False`)
373+
filter_per_worker (bool, optional): If set to :obj:`True`, will filter
374+
the returning data in each worker's subprocess rather than in the
375+
main process.
376+
Setting this to :obj:`True` is generally not recommended:
377+
(1) it may result in too many open file handles,
378+
(2) it may slown down data loading,
379+
(3) it requires operating on CPU tensors.
380+
(default: :obj:`False`)
373381
**kwargs (optional): Additional arguments of
374382
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
375383
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
@@ -384,6 +392,7 @@ def __init__(
384392
time_attr: Optional[str] = None,
385393
transform: Callable = None,
386394
is_sorted: bool = False,
395+
filter_per_worker: bool = False,
387396
neighbor_sampler: Optional[NeighborSampler] = None,
388397
**kwargs,
389398
):
@@ -401,6 +410,7 @@ def __init__(
401410
self.replace = replace
402411
self.directed = directed
403412
self.transform = transform
413+
self.filter_per_worker = filter_per_worker
404414
self.neighbor_sampler = neighbor_sampler
405415

406416
node_type, input_nodes = get_input_nodes(data, input_nodes)
@@ -417,11 +427,9 @@ def __init__(
417427
share_memory=kwargs.get('num_workers', 0) > 0,
418428
)
419429

420-
super().__init__(input_nodes, collate_fn=self.neighbor_sampler,
421-
**kwargs)
430+
super().__init__(input_nodes, collate_fn=self.collate_fn, **kwargs)
422431

423-
def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
424-
# NOTE This function will always be executed on the main thread!
432+
def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
425433
if isinstance(self.data, Data):
426434
node, row, col, edge, batch_size = out
427435
data = filter_data(self.data, node, row, col, edge,
@@ -445,8 +453,18 @@ def transform_fn(self, out: Any) -> Union[Data, HeteroData]:
445453

446454
return data if self.transform is None else self.transform(data)
447455

456+
def collate_fn(self, index: Union[List[int], Tensor]) -> Any:
457+
out = self.neighbor_sampler(index)
458+
if self.filter_per_worker:
459+
# We execute `filter_fn` in the worker process.
460+
out = self.filter_fn(out)
461+
return out
462+
448463
def _get_iterator(self) -> Iterator:
449-
return DataLoaderIterator(super()._get_iterator(), self.transform_fn)
464+
if self.filter_per_worker:
465+
return super()._get_iterator()
466+
# We execute `filter_fn` in the main process.
467+
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)
450468

451469
def __repr__(self) -> str:
452470
return f'{self.__class__.__name__}()'

0 commit comments

Comments
 (0)