Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LinkNeighborLoader to Pytorch Lightning datamodule #4868

Merged
merged 20 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `LinkeNeighborLoader` support to lightning datamodule ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
Expand Down
25 changes: 24 additions & 1 deletion test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch
import torch.nn.functional as F

from torch_geometric.data import LightningDataset, LightningNodeData
from torch_geometric.data import (
LightningDataset,
LightningLinkData,
LightningNodeData,
)
from torch_geometric.nn import global_mean_pool
from torch_geometric.testing import onlyFullTest, withCUDA, withPackage

Expand Down Expand Up @@ -264,3 +268,22 @@ def test_lightning_hetero_node_data(get_dataset):
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `train`
offset += 5 * devices * math.ceil(400 / (devices * 32)) # `val`
assert torch.all(new_x > (old_x + offset - 4)) # Ensure shared data.


@withCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data(get_dataset):
# TODO: Add more datasets.
dataset = get_dataset(name='DBLP')
data = dataset[0]
datamodule = LightningLinkData(data, loader='link_neighbor',
num_neighbors=[5], batch_size=32,
num_workers=3)
input_edges = (('author', 'dummy', 'paper'), data['author',
'paper']['edge_index'])
loader = datamodule.dataloader(input_edges=input_edges, input_labels=None,
shuffle=True)
batch = next(iter(loader))
assert (batch['author', 'dummy',
'paper']['edge_label_index'].shape[1] == 32)
7 changes: 6 additions & 1 deletion torch_geometric/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
from .batch import Batch
from .dataset import Dataset
from .in_memory_dataset import InMemoryDataset
from .lightning_datamodule import LightningDataset, LightningNodeData
from .lightning_datamodule import (
LightningDataset,
LightningLinkData,
LightningNodeData,
)
from .makedirs import makedirs
from .download import download_url
from .extract import extract_tar, extract_zip, extract_bz2, extract_gz
Expand All @@ -18,6 +22,7 @@
'InMemoryDataset',
'LightningDataset',
'LightningNodeData',
'LightningLinkData',
'makedirs',
'download_url',
'extract_tar',
Expand Down
169 changes: 165 additions & 4 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import torch

from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.loader.neighbor_loader import (
NeighborLoader,
NeighborSampler,
get_input_nodes,
)
from torch_geometric.typing import InputNodes
from torch_geometric.typing import InputEdges, InputNodes

try:
from pytorch_lightning import LightningDataModule as PLLightningDataModule
Expand Down Expand Up @@ -245,9 +246,8 @@ def __init__(

if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='val')

if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='valid')
if input_val_nodes is None:
input_val_nodes = infer_input_nodes(data, split='valid')

if input_test_nodes is None:
input_test_nodes = infer_input_nodes(data, split='test')
Expand Down Expand Up @@ -352,6 +352,167 @@ def __repr__(self) -> str:
return f'{self.__class__.__name__}({kwargs})'


# TODO: Unify implementation with LightningNodeData via a common base class.
class LightningLinkData(LightningDataModule):
r"""Converts a :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object into a
:class:`pytorch_lightning.LightningDataModule` variant, which can be
automatically used as a :obj:`datamodule` for multi-GPU link-level
training (such as for link prediction) via `PyTorch Lightning
<https://www.pytorchlightning.ai>`_. :class:`LightningDataset` will
take care of providing mini-batches via
:class:`~torch_geometric.loader.LinkNeighborLoader`.

.. note::

Currently only the
:class:`pytorch_lightning.strategies.SingleDeviceStrategy` and
:class:`pytorch_lightning.strategies.DDPSpawnStrategy` training
strategies of `PyTorch Lightning
<https://pytorch-lightning.readthedocs.io/en/latest/guides/
speed.html>`__ are supported in order to correctly share data across
all devices/processes:

.. code-block::

import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
devices=4)
trainer.fit(model, datamodule)

Args:
data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` graph object.
input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The training edges. (default: :obj:`None`)
input_train_edge_label (Tensor, optional):
The labels of train edge indices.
input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The validation edges. (default: :obj:`None`)
input_val_edge_label (Tensor, optional):
The labels of val edge indices.
input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):
The test edges. (default: :obj:`None`)
input_test_edge_label (Tensor, optional):
The labels of train edge indices.
loader (str): The scalability technique to use (:obj:`"full"`,
:obj:`"link_neighbor"`). (default: :obj:`"link_neighbor"`)
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
num_workers: How many subprocesses to use for data loading.
:obj:`0` means that the data will be loaded in the main process.
(default: :obj:`0`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.loader.LinkNeighborLoader`.
"""
def __init__(
self,
data: Union[Data, HeteroData],
input_train_edges: InputEdges = None,
input_train_edge_label: torch.Tensor = None,
input_val_edges: InputEdges = None,
input_val_edge_label: torch.Tensor = None,
input_test_edges: InputEdges = None,
input_test_edge_label: torch.Tensor = None,
loader: str = "link_neighbor",
batch_size: int = 1,
num_workers: int = 0,
**kwargs,
):

assert loader in ['full', 'link_neighbor']
# TODO: Handle or document behavior where none of train, val, test
# edges are specified.
if loader == 'full' and batch_size != 1:
warnings.warn(f"Re-setting 'batch_size' to 1 in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got '{batch_size}')")
batch_size = 1

if loader == 'full' and num_workers != 0:
warnings.warn(f"Re-setting 'num_workers' to 0 in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got '{num_workers}')")
num_workers = 0

super().__init__(
has_val=input_val_edges is not None,
has_test=input_test_edges is not None,
batch_size=batch_size,
num_workers=num_workers,
**kwargs,
)

if loader == 'full':
if kwargs.get('pin_memory', False):
warnings.warn(f"Re-setting 'pin_memory' to 'False' in "
f"'{self.__class__.__name__}' for loader='full' "
f"(got 'True')")
self.kwargs['pin_memory'] = False

self.data = data
self.loader = loader

self.input_train_edges = input_train_edges
self.input_train_edge_label = input_train_edge_label
self.input_val_edges = input_val_edges
self.input_val_edge_label = input_val_edge_label
self.input_test_edges = input_test_edges
self.input_test_edge_label = input_test_edge_label

def prepare_data(self):
""""""
if self.loader == 'full':
try:
num_devices = self.trainer.num_devices
except AttributeError:
# PyTorch Lightning < 1.6 backward compatibility:
num_devices = self.trainer.num_processes
num_devices = max(num_devices, self.trainer.num_gpus)

if num_devices > 1:
raise ValueError(
f"'{self.__class__.__name__}' with loader='full' requires "
f"training on a single device")
super().prepare_data()

def dataloader(self, input_edges: InputEdges, input_labels: torch.Tensor,
shuffle: bool) -> DataLoader:
if self.loader == 'full':
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')
return torch.utils.data.DataLoader([self.data], shuffle=False,
collate_fn=lambda xs: xs[0],
**self.kwargs)

if self.loader == 'link_neighbor':
return LinkNeighborLoader(data=self.data,
edge_label_index=input_edges,
edge_label=input_labels, shuffle=shuffle,
**self.kwargs)

raise NotImplementedError

def train_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_train_edges,
self.input_train_edge_label, shuffle=True)

def val_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_val_edges, self.input_val_edge_label,
shuffle=False)

def test_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_test_edges,
self.input_test_edge_label, shuffle=False)

def __repr__(self) -> str:
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
return f'{self.__class__.__name__}({kwargs})'


###############################################################################


Expand Down
5 changes: 2 additions & 3 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinkNeighborLoader(torch.utils.data.DataLoader):
.. code-block:: python

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
from torch_geometric.loader import LinkNeighborLoader

data = Planetoid(path, name='Cora')[0]

Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(

def filter_fn(self, out: Any) -> Union[Data, HeteroData]:
if isinstance(self.data, Data):
node, row, col, edge, edge_label_index, edge_label = out
(node, row, col, edge, edge_label_index, edge_label) = out
data = filter_data(self.data, node, row, col, edge,
self.neighbor_sampler.perm)
data.edge_label_index = edge_label_index
Expand Down Expand Up @@ -355,7 +355,6 @@ def get_edge_label_index(

edge_type, edge_label_index = edge_label_index
edge_type = data._to_canonical(*edge_type)
assert edge_type in data.edge_types

if edge_label_index is None:
return edge_type, data[edge_type].edge_index
Expand Down