From fcbb02fbcc933c022afc04ab645555a102ffcf47 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 Jun 2022 09:19:20 +0000 Subject: [PATCH 1/2] filter per worker --- torch_geometric/loader/base.py | 2 +- torch_geometric/loader/hgt_loader.py | 25 ++++++++++++++-- .../loader/link_neighbor_loader.py | 29 +++++++++++++++---- torch_geometric/loader/neighbor_loader.py | 28 ++++++++++++++---- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/torch_geometric/loader/base.py b/torch_geometric/loader/base.py index 57270c8112fb..eaf9ecd99438 100644 --- a/torch_geometric/loader/base.py +++ b/torch_geometric/loader/base.py @@ -3,7 +3,7 @@ from torch.utils.data.dataloader import _BaseDataLoaderIter -class DataLoaderIterator(object): +class DataLoaderIterator: r"""A data loader iterator extended by a simple post transformation function :meth:`transform_fn`. While the iterator may request items from different sub-processes, :meth:`transform_fn` will always be executed in diff --git a/torch_geometric/loader/hgt_loader.py b/torch_geometric/loader/hgt_loader.py index a892a62eff93..bfc927aeaa6d 100644 --- a/torch_geometric/loader/hgt_loader.py +++ b/torch_geometric/loader/hgt_loader.py @@ -77,6 +77,14 @@ class HGTLoader(torch.utils.data.DataLoader): transform (Callable, optional): A function/transform that takes in an a sampled mini-batch and returns a transformed version. (default: :obj:`None`) + filter_per_worker (bool, optional): If set to :obj:`True`, will filter + the returning data in each worker's subprocess rather than in the + main process. + Setting this to :obj:`True` is generally not recommended: + (1) it may result in too many open file handles, + (2) it may slown down data loading, + (3) it requires operating on CPU tensors. + (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. @@ -87,6 +95,7 @@ def __init__( num_samples: Union[List[int], Dict[NodeType, List[int]]], input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]], transform: Callable = None, + filter_per_worker: bool = False, **kwargs, ): if 'collate_fn' in kwargs: @@ -112,6 +121,7 @@ def __init__( self.input_nodes = input_nodes self.num_hops = max([len(v) for v in num_samples.values()]) self.transform = transform + self.filter_per_worker = filter_per_worker self.sample_fn = torch.ops.torch_sparse.hgt_sample # Convert the graph data into a suitable format for sampling. @@ -134,8 +144,7 @@ def sample(self, indices: List[int]) -> HeteroData: ) return node_dict, row_dict, col_dict, edge_dict, len(indices) - def transform_fn(self, out: Any) -> HeteroData: - # NOTE This function will always be executed on the main thread! + def filter_fn(self, out: Any) -> HeteroData: node_dict, row_dict, col_dict, edge_dict, batch_size = out data = filter_hetero_data(self.data, node_dict, row_dict, col_dict, @@ -144,8 +153,18 @@ def transform_fn(self, out: Any) -> HeteroData: return data if self.transform is None else self.transform(data) + def collate_fn(self, indices: List[int]) -> Any: + out = self.sample(indices) + if self.filter_per_worker: + # We execute `filter_fn` in the worker process. + out = self.filter_fn(out) + return out + def _get_iterator(self) -> Iterator: - return DataLoaderIterator(super()._get_iterator(), self.transform_fn) + if self.filter_per_worker: + return super()._get_iterator() + # We execute `filter_fn` in the main process. + return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index da9b35fd330a..49dd4589c023 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -223,6 +223,14 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) + filter_per_worker (bool, optional): If set to :obj:`True`, will filter + the returning data in each worker's subprocess rather than in the + main process. + Setting this to :obj:`True` is generally not recommended: + (1) it may result in too many open file handles, + (2) it may slown down data loading, + (3) it requires operating on CPU tensors. + (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. @@ -238,6 +246,7 @@ def __init__( neg_sampling_ratio: float = 0.0, transform: Callable = None, is_sorted: bool = False, + filter_per_worker: bool = False, neighbor_sampler: Optional[LinkNeighborSampler] = None, **kwargs, ): @@ -255,9 +264,10 @@ def __init__( self.edge_label = edge_label self.replace = replace self.directed = directed + self.neg_sampling_ratio = neg_sampling_ratio self.transform = transform self.neighbor_sampler = neighbor_sampler - self.neg_sampling_ratio = neg_sampling_ratio + self.neighbor_sampler = neighbor_sampler edge_type, edge_label_index = get_edge_label_index( data, edge_label_index) @@ -275,10 +285,9 @@ def __init__( ) super().__init__(Dataset(edge_label_index, edge_label), - collate_fn=self.neighbor_sampler, **kwargs) + collate_fn=self.collate_fn, **kwargs) - def transform_fn(self, out: Any) -> Union[Data, HeteroData]: - # NOTE This function will always be executed on the main thread! + def filter_fn(self, out: Any) -> Union[Data, HeteroData]: if isinstance(self.data, Data): node, row, col, edge, edge_label_index, edge_label = out data = filter_data(self.data, node, row, col, edge, @@ -300,8 +309,18 @@ def transform_fn(self, out: Any) -> Union[Data, HeteroData]: return data if self.transform is None else self.transform(data) + def collate_fn(self, index: Union[List[int], Tensor]) -> Any: + out = self.neighbor_sampler(index) + if self.filter_per_worker: + # We execute `filter_fn` in the worker process. + out = self.filter_fn(out) + return out + def _get_iterator(self) -> Iterator: - return DataLoaderIterator(super()._get_iterator(), self.transform_fn) + if self.filter_per_worker: + return super()._get_iterator() + # We execute `filter_fn` in the main process. + return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 5493b31e02f5..9f0fe4193871 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -370,6 +370,14 @@ class NeighborLoader(torch.utils.data.DataLoader): :obj:`edge_index` is sorted by column. This avoids internal re-sorting of the data and can improve runtime and memory efficiency. (default: :obj:`False`) + filter_per_worker (bool, optional): If set to :obj:`True`, will filter + the returning data in each worker's subprocess rather than in the + main process. + Setting this to :obj:`True` is generally not recommended: + (1) it may result in too many open file handles, + (2) it may slown down data loading, + (3) it requires operating on CPU tensors. + (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. @@ -384,6 +392,7 @@ def __init__( time_attr: Optional[str] = None, transform: Callable = None, is_sorted: bool = False, + filter_per_worker: bool = False, neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): @@ -401,6 +410,7 @@ def __init__( self.replace = replace self.directed = directed self.transform = transform + self.filter_per_worker = filter_per_worker self.neighbor_sampler = neighbor_sampler node_type, input_nodes = get_input_nodes(data, input_nodes) @@ -417,11 +427,9 @@ def __init__( share_memory=kwargs.get('num_workers', 0) > 0, ) - super().__init__(input_nodes, collate_fn=self.neighbor_sampler, - **kwargs) + super().__init__(input_nodes, collate_fn=self.collate_fn, **kwargs) - def transform_fn(self, out: Any) -> Union[Data, HeteroData]: - # NOTE This function will always be executed on the main thread! + def filter_fn(self, out: Any) -> Union[Data, HeteroData]: if isinstance(self.data, Data): node, row, col, edge, batch_size = out data = filter_data(self.data, node, row, col, edge, @@ -445,8 +453,18 @@ def transform_fn(self, out: Any) -> Union[Data, HeteroData]: return data if self.transform is None else self.transform(data) + def collate_fn(self, index: Union[List[int], Tensor]) -> Any: + out = self.neighbor_sampler(index) + if self.filter_per_worker: + # We execute `filter_fn` in the worker process. + out = self.filter_fn(out) + return out + def _get_iterator(self) -> Iterator: - return DataLoaderIterator(super()._get_iterator(), self.transform_fn) + if self.filter_per_worker: + return super()._get_iterator() + # We execute `filter_fn` in the main process. + return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' From 128bba7d9525a056be942696a3780f8f65cf0140 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 27 Jun 2022 09:23:47 +0000 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 1 + torch_geometric/loader/hgt_loader.py | 2 +- torch_geometric/loader/link_neighbor_loader.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 082c62d3e17a..977f759a5bf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873)) - Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815)) - Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857)) - Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847)) diff --git a/torch_geometric/loader/hgt_loader.py b/torch_geometric/loader/hgt_loader.py index bfc927aeaa6d..00e9c2d3eb53 100644 --- a/torch_geometric/loader/hgt_loader.py +++ b/torch_geometric/loader/hgt_loader.py @@ -130,7 +130,7 @@ def __init__( self.colptr_dict, self.row_dict, self.perm_dict = to_hetero_csc( data, device='cpu', share_memory=kwargs.get('num_workers', 0) > 0) - super().__init__(input_nodes[1].tolist(), collate_fn=self.sample, + super().__init__(input_nodes[1].tolist(), collate_fn=self.collate_fn, **kwargs) def sample(self, indices: List[int]) -> HeteroData: diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 49dd4589c023..d4da7db645c6 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -266,7 +266,7 @@ def __init__( self.directed = directed self.neg_sampling_ratio = neg_sampling_ratio self.transform = transform - self.neighbor_sampler = neighbor_sampler + self.filter_per_worker = filter_per_worker self.neighbor_sampler = neighbor_sampler edge_type, edge_label_index = get_edge_label_index(