|
1 | 1 | import torch
|
| 2 | +from torch_sparse import SparseTensor |
2 | 3 |
|
3 | 4 | from torch_geometric.data import Data, Dataset, HeteroData, InMemoryDataset
|
4 | 5 |
|
@@ -158,3 +159,57 @@ def _process(self):
|
158 | 159 | ds = DS3()
|
159 | 160 | assert not ds.enter_download
|
160 | 161 | assert not ds.enter_process
|
| 162 | + |
| 163 | + |
| 164 | +class MyTestDataset2(InMemoryDataset): |
| 165 | + def __init__(self, data_list): |
| 166 | + super().__init__('/tmp/MyTestDataset2') |
| 167 | + self.data, self.slices = self.collate(data_list) |
| 168 | + |
| 169 | + |
| 170 | +def test_lists_of_tensors_in_memory_dataset(): |
| 171 | + def tr(n, m): |
| 172 | + return torch.rand((n, m)) |
| 173 | + |
| 174 | + d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)]) |
| 175 | + d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)]) |
| 176 | + d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)]) |
| 177 | + d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)]) |
| 178 | + |
| 179 | + data_list = [d1, d2, d3, d4] |
| 180 | + |
| 181 | + dataset = MyTestDataset2(data_list) |
| 182 | + assert len(dataset) == 4 |
| 183 | + assert dataset[0].xs[1].shape == (11, 4) |
| 184 | + assert dataset[0].xs[2].shape == (1, 2) |
| 185 | + assert dataset[1].xs[0].shape == (5, 3) |
| 186 | + assert dataset[2].xs[1].shape == (15, 4) |
| 187 | + assert dataset[3].xs[1].shape == (16, 4) |
| 188 | + |
| 189 | + |
| 190 | +class MyTestDataset3(InMemoryDataset): |
| 191 | + def __init__(self, data_list): |
| 192 | + super().__init__('/tmp/MyTestDataset3') |
| 193 | + self.data, self.slices = self.collate(data_list) |
| 194 | + |
| 195 | + |
| 196 | +def test_lists_of_SparseTensors(): |
| 197 | + e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]]) |
| 198 | + e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]]) |
| 199 | + e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]]) |
| 200 | + e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]]) |
| 201 | + adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11)) |
| 202 | + adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22)) |
| 203 | + adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12)) |
| 204 | + adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15)) |
| 205 | + |
| 206 | + d1 = Data(adj_test=[adj1, adj2]) |
| 207 | + d2 = Data(adj_test=[adj3, adj4]) |
| 208 | + |
| 209 | + data_list = [d1, d2] |
| 210 | + dataset = MyTestDataset3(data_list) |
| 211 | + assert len(dataset) == 2 |
| 212 | + assert dataset[0].adj_test[0].sparse_sizes() == (11, 11) |
| 213 | + assert dataset[0].adj_test[1].sparse_sizes() == (22, 22) |
| 214 | + assert dataset[1].adj_test[0].sparse_sizes() == (12, 12) |
| 215 | + assert dataset[1].adj_test[1].sparse_sizes() == (15, 15) |
0 commit comments