|
4 | 4 |
|
5 | 5 | from torch_geometric.data.graph_store import EdgeLayout
|
6 | 6 | from torch_geometric.testing.graph_store import MyGraphStore
|
| 7 | +from torch_geometric.utils.sort_edge_index import sort_edge_index |
| 8 | + |
| 9 | + |
| 10 | +def get_edge_index(num_src_nodes, num_dst_nodes, num_edges): |
| 11 | + row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long) |
| 12 | + col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long) |
| 13 | + return torch.stack([row, col], dim=0) |
7 | 14 |
|
8 | 15 |
|
9 | 16 | def test_graph_store():
|
@@ -38,3 +45,52 @@ def assert_equal_tensor_tuple(expected, actual):
|
38 | 45 |
|
39 | 46 | with pytest.raises(KeyError):
|
40 | 47 | _ = graph_store['edge_2', 'coo']
|
| 48 | + |
| 49 | + |
| 50 | +def test_graph_store_conversion(): |
| 51 | + graph_store = MyGraphStore() |
| 52 | + edge_index = get_edge_index(100, 100, 300) |
| 53 | + edge_index = sort_edge_index(edge_index, sort_by_row=False) |
| 54 | + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(100, 100)) |
| 55 | + |
| 56 | + coo = (edge_index[0], edge_index[1]) |
| 57 | + csr = adj.csr()[:2] |
| 58 | + csc = adj.csc()[-2::-1] |
| 59 | + |
| 60 | + # Put all edge indices: |
| 61 | + graph_store.put_edge_index(edge_index=coo, edge_type=('v', '1', 'v'), |
| 62 | + layout='coo', size=(100, 100), is_sorted=True) |
| 63 | + |
| 64 | + graph_store.put_edge_index(edge_index=csr, edge_type=('v', '2', 'v'), |
| 65 | + layout='csr', size=(100, 100)) |
| 66 | + |
| 67 | + graph_store.put_edge_index(edge_index=csc, edge_type=('v', '3', 'v'), |
| 68 | + layout='csc', size=(100, 100)) |
| 69 | + |
| 70 | + def assert_edge_index_equal(expected: torch.Tensor, actual: torch.Tensor): |
| 71 | + assert torch.equal(sort_edge_index(expected), sort_edge_index(actual)) |
| 72 | + |
| 73 | + # Convert to COO: |
| 74 | + row_dict, col_dict, perm_dict = graph_store.coo() |
| 75 | + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 |
| 76 | + for key in row_dict.keys(): |
| 77 | + actual = torch.stack((row_dict[key], col_dict[key])) |
| 78 | + assert_edge_index_equal(actual, edge_index) |
| 79 | + assert perm_dict[key] is None |
| 80 | + |
| 81 | + # Convert to CSR: |
| 82 | + rowptr_dict, col_dict, perm_dict = graph_store.csr() |
| 83 | + assert len(rowptr_dict) == len(col_dict) == len(perm_dict) == 3 |
| 84 | + for key in rowptr_dict: |
| 85 | + assert torch.equal(rowptr_dict[key], csr[0]) |
| 86 | + assert torch.equal(col_dict[key], csr[1]) |
| 87 | + if key == ('v', '1', 'v'): |
| 88 | + assert perm_dict[key] is not None |
| 89 | + |
| 90 | + # Convert to CSC: |
| 91 | + row_dict, colptr_dict, perm_dict = graph_store.csc() |
| 92 | + assert len(row_dict) == len(colptr_dict) == len(perm_dict) == 3 |
| 93 | + for key in row_dict: |
| 94 | + assert torch.equal(row_dict[key], csc[0]) |
| 95 | + assert torch.equal(colptr_dict[key], csc[1]) |
| 96 | + assert perm_dict[key] is None |
0 commit comments