Skip to content

Commit 03b3b94

Browse files
committed
[benchmark] Add NeighborLoader bench
1 parent c957f7f commit 03b3b94

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from ogb.nodeproppred import PygNodePropPredDataset
2+
import torch_geometric.transforms as T
3+
from torch_geometric.datasets import OGB_MAG
4+
import os.path as osp
5+
import argparse
6+
from timeit import default_timer
7+
from torch_geometric.loader import NeighborLoader
8+
9+
10+
def run(args: argparse.ArgumentParser) -> None:
11+
12+
print("BENCHMARK STARTS")
13+
for dataset_name in args.datasets:
14+
print("Dataset: ", dataset_name)
15+
16+
root = osp.join(osp.dirname(osp.realpath(__file__)),
17+
args.root, dataset_name.partition("-")[2])
18+
19+
if dataset_name == 'ogbn-mag':
20+
transform = T.ToUndirected(merge=True)
21+
dataset = OGB_MAG(root=root, transform=transform)
22+
train_idx = ('paper', dataset[0]['paper'].train_mask)
23+
eval_idx = ('paper', None)
24+
neighbor_sizes = args.hetero_neighbor_sizes
25+
else:
26+
dataset = PygNodePropPredDataset(dataset_name, root)
27+
split_idx = dataset.get_idx_split()
28+
train_idx = split_idx['train']
29+
eval_idx = None
30+
neighbor_sizes = args.homo_neighbor_sizes
31+
32+
data = dataset[0].to(args.device)
33+
34+
print('Train sampling')
35+
for sizes in neighbor_sizes:
36+
print(f'Sizes={sizes}')
37+
for batch_size in args.batch_sizes:
38+
train_loader = NeighborLoader(data,
39+
num_neighbors=sizes,
40+
input_nodes=train_idx,
41+
batch_size=batch_size,
42+
shuffle=True,
43+
num_workers=args.num_workers,)
44+
start = default_timer()
45+
iter = 0
46+
times = []
47+
for run in range(args.runs):
48+
start = default_timer()
49+
for batch in train_loader:
50+
iter = iter + 1
51+
stop = default_timer()
52+
times.append(round(stop - start, 3))
53+
average_time = round(sum(times) / args.runs, 3)
54+
print(f'Batch size={batch_size} iterations={iter} '
55+
+ f'times={times} average_time={average_time}')
56+
print('Evaluation sampling')
57+
for batch_size in args.eval_batch_sizes:
58+
subgraph_loader = NeighborLoader(data,
59+
num_neighbors=[-1],
60+
input_nodes=eval_idx,
61+
batch_size=batch_size,
62+
shuffle=False,
63+
num_workers=args.num_workers,)
64+
start = default_timer()
65+
iter = 0
66+
times = []
67+
for run in range(args.runs):
68+
start = default_timer()
69+
for batch in subgraph_loader:
70+
iter = iter + 1
71+
stop = default_timer()
72+
times.append(round(stop - start, 3))
73+
average_time = round(sum(times) / args.runs, 3)
74+
print(f'Batch size={batch_size} iterations={iter} '
75+
+ f'times={times} average_time={average_time}')
76+
77+
78+
if __name__ == '__main__':
79+
argparser = argparse.ArgumentParser('NeighborLoader Sampling Benchmarking')
80+
81+
argparser.add_argument('--device', default='cpu', type=str)
82+
argparser.add_argument('--datasets', nargs="+",
83+
default=['ogbn-arxiv', 'ogbn-products', 'ogbn-mag'], type=str)
84+
argparser.add_argument('--root', default='../../data', type=str)
85+
argparser.add_argument(
86+
'--batch-sizes', default=[8192, 4096, 2048, 1024, 512], type=int)
87+
argparser.add_argument('--homo-neighbor_sizes',
88+
default=[[10, 5], [15, 10, 5], [20, 15, 10]], type=int)
89+
argparser.add_argument('--hetero-neighbor_sizes',
90+
default=[[5], [10], [10, 5]], type=int)
91+
argparser.add_argument('--eval-batch-sizes',
92+
default=[16384, 8192, 4096, 2048, 1024, 512], type=int)
93+
argparser.add_argument('--num-workers', default=0, type=int)
94+
argparser.add_argument('--runs', default=3, type=int)
95+
96+
args = argparser.parse_args()
97+
98+
run(args)

0 commit comments

Comments
 (0)