Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GraphStore: Data, HeteroData respect is_sorted #4922

Merged
merged 7 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))
- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))
- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883))
- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922))
- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))
- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
Expand Down
9 changes: 9 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ def test_custom_neighbor_loader(FeatureStore, GraphStore):
edge_type=('author', 'to', 'paper'),
layout='csc', size=(200, 100))

# COO (sorted):
edge_index = get_edge_index(200, 200, 100)
edge_index = edge_index[:, edge_index[1].argsort()]
data['author', 'to', 'author'].edge_index = edge_index
coo = (edge_index[0], edge_index[1])
graph_store.put_edge_index(edge_index=coo,
edge_type=('author', 'to', 'author'),
layout='coo', size=(200, 200), is_sorted=True)

# Construct neighbor loaders:
loader1 = NeighborLoader(data, batch_size=20,
input_nodes=('paper', range(100)),
Expand Down
20 changes: 13 additions & 7 deletions torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,12 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self, attr_name, attr_val)

# Set edge attributes:
if not hasattr(self, '_edge_attrs'):
self._edge_attrs = {}

self._edge_attrs[edge_attr.layout.value] = edge_attr

# Set size, if possible:
size = edge_attr.size
if size is not None:
Expand All @@ -866,13 +872,13 @@ def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
def get_all_edge_attrs(self) -> List[EdgeAttr]:
r"""Returns `EdgeAttr` objects corresponding to the edge indices stored
in `Data` and their layouts"""
out = []
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in self:
out.append(
EdgeAttr(edge_type=None, layout=layout,
size=(self.num_nodes, self.num_nodes)))
return out
if not hasattr(self, '_edge_attrs'):
return []

edge_attrs = self._edge_attrs.values()
for attr in edge_attrs:
attr.size = (self.num_nodes, self.num_nodes)
return edge_attrs


###############################################################################
Expand Down
101 changes: 57 additions & 44 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:
Raises:
KeyError: if the edge index corresponding to attr was not found.
"""

edge_attr = self._edge_attr_cls.cast(*args, **kwargs)
edge_attr.layout = EdgeLayout(edge_attr.layout)
# Override is_sorted for CSC and CSR:
# TODO treat is_sorted specially in this function, where is_sorted=True
# returns an edge index sorted by column.
edge_attr.is_sorted = edge_attr.is_sorted or (edge_attr.layout in [
EdgeLayout.CSC, EdgeLayout.CSR
])
Expand All @@ -131,9 +134,57 @@ def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:

# Layout Conversion #######################################################

def _edge_to_layout(
self,
attr: EdgeAttr,
layout: EdgeLayout,
) -> Tuple[Tensor, Tensor, OptTensor]:
from_tuple = self.get_edge_index(attr)

if layout == EdgeLayout.COO:
if attr.layout == EdgeLayout.CSR:
col = from_tuple[1]
row = torch.ops.torch_sparse.ptr2ind(from_tuple[0],
col.numel())
else:
row = from_tuple[0]
col = torch.ops.torch_sparse.ptr2ind(from_tuple[1],
row.numel())
perm = None

elif layout == EdgeLayout.CSR:
# We convert to CSR by converting to CSC on the transpose
if attr.layout == EdgeLayout.COO:
adj = edge_tensor_type_to_adj_type(
attr, (from_tuple[1], from_tuple[0]))
else:
adj = edge_tensor_type_to_adj_type(attr, from_tuple).t()

# NOTE we set is_sorted=False here as is_sorted refers to
# the edge_index being sorted by the destination node
# (column), but here we deal with the transpose
attr_copy = copy.copy(attr)
attr_copy.is_sorted = False
attr_copy.size = None if attr.size is None else (attr.size[1],
attr.size[0])

# Actually rowptr, col, perm
row, col, perm = to_csc(adj, attr_copy, device='cpu')

else:
adj = edge_tensor_type_to_adj_type(attr, from_tuple)

# Actually colptr, row, perm
col, row, perm = to_csc(adj, attr, device='cpu')

return row, col, perm

# TODO support `replace` to replace the existing edge index.
def _to_layout(self, layout: EdgeLayout,
store: bool = False) -> ConversionOutputType:
def _all_edges_to_layout(
self,
layout: EdgeLayout,
store: bool = False,
) -> ConversionOutputType:
# Obtain all edge attributes, grouped by type:
edge_attrs = self.get_all_edge_attrs()
edge_type_to_attrs: Dict[Any, List[EdgeAttr]] = defaultdict(list)
Expand Down Expand Up @@ -165,45 +216,7 @@ def _to_layout(self, layout: EdgeLayout,
else:
from_attr = edge_attrs[edge_layouts.index(EdgeLayout.CSR)]

from_tuple = self.get_edge_index(from_attr)

# Convert to the new layout:
if layout == EdgeLayout.COO:
if from_attr.layout == EdgeLayout.CSR:
col = from_tuple[1]
row = torch.ops.torch_sparse.ptr2ind(
from_tuple[0], col.numel())
else:
row = from_tuple[0]
col = torch.ops.torch_sparse.ptr2ind(
from_tuple[1], row.numel())
perm = None

elif layout == EdgeLayout.CSR:
# We convert to CSR by converting to CSC on the transpose
if from_attr.layout == EdgeLayout.COO:
adj = edge_tensor_type_to_adj_type(
from_attr, (from_tuple[1], from_tuple[0]))
else:
adj = edge_tensor_type_to_adj_type(
from_attr, from_tuple).t()

# NOTE we set is_sorted=False here as is_sorted refers to
# the edge_index being sorted by the destination node
# (column), but here we deal with the transpose
from_attr_copy = copy.copy(from_attr)
from_attr_copy.is_sorted = False
from_attr_copy.size = None if from_attr.size is None else (
from_attr.size[1], from_attr.size[0])

# Actually rowptr, col, perm
row, col, perm = to_csc(adj, from_attr_copy, device='cpu')

else:
adj = edge_tensor_type_to_adj_type(from_attr, from_tuple)

# Actually colptr, row, perm
col, row, perm = to_csc(adj, from_attr, device='cpu')
row, col, perm = self._edge_to_layout(from_attr, layout)

row_dict[from_attr.edge_type] = row
col_dict[from_attr.edge_type] = col
Expand Down Expand Up @@ -235,17 +248,17 @@ def _to_layout(self, layout: EdgeLayout,
def coo(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to COO format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.COO, store)
return self._all_edges_to_layout(EdgeLayout.COO, store)

def csr(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to CSR format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.CSR, store)
return self._all_edges_to_layout(EdgeLayout.CSR, store)

def csc(self, store: bool = False) -> ConversionOutputType:
r"""Converts the edge indices in the graph store to CSC format,
optionally storing the converted edge indices in the graph store."""
return self._to_layout(EdgeLayout.CSC, store)
return self._all_edges_to_layout(EdgeLayout.CSC, store)

# Additional methods ######################################################

Expand Down
18 changes: 13 additions & 5 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,13 @@ def _put_edge_index(self, edge_index: EdgeTensorType,
attr_val = edge_tensor_type_to_adj_type(edge_attr, edge_index)
setattr(self[edge_attr.edge_type], attr_name, attr_val)

# Set edge attributes:
if not hasattr(self[edge_attr.edge_type], '_edge_attrs'):
self[edge_attr.edge_type]._edge_attrs = {}

self[edge_attr.edge_type]._edge_attrs[
edge_attr.layout.value] = edge_attr

key = self._to_canonical(edge_attr.edge_type)
src, _, dst = key

Expand Down Expand Up @@ -781,11 +788,12 @@ def get_all_edge_attrs(self) -> List[EdgeAttr]:
indices stored in `HeteroData` and their layouts."""
out = []
for edge_type, edge_store in self.edge_items():
for layout, attr_name in EDGE_LAYOUT_TO_ATTR_NAME.items():
if attr_name in edge_store:
out.append(
EdgeAttr(edge_type=edge_type, layout=layout,
size=self[edge_type].size()))
if not hasattr(self[edge_type], '_edge_attrs'):
continue
edge_attrs = self[edge_type]._edge_attrs.values()
for attr in edge_attrs:
attr.size = self[edge_type].size()
out.extend(edge_attrs)
return out


Expand Down