Skip to content

Commit 4af89d3

Browse files
authored
Merge branch 'master' into node_layer_norm
2 parents 400f76a + db5e6d9 commit 4af89d3

File tree

6 files changed

+186
-58
lines changed

6 files changed

+186
-58
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
1616
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
1717
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
18-
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883))
18+
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
1919
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
2020
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
2121
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))

test/loader/test_neighbor_loader.py

+54
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch_sparse import SparseTensor
55

66
from torch_geometric.data import Data, HeteroData
7+
from torch_geometric.data.feature_store import TensorAttr
78
from torch_geometric.loader import NeighborLoader
89
from torch_geometric.nn import GraphConv, to_hetero
910
from torch_geometric.testing import withRegisteredOp
@@ -322,6 +323,15 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
322323
edge_type=('author', 'to', 'paper'),
323324
layout='csc', size=(200, 100))
324325

326+
# COO (sorted):
327+
edge_index = get_edge_index(200, 200, 100)
328+
edge_index = edge_index[:, edge_index[1].argsort()]
329+
data['author', 'to', 'author'].edge_index = edge_index
330+
coo = (edge_index[0], edge_index[1])
331+
graph_store.put_edge_index(edge_index=coo,
332+
edge_type=('author', 'to', 'author'),
333+
layout='coo', size=(200, 200), is_sorted=True)
334+
325335
# Construct neighbor loaders:
326336
loader1 = NeighborLoader(data, batch_size=20,
327337
input_nodes=('paper', range(100)),
@@ -350,3 +360,47 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
350360
'paper', 'to', 'author'].edge_index.size())
351361
assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[
352362
'author', 'to', 'paper'].edge_index.size())
363+
364+
365+
@withRegisteredOp('torch_sparse.hetero_temporal_neighbor_sample')
366+
@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
367+
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
368+
def test_temporal_custom_neighbor_loader_on_cora(get_dataset, FeatureStore,
369+
GraphStore):
370+
# Initialize dataset (once):
371+
dataset = get_dataset(name='Cora')
372+
data = dataset[0]
373+
374+
# Initialize feature store, graph store, and reference:
375+
feature_store = FeatureStore()
376+
graph_store = GraphStore()
377+
hetero_data = HeteroData()
378+
379+
feature_store.put_tensor(data.x, group_name='paper', attr_name='x',
380+
index=None)
381+
hetero_data['paper'].x = data.x
382+
383+
feature_store.put_tensor(torch.arange(data.num_nodes), group_name='paper',
384+
attr_name='time', index=None)
385+
hetero_data['paper'].time = torch.arange(data.num_nodes)
386+
387+
num_nodes = data.x.size(dim=0)
388+
graph_store.put_edge_index(edge_index=data.edge_index,
389+
edge_type=('paper', 'to', 'paper'),
390+
layout='coo', size=(num_nodes, num_nodes))
391+
hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index
392+
393+
loader1 = NeighborLoader(hetero_data, num_neighbors=[-1, -1],
394+
input_nodes='paper', time_attr='time',
395+
batch_size=128)
396+
397+
loader2 = NeighborLoader(
398+
(feature_store, graph_store),
399+
num_neighbors=[-1, -1],
400+
input_nodes=TensorAttr(group_name='paper', attr_name='x'),
401+
time_attr='time',
402+
batch_size=128,
403+
)
404+
405+
for batch1, batch2 in zip(loader1, loader2):
406+
assert torch.equal(batch1['paper'].time, batch2['paper'].time)

torch_geometric/data/data.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,12 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
842842
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
843843
setattr(self, attr_name, attr_val)
844844

845+
# Set edge attributes:
846+
if not hasattr(self, '_edge_attrs'):
847+
self._edge_attrs = {}
848+
849+
self._edge_attrs[edge_attr.layout.value] = edge_attr
850+
845851
# Set size, if possible:
846852
size = edge_attr.size
847853
if size is not None:
@@ -866,13 +872,26 @@ def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
866872
def get_all_edge_attrs(self) -> List[EdgeAttr]:
867873
r"""Returns `EdgeAttr` objects corresponding to the edge indices stored
868874
in `Data` and their layouts"""
869-
out = []
875+
if not hasattr(self, '_edge_attrs'):
876+
return []
877+
added_attrs = set()
878+
879+
# Check edges added via _put_edge_index:
880+
edge_attrs = self._edge_attrs.values()
881+
for attr in edge_attrs:
882+
attr.size = (self.num_nodes, self.num_nodes)
883+
added_attrs.add(attr.layout)
884+
885+
# Check edges added through regular interface:
886+
# TODO deprecate this and store edge attributes for all edges in
887+
# EdgeStorage
870888
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
871-
if attr_name in self:
872-
out.append(
889+
if attr_name in self and layout not in added_attrs:
890+
edge_attrs.append(
873891
EdgeAttr(edge_type=None, layout=layout,
874892
size=(self.num_nodes, self.num_nodes)))
875-
return out
893+
894+
return edge_attrs
876895

877896

878897
###############################################################################

torch_geometric/data/graph_store.py

+57-44
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,12 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
117117
Raises:
118118
KeyError: if the edge index corresponding to attr was not found.
119119
"""
120+
120121
edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
121122
edge_attr.layout = EdgeLayout(edge_attr.layout)
122123
# Override is_sorted for CSC and CSR:
124+
# TODO treat is_sorted specially in this function, where is_sorted=True
125+
# returns an edge index sorted by column.
123126
edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [
124127
EdgeLayout.CSC, EdgeLayout.CSR
125128
])
@@ -131,9 +134,57 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
131134

132135
# Layout Conversion #######################################################
133136

137+
def _edge_to_layout(
138+
self,
139+
attr: EdgeAttr,
140+
layout: EdgeLayout,
141+
) -> Tuple[Tensor, Tensor, OptTensor]:
142+
from_tuple = self.get_edge_index(attr)
143+
144+
if layout == EdgeLayout.COO:
145+
if attr.layout == EdgeLayout.CSR:
146+
col = from_tuple[1]
147+
row = torch.ops.torch_sparse.ptr2ind(from_tuple[0],
148+
col.numel())
149+
else:
150+
row = from_tuple[0]
151+
col = torch.ops.torch_sparse.ptr2ind(from_tuple[1],
152+
row.numel())
153+
perm = None
154+
155+
elif layout == EdgeLayout.CSR:
156+
# We convert to CSR by converting to CSC on the transpose
157+
if attr.layout == EdgeLayout.COO:
158+
adj = edge_tensor_type_to_adj_type(
159+
attr, (from_tuple[1], from_tuple[0]))
160+
else:
161+
adj = edge_tensor_type_to_adj_type(attr, from_tuple).t()
162+
163+
# NOTE we set is_sorted=False here as is_sorted refers to
164+
# the edge_index being sorted by the destination node
165+
# (column), but here we deal with the transpose
166+
attr_copy = copy.copy(attr)
167+
attr_copy.is_sorted = False
168+
attr_copy.size = None if attr.size is None else (attr.size[1],
169+
attr.size[0])
170+
171+
# Actually rowptr, col, perm
172+
row, col, perm = to_csc(adj, attr_copy, device='cpu')
173+
174+
else:
175+
adj = edge_tensor_type_to_adj_type(attr, from_tuple)
176+
177+
# Actually colptr, row, perm
178+
col, row, perm = to_csc(adj, attr, device='cpu')
179+
180+
return row, col, perm
181+
134182
# TODO support `replace` to replace the existing edge index.
135-
def _to_layout(self, layout: EdgeLayout,
136-
store: bool = False) -> ConversionOutputType:
183+
def _all_edges_to_layout(
184+
self,
185+
layout: EdgeLayout,
186+
store: bool = False,
187+
) -> ConversionOutputType:
137188
# Obtain all edge attributes, grouped by type:
138189
edge_attrs = self.get_all_edge_attrs()
139190
edge_type_to_attrs: Dict[Any, List[EdgeAttr]] = defaultdict(list)
@@ -165,45 +216,7 @@ def _to_layout(self, layout: EdgeLayout,
165216
else:
166217
from_attr = edge_attrs[edge_layouts.index(EdgeLayout.CSR)]
167218

168-
from_tuple = self.get_edge_index(from_attr)
169-
170-
# Convert to the new layout:
171-
if layout == EdgeLayout.COO:
172-
if from_attr.layout == EdgeLayout.CSR:
173-
col = from_tuple[1]
174-
row = torch.ops.torch_sparse.ptr2ind(
175-
from_tuple[0], col.numel())
176-
else:
177-
row = from_tuple[0]
178-
col = torch.ops.torch_sparse.ptr2ind(
179-
from_tuple[1], row.numel())
180-
perm = None
181-
182-
elif layout == EdgeLayout.CSR:
183-
# We convert to CSR by converting to CSC on the transpose
184-
if from_attr.layout == EdgeLayout.COO:
185-
adj = edge_tensor_type_to_adj_type(
186-
from_attr, (from_tuple[1], from_tuple[0]))
187-
else:
188-
adj = edge_tensor_type_to_adj_type(
189-
from_attr, from_tuple).t()
190-
191-
# NOTE we set is_sorted=False here as is_sorted refers to
192-
# the edge_index being sorted by the destination node
193-
# (column), but here we deal with the transpose
194-
from_attr_copy = copy.copy(from_attr)
195-
from_attr_copy.is_sorted = False
196-
from_attr_copy.size = None if from_attr.size is None else (
197-
from_attr.size[1], from_attr.size[0])
198-
199-
# Actually rowptr, col, perm
200-
row, col, perm = to_csc(adj, from_attr_copy, device='cpu')
201-
202-
else:
203-
adj = edge_tensor_type_to_adj_type(from_attr, from_tuple)
204-
205-
# Actually colptr, row, perm
206-
col, row, perm = to_csc(adj, from_attr, device='cpu')
219+
row, col, perm = self._edge_to_layout(from_attr, layout)
207220

208221
row_dict[from_attr.edge_type] = row
209222
col_dict[from_attr.edge_type] = col
@@ -235,17 +248,17 @@ def _to_layout(self, layout: EdgeLayout,
235248
def coo(self, store: bool = False) -> ConversionOutputType:
236249
r"""Converts the edge indices in the graph store to COO format,
237250
optionally storing the converted edge indices in the graph store."""
238-
return self._to_layout(EdgeLayout.COO, store)
251+
return self._all_edges_to_layout(EdgeLayout.COO, store)
239252

240253
def csr(self, store: bool = False) -> ConversionOutputType:
241254
r"""Converts the edge indices in the graph store to CSR format,
242255
optionally storing the converted edge indices in the graph store."""
243-
return self._to_layout(EdgeLayout.CSR, store)
256+
return self._all_edges_to_layout(EdgeLayout.CSR, store)
244257

245258
def csc(self, store: bool = False) -> ConversionOutputType:
246259
r"""Converts the edge indices in the graph store to CSC format,
247260
optionally storing the converted edge indices in the graph store."""
248-
return self._to_layout(EdgeLayout.CSC, store)
261+
return self._all_edges_to_layout(EdgeLayout.CSC, store)
249262

250263
# Additional methods ######################################################
251264

torch_geometric/data/hetero_data.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
695695
out = self._node_store_dict.get(attr.group_name, None)
696696
if out:
697697
# Group name exists, handle index or create new attribute name:
698-
val = getattr(out, attr.attr_name)
698+
val = getattr(out, attr.attr_name, None)
699699
if val is not None:
700700
val[attr.index] = tensor
701701
else:
@@ -754,6 +754,13 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
754754
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
755755
setattr(self[edge_attr.edge_type], attr_name, attr_val)
756756

757+
# Set edge attributes:
758+
if not hasattr(self[edge_attr.edge_type], '_edge_attrs'):
759+
self[edge_attr.edge_type]._edge_attrs = {}
760+
761+
self[edge_attr.edge_type]._edge_attrs[
762+
edge_attr.layout.value] = edge_attr
763+
757764
key = self._to_canonical(edge_attr.edge_type)
758765
src, _, dst = key
759766

@@ -780,12 +787,30 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
780787
r"""Returns a list of `EdgeAttr` objects corresponding to the edge
781788
indices stored in `HeteroData` and their layouts."""
782789
out = []
790+
added_attrs = set()
791+
792+
# Check edges added via _put_edge_index:
793+
for edge_type, _ in self.edge_items():
794+
if not hasattr(self[edge_type], '_edge_attrs'):
795+
continue
796+
edge_attrs = self[edge_type]._edge_attrs.values()
797+
for attr in edge_attrs:
798+
attr.size = self[edge_type].size()
799+
added_attrs.add((attr.edge_type, attr.layout))
800+
out.extend(edge_attrs)
801+
802+
# Check edges added through regular interface:
803+
# TODO deprecate this and store edge attributes for all edges in
804+
# EdgeStorage
783805
for edge_type, edge_store in self.edge_items():
784806
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
785-
if attr_name in edge_store:
807+
# Don't double count:
808+
if attr_name in edge_store and ((edge_type, layout)
809+
not in added_attrs):
786810
out.append(
787811
EdgeAttr(edge_type=edge_type, layout=layout,
788812
size=self[edge_type].size()))
813+
789814
return out
790815

791816

torch_geometric/loader/neighbor_loader.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,21 @@ def __init__(
9595
# TODO support `collect` on `FeatureStore`
9696
self.node_time_dict = None
9797
if time_attr is not None:
98-
raise ValueError(
99-
f"'time_attr' attribute not yet supported for "
100-
f"'{data[0].__class__.__name__}' object")
98+
# We need to obtain all features with 'attr_name=time_attr'
99+
# from the feature store and store them in node_time_dict. To
100+
# do so, we make an explicit feature store GET call here with
101+
# the relevant 'TensorAttr's
102+
time_attrs = [
103+
attr for attr in feature_store.get_all_tensor_attrs()
104+
if attr.attr_name == time_attr
105+
]
106+
for attr in time_attrs:
107+
attr.index = None
108+
time_tensors = feature_store.multi_get_tensor(time_attrs)
109+
self.node_time_dict = {
110+
time_attr.group_name: time_tensor
111+
for time_attr, time_tensor in zip(time_attrs, time_tensors)
112+
}
101113

102114
# Obtain all node and edge metadata:
103115
node_attrs = feature_store.get_all_tensor_attrs()
@@ -475,18 +487,23 @@ def to_index(tensor):
475487
if isinstance(input_nodes, Tensor):
476488
return None, to_index(input_nodes)
477489

490+
# Can't infer number of nodes from a group_name; need an attr_name
478491
if isinstance(input_nodes, str):
479-
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
480-
return input_nodes, range(num_nodes)
492+
raise NotImplementedError(
493+
f"Cannot infer the number of nodes from a single string "
494+
f"(got '{input_nodes}'). Please pass a more explicit "
495+
f"representation. ")
481496

482497
if isinstance(input_nodes, (list, tuple)):
483498
assert len(input_nodes) == 2
484499
assert isinstance(input_nodes[0], str)
485500

486501
node_type, input_nodes = input_nodes
487502
if input_nodes is None:
488-
num_nodes = feature_store.get_tensor_size(input_nodes)[0]
489-
return input_nodes[0], range(num_nodes)
503+
raise NotImplementedError(
504+
f"Cannot infer the number of nodes from a node type alone "
505+
f"(got '{input_nodes}'). Please pass a more explicit "
506+
f"representation. ")
490507
return node_type, to_index(input_nodes)
491508

492509
assert isinstance(input_nodes, TensorAttr)

0 commit comments

Comments
 (0)