@@ -247,13 +247,21 @@ end
247
247
@test g. num_nodes == 169343
248
248
@test g. num_edges == 1166243
249
249
250
- @test sum (count .([g. node_data. train_mask, g. node_data. test_mask, g. node_data. val_mask])) == g. num_nodes
250
+ train_mask = d. metadata[" node" ]. split[" time" ][1 ]. train
251
+ test_mask = d. metadata[" node" ]. split[" time" ][1 ]. test
252
+ val_mask = d. metadata[" node" ]. split[" time" ][1 ]. val
253
+
254
+ @test sum (count .([train_mask, test_mask, val_mask])) == g. num_nodes
251
255
end
252
256
253
257
@testset " OGBDataset - ogbg-molhiv" begin
254
258
d = OGBDataset (" ogbg-molhiv" )
255
259
256
- @test sum (count .([d. graph_data. train_mask, d. graph_data. test_mask, d. graph_data. val_mask])) == length (d)
260
+ train_mask = d. metadata[" graph" ]. split[" scaffold" ]. train
261
+ test_mask = d. metadata[" graph" ]. split[" scaffold" ]. test
262
+ val_mask = d. metadata[" graph" ]. split[" scaffold" ]. val
263
+
264
+ @test sum (count .([train_mask, test_mask, val_mask])) == length (d)
257
265
end
258
266
259
267
@testset " Reddit_full" begin
264
272
@test g. num_edges == 114615892
265
273
@test size (g. node_data. features) == (602 , g. num_nodes)
266
274
@test size (g. node_data. labels) == (g. num_nodes,)
267
- @test count (g. node_data. train_mask) == 153431
268
- @test count (g. node_data. val_mask) == 23831
269
- @test count (g. node_data. test_mask) == 55703
275
+ split = data. metadata[" node" ]. split[1 ]
276
+ @test count (split. train) == 153431
277
+ @test count (split. val) == 23831
278
+ @test count (split. test) == 55703
270
279
s, t = g. edge_index
271
280
@test length (s) == length (t) == g. num_edges
272
281
@test minimum (s) == minimum (t) == 1
281
290
@test g. num_edges == 23213838
282
291
@test size (g. node_data. features) == (602 , g. num_nodes)
283
292
@test size (g. node_data. labels) == (g. num_nodes,)
284
- @test count (g. node_data. train_mask) == 152410
285
- @test count (g. node_data. val_mask) == 23699
286
- @test count (g. node_data. test_mask) == 55334
293
+ split = data. metadata[" node" ]. split[1 ]
294
+ @test count (split. train) == 152410
295
+ @test count (split. val) == 23699
296
+ @test count (split. test) == 55334
287
297
s, t = g. edge_index
288
298
@test length (s) == length (t) == g. num_edges
289
299
@test minimum (s) == minimum (t) == 1
0 commit comments