Skip to content

Commit 65ab1e0

Browse files
wangbingnan136pre-commit-ci[bot]rusty1s
authored
Add support for predict_dataloader in LightningNodeData (#4884)
* Update lightning_datamodule.py * Update lightning_datamodule.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * update Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
1 parent 9fc80f3 commit 65ab1e0

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
55

66
## [2.0.5] - 2022-MM-DD
77
### Added
8+
- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))
89
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877))
910
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
1011
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))

torch_geometric/data/lightning_datamodule.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,30 @@ class LightningNodeData(LightningDataModule):
191191
data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or
192192
:class:`~torch_geometric.data.HeteroData` graph object.
193193
input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The
194-
indices of training nodes. If not given, will try to automatically
195-
infer them from the :obj:`data` object. (default: :obj:`None`)
194+
indices of training nodes.
195+
If not given, will try to automatically infer them from the
196+
:obj:`data` object by searching for :obj:`train_mask`,
197+
:obj:`train_idx`, or :obj:`train_index` attributes.
198+
(default: :obj:`None`)
196199
input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The
197-
indices of validation nodes. If not given, will try to
198-
automatically infer them from the :obj:`data` object.
200+
indices of validation nodes.
201+
If not given, will try to automatically infer them from the
202+
:obj:`data` object by searching for :obj:`val_mask`,
203+
:obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`,
204+
:obj:`val_index`, or :obj:`valid_index` attributes.
199205
(default: :obj:`None`)
200206
input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The
201-
indices of test nodes. If not given, will try to automatically
202-
infer them from the :obj:`data` object. (default: :obj:`None`)
207+
indices of test nodes.
208+
If not given, will try to automatically infer them from the
209+
:obj:`data` object by searching for :obj:`test_mask`,
210+
:obj:`test_idx`, or :obj:`test_index` attributes.
211+
(default: :obj:`None`)
212+
input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The
213+
indices of prediction nodes.
214+
If not given, will try to automatically infer them from the
215+
:obj:`data` object by searching for :obj:`pred_mask`,
216+
:obj:`pred_idx`, or :obj:`pred_index` attributes.
217+
(default: :obj:`None`)
203218
loader (str): The scalability technique to use (:obj:`"full"`,
204219
:obj:`"neighbor"`). (default: :obj:`"neighbor"`)
205220
batch_size (int, optional): How many samples per batch to load.
@@ -216,6 +231,7 @@ def __init__(
216231
input_train_nodes: InputNodes = None,
217232
input_val_nodes: InputNodes = None,
218233
input_test_nodes: InputNodes = None,
234+
input_pred_nodes: InputNodes = None,
219235
loader: str = "neighbor",
220236
batch_size: int = 1,
221237
num_workers: int = 0,
@@ -236,6 +252,9 @@ def __init__(
236252
if input_test_nodes is None:
237253
input_test_nodes = infer_input_nodes(data, split='test')
238254

255+
if input_pred_nodes is None:
256+
input_pred_nodes = infer_input_nodes(data, split='pred')
257+
239258
if loader == 'full' and batch_size != 1:
240259
warnings.warn(f"Re-setting 'batch_size' to 1 in "
241260
f"'{self.__class__.__name__}' for loader='full' "
@@ -279,6 +298,7 @@ def __init__(
279298
self.input_train_nodes = input_train_nodes
280299
self.input_val_nodes = input_val_nodes
281300
self.input_test_nodes = input_test_nodes
301+
self.input_pred_nodes = input_pred_nodes
282302

283303
def prepare_data(self):
284304
""""""
@@ -323,6 +343,10 @@ def test_dataloader(self) -> DataLoader:
323343
""""""
324344
return self.dataloader(self.input_test_nodes, shuffle=False)
325345

346+
def predict_dataloader(self) -> DataLoader:
347+
""""""
348+
return self.dataloader(self.input_pred_nodes, shuffle=False)
349+
326350
def __repr__(self) -> str:
327351
kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)
328352
return f'{self.__class__.__name__}({kwargs})'

0 commit comments

Comments
 (0)