diff --git a/CHANGELOG.md b/CHANGELOG.md index e9424a4624fb..3a06934bed42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620)) - Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521)) - Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604)) - Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600)) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index 550784a57762..cafe28d4c77a 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -274,6 +274,7 @@ def __init__( directed=kwargs.get('directed', True), input_type=get_input_nodes(data, input_train_nodes)[0], time_attr=kwargs.get('time_attr', None), + is_sorted=kwargs.get('is_sorted', False), ) self.input_train_nodes = input_train_nodes self.input_val_nodes = input_val_nodes diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index 08390100e2f0..e8162ea66fd9 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -201,9 +201,6 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all edges between all sampled nodes. (default: :obj:`True`) - transform (Callable, optional): A function/transform that takes in - a sampled mini-batch and returns a transformed version. - (default: :obj:`None`) neg_sampling_ratio (float, optional): The ratio of sampled negative edges to the number of positive edges. If :obj:`edge_label` does not exist, it will be automatically @@ -219,6 +216,13 @@ class LinkNeighborLoader(torch.utils.data.DataLoader): :meth:`F.binary_cross_entropy`) and of type :obj:`torch.long` for multi-class classification (to facilitate the ease-of-use of :meth:`F.cross_entropy`). (default: :obj:`0.0`). + transform (Callable, optional): A function/transform that takes in + a sampled mini-batch and returns a transformed version. + (default: :obj:`None`) + is_sorted (bool, optional): If set to :obj:`True`, assumes that + :obj:`edge_index` is sorted by column. This avoids internal + re-sorting of the data and can improve runtime and memory + efficiency. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. @@ -231,9 +235,10 @@ def __init__( edge_label: OptTensor = None, replace: bool = False, directed: bool = True, + neg_sampling_ratio: float = 0.0, transform: Callable = None, + is_sorted: bool = False, neighbor_sampler: Optional[LinkNeighborSampler] = None, - neg_sampling_ratio: float = 0.0, **kwargs, ): # Remove for PyTorch Lightning: @@ -259,9 +264,15 @@ def __init__( if neighbor_sampler is None: self.neighbor_sampler = LinkNeighborSampler( - data, num_neighbors, replace, directed, edge_type, + data, + num_neighbors, + replace, + directed, + input_type=edge_type, + is_sorted=is_sorted, + neg_sampling_ratio=self.neg_sampling_ratio, share_memory=kwargs.get('num_workers', 0) > 0, - neg_sampling_ratio=self.neg_sampling_ratio) + ) super().__init__(Dataset(edge_label_index, edge_label), collate_fn=self.neighbor_sampler, **kwargs) diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index b7bdf2c505cb..744b6e4d06e5 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -24,8 +24,9 @@ def __init__( replace: bool = False, directed: bool = True, input_type: Optional[Any] = None, - share_memory: bool = False, time_attr: Optional[str] = None, + is_sorted: bool = False, + share_memory: bool = False, ): self.data_cls = data.__class__ self.num_neighbors = num_neighbors @@ -41,7 +42,8 @@ def __init__( f"'{data.__class__.__name__}' object") # Convert the graph data into a suitable format for sampling. - out = to_csc(data, device='cpu', share_memory=share_memory) + out = to_csc(data, device='cpu', share_memory=share_memory, + is_sorted=is_sorted) self.colptr, self.row, self.perm = out assert isinstance(num_neighbors, (list, tuple)) @@ -54,7 +56,8 @@ def __init__( # Convert the graph data into a suitable format for sampling. # NOTE: Since C++ cannot take dictionaries with tuples as key as # input, edge type triplets are converted into single strings. - out = to_hetero_csc(data, device='cpu', share_memory=share_memory) + out = to_hetero_csc(data, device='cpu', share_memory=share_memory, + is_sorted=is_sorted) self.colptr_dict, self.row_dict, self.perm_dict = out self.node_types, self.edge_types = data.metadata() @@ -245,6 +248,10 @@ class NeighborLoader(torch.utils.data.DataLoader): transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) + is_sorted (bool, optional): If set to :obj:`True`, assumes that + :obj:`edge_index` is sorted by column. This avoids internal + re-sorting of the data and can improve runtime and memory + efficiency. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. @@ -258,6 +265,7 @@ def __init__( directed: bool = True, time_attr: Optional[str] = None, transform: Callable = None, + is_sorted: bool = False, neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): @@ -281,9 +289,15 @@ def __init__( if neighbor_sampler is None: self.neighbor_sampler = NeighborSampler( - data, num_neighbors, replace, directed, node_type, + data, + num_neighbors, + replace, + directed, + input_type=node_type, time_attr=time_attr, - share_memory=kwargs.get('num_workers', 0) > 0) + is_sorted=is_sorted, + share_memory=kwargs.get('num_workers', 0) > 0, + ) super().__init__(input_nodes, collate_fn=self.neighbor_sampler, **kwargs) diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index a7b783727477..e69ae0c0c9d3 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -34,6 +34,7 @@ def to_csc( data: Union[Data, EdgeStorage], device: Optional[torch.device] = None, share_memory: bool = False, + is_sorted: bool = False, ) -> Tuple[Tensor, Tensor, OptTensor]: # Convert the graph data into a suitable format for sampling (CSC format). # Returns the `colptr` and `row` indices of the graph, as well as an @@ -47,17 +48,18 @@ def to_csc( elif hasattr(data, 'edge_index'): (row, col) = data.edge_index - size = data.size() - perm = (col * size[0]).add_(row).argsort() + if not is_sorted: + size = data.size() + perm = (col * size[0]).add_(row).argsort() + row = row[perm] colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1]) - row = row[perm] else: raise AttributeError("Data object does not contain attributes " "'adj_t' or 'edge_index'") colptr = colptr.to(device) row = row.to(device) - perm = perm if perm is not None else perm.to(device) + perm = perm.to(device) if perm is not None else None if not colptr.is_cuda and share_memory: colptr.share_memory_() @@ -72,6 +74,7 @@ def to_hetero_csc( data: HeteroData, device: Optional[torch.device] = None, share_memory: bool = False, + is_sorted: bool = False, ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]: # Convert the heterogeneous graph data into a suitable format for sampling # (CSC format). @@ -83,7 +86,7 @@ def to_hetero_csc( for store in data.edge_stores: key = edge_type_to_str(store._key) - out = to_csc(store, device, share_memory) + out = to_csc(store, device, share_memory, is_sorted) colptr_dict[key], row_dict[key], perm_dict[key] = out return colptr_dict, row_dict, perm_dict