Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NeighborLoader benchmark suite #4815

Merged
merged 7 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
Expand Down
93 changes: 93 additions & 0 deletions benchmark/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
import os.path as osp
from timeit import default_timer

import tqdm
from ogb.nodeproppred import PygNodePropPredDataset

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader


def run(args: argparse.ArgumentParser) -> None:
for dataset_name in args.datasets:
print(f"Dataset: {dataset_name}")
root = osp.join(args.root, dataset_name)

if dataset_name == 'mag':
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG(root=root, transform=transform)
train_idx = ('paper', dataset[0]['paper'].train_mask)
eval_idx = ('paper', None)
neighbor_sizes = args.hetero_neighbor_sizes
else:
dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root)
split_idx = dataset.get_idx_split()
train_idx = split_idx['train']
eval_idx = None
neighbor_sizes = args.homo_neighbor_sizes

data = dataset[0].to(args.device)

for num_neighbors in neighbor_sizes:
print(f'Training sampling with {num_neighbors} neighbors')
for batch_size in args.batch_sizes:
train_loader = NeighborLoader(
data,
num_neighbors=num_neighbors,
input_nodes=train_idx,
batch_size=batch_size,
shuffle=True,
num_workers=args.num_workers,
)
runtimes = []
num_iterations = 0
for run in range(args.runs):
start = default_timer()
for batch in tqdm.tqdm(train_loader):
num_iterations += 1
stop = default_timer()
runtimes.append(round(stop - start, 3))
average_time = round(sum(runtimes) / args.runs, 3)
print(f'batch size={batch_size}, iterations={num_iterations}, '
f'runtimes={runtimes}, average runtime={average_time}')

print('Evaluation sampling with all neighbors')
for batch_size in args.eval_batch_sizes:
subgraph_loader = NeighborLoader(
data,
num_neighbors=[-1],
input_nodes=eval_idx,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
)
runtimes = []
num_iterations = 0
for run in range(args.runs):
start = default_timer()
for batch in tqdm.tqdm(subgraph_loader):
num_iterations += 1
stop = default_timer()
runtimes.append(round(stop - start, 3))
average_time = round(sum(runtimes) / args.runs, 3)
print(f'batch size={batch_size}, iterations={num_iterations}, '
f'runtimes={runtimes}, average runtime={average_time}')


if __name__ == '__main__':
parser = argparse.ArgumentParser('NeighborLoader Sampling Benchmarking')

add = parser.add_argument
add('--device', default='cpu')
add('--datasets', nargs="+", default=['arxiv', 'products', 'mag'])
add('--root', default='../../data')
add('--batch-sizes', default=[8192, 4096, 2048, 1024, 512])
add('--eval-batch-sizes', default=[16384, 8192, 4096, 2048, 1024, 512])
add('--homo-neighbor_sizes', default=[[10, 5], [15, 10, 5], [20, 15, 10]])
add('--hetero-neighbor_sizes', default=[[5], [10], [10, 5]], type=int)
add('--num-workers', default=0)
add('--runs', default=3)

run(parser.parse_args())