@@ -34,6 +34,7 @@ def to_csc(
34
34
data : Union [Data , EdgeStorage ],
35
35
device : Optional [torch .device ] = None ,
36
36
share_memory : bool = False ,
37
+ is_sorted : bool = False ,
37
38
) -> Tuple [Tensor , Tensor , OptTensor ]:
38
39
# Convert the graph data into a suitable format for sampling (CSC format).
39
40
# Returns the `colptr` and `row` indices of the graph, as well as an
@@ -47,17 +48,18 @@ def to_csc(
47
48
48
49
elif hasattr (data , 'edge_index' ):
49
50
(row , col ) = data .edge_index
50
- size = data .size ()
51
- perm = (col * size [0 ]).add_ (row ).argsort ()
51
+ if not is_sorted :
52
+ size = data .size ()
53
+ perm = (col * size [0 ]).add_ (row ).argsort ()
54
+ row = row [perm ]
52
55
colptr = torch .ops .torch_sparse .ind2ptr (col [perm ], size [1 ])
53
- row = row [perm ]
54
56
else :
55
57
raise AttributeError ("Data object does not contain attributes "
56
58
"'adj_t' or 'edge_index'" )
57
59
58
60
colptr = colptr .to (device )
59
61
row = row .to (device )
60
- perm = perm if perm is not None else perm . to ( device )
62
+ perm = perm . to ( device ) if perm is not None else None
61
63
62
64
if not colptr .is_cuda and share_memory :
63
65
colptr .share_memory_ ()
@@ -72,6 +74,7 @@ def to_hetero_csc(
72
74
data : HeteroData ,
73
75
device : Optional [torch .device ] = None ,
74
76
share_memory : bool = False ,
77
+ is_sorted : bool = False ,
75
78
) -> Tuple [Dict [str , Tensor ], Dict [str , Tensor ], Dict [str , OptTensor ]]:
76
79
# Convert the heterogeneous graph data into a suitable format for sampling
77
80
# (CSC format).
@@ -83,7 +86,7 @@ def to_hetero_csc(
83
86
84
87
for store in data .edge_stores :
85
88
key = edge_type_to_str (store ._key )
86
- out = to_csc (store , device , share_memory )
89
+ out = to_csc (store , device , share_memory , is_sorted )
87
90
colptr_dict [key ], row_dict [key ], perm_dict [key ] = out
88
91
89
92
return colptr_dict , row_dict , perm_dict
0 commit comments