Skip to content

Commit 06c34b2

Browse files
committed
update
1 parent f35c85f commit 06c34b2

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

torch_geometric/data/lightning_datamodule.py

+1
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(
274274
directed=kwargs.get('directed', True),
275275
input_type=get_input_nodes(data, input_train_nodes)[0],
276276
time_attr=kwargs.get('time_attr', None),
277+
is_sorted=kwargs.get('is_sorted', False),
277278
)
278279
self.input_train_nodes = input_train_nodes
279280
self.input_val_nodes = input_val_nodes

torch_geometric/loader/neighbor_loader.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
input_type: Optional[Any] = None,
2727
share_memory: bool = False,
2828
time_attr: Optional[str] = None,
29+
is_sorted: bool = False,
2930
):
3031
self.data_cls = data.__class__
3132
self.num_neighbors = num_neighbors
@@ -41,7 +42,8 @@ def __init__(
4142
f"'{data.__class__.__name__}' object")
4243

4344
# Convert the graph data into a suitable format for sampling.
44-
out = to_csc(data, device='cpu', share_memory=share_memory)
45+
out = to_csc(data, device='cpu', share_memory=share_memory,
46+
is_sorted=is_sorted)
4547
self.colptr, self.row, self.perm = out
4648
assert isinstance(num_neighbors, (list, tuple))
4749

@@ -54,7 +56,8 @@ def __init__(
5456
# Convert the graph data into a suitable format for sampling.
5557
# NOTE: Since C++ cannot take dictionaries with tuples as key as
5658
# input, edge type triplets are converted into single strings.
57-
out = to_hetero_csc(data, device='cpu', share_memory=share_memory)
59+
out = to_hetero_csc(data, device='cpu', share_memory=share_memory,
60+
is_sorted=is_sorted)
5861
self.colptr_dict, self.row_dict, self.perm_dict = out
5962

6063
self.node_types, self.edge_types = data.metadata()

torch_geometric/loader/utils.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def to_csc(
3434
data: Union[Data, EdgeStorage],
3535
device: Optional[torch.device] = None,
3636
share_memory: bool = False,
37+
is_sorted: bool = False,
3738
) -> Tuple[Tensor, Tensor, OptTensor]:
3839
# Convert the graph data into a suitable format for sampling (CSC format).
3940
# Returns the `colptr` and `row` indices of the graph, as well as an
@@ -47,17 +48,18 @@ def to_csc(
4748

4849
elif hasattr(data, 'edge_index'):
4950
(row, col) = data.edge_index
50-
size = data.size()
51-
perm = (col * size[0]).add_(row).argsort()
51+
if not is_sorted:
52+
size = data.size()
53+
perm = (col * size[0]).add_(row).argsort()
54+
row = row[perm]
5255
colptr = torch.ops.torch_sparse.ind2ptr(col[perm], size[1])
53-
row = row[perm]
5456
else:
5557
raise AttributeError("Data object does not contain attributes "
5658
"'adj_t' or 'edge_index'")
5759

5860
colptr = colptr.to(device)
5961
row = row.to(device)
60-
perm = perm if perm is not None else perm.to(device)
62+
perm = perm.to(device) if perm is not None else None
6163

6264
if not colptr.is_cuda and share_memory:
6365
colptr.share_memory_()
@@ -72,6 +74,7 @@ def to_hetero_csc(
7274
data: HeteroData,
7375
device: Optional[torch.device] = None,
7476
share_memory: bool = False,
77+
is_sorted: bool = False,
7578
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:
7679
# Convert the heterogeneous graph data into a suitable format for sampling
7780
# (CSC format).
@@ -83,7 +86,7 @@ def to_hetero_csc(
8386

8487
for store in data.edge_stores:
8588
key = edge_type_to_str(store._key)
86-
out = to_csc(store, device, share_memory)
89+
out = to_csc(store, device, share_memory, is_sorted)
8790
colptr_dict[key], row_dict[key], perm_dict[key] = out
8891

8992
return colptr_dict, row_dict, perm_dict

0 commit comments

Comments
 (0)