Skip to content

Commit f0fba57

Browse files
committed
Fix tests?
1 parent 7387071 commit f0fba57

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

test/datasets/graphs.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
@test g.num_edges == 9104
1212
@test size(g.node_data.features) == (3703, g.num_nodes)
1313
@test size(g.node_data.targets) == (g.num_nodes,)
14-
@test sum(g.node_data.train_mask) == 120
15-
@test sum(g.node_data.val_mask) == 500
16-
@test sum(g.node_data.test_mask) == 1015
14+
@test sum(data.metadata["node"].split[1].train) == 120
15+
@test sum(data.metadata["node"].split[1].val) == 500
16+
@test sum(data.metadata["node"].split[1].test) == 1015
1717
@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}
1818
s, t = g.edge_index
1919
for a in (s, t)
@@ -36,9 +36,9 @@ end
3636
@test g.num_edges == 10556
3737
@test size(g.node_data.features) == (1433, g.num_nodes)
3838
@test size(g.node_data.targets) == (g.num_nodes,)
39-
@test sum(g.node_data.train_mask) == 140
40-
@test sum(g.node_data.val_mask) == 500
41-
@test sum(g.node_data.test_mask) == 1000
39+
@test sum(data.metadata["node"].split[1].train) == 140
40+
@test sum(data.metadata["node"].split[1].val) == 500
41+
@test sum(data.metadata["node"].split[1].test) == 1000
4242
@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}
4343
s, t = g.edge_index
4444
for a in (s, t)
@@ -109,9 +109,9 @@ end
109109
@test g.num_edges == 88648
110110
@test size(g.node_data.features) == (500, g.num_nodes)
111111
@test size(g.node_data.targets) == (g.num_nodes,)
112-
@test sum(g.node_data.train_mask) == 60
113-
@test sum(g.node_data.val_mask) == 500
114-
@test sum(g.node_data.test_mask) == 1000
112+
@test sum(data.metadata["node"].split[1].train) == 60
113+
@test sum(data.metadata["node"].split[1].val) == 500
114+
@test sum(data.metadata["node"].split[1].test) == 1000
115115
@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}
116116
s, t = g.edge_index
117117
for a in (s, t)

test/datasets/graphs_no_ci.jl

+18-8
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,21 @@ end
247247
@test g.num_nodes == 169343
248248
@test g.num_edges == 1166243
249249

250-
@test sum(count.([g.node_data.train_mask, g.node_data.test_mask, g.node_data.val_mask])) == g.num_nodes
250+
train_mask = d.metadata["node"].split["time"][1].train
251+
test_mask = d.metadata["node"].split["time"][1].test
252+
val_mask = d.metadata["node"].split["time"][1].val
253+
254+
@test sum(count.([train_mask, test_mask, val_mask])) == g.num_nodes
251255
end
252256

253257
@testset "OGBDataset - ogbg-molhiv" begin
254258
d = OGBDataset("ogbg-molhiv")
255259

256-
@test sum(count.([d.graph_data.train_mask, d.graph_data.test_mask, d.graph_data.val_mask])) == length(d)
260+
train_mask = d.metadata["graph"].split["scaffold"].train
261+
test_mask = d.metadata["graph"].split["scaffold"].test
262+
val_mask = d.metadata["graph"].split["scaffold"].val
263+
264+
@test sum(count.([train_mask, test_mask, val_mask])) == length(d)
257265
end
258266

259267
@testset "Reddit_full" begin
@@ -264,9 +272,10 @@ end
264272
@test g.num_edges == 114615892
265273
@test size(g.node_data.features) == (602, g.num_nodes)
266274
@test size(g.node_data.labels) == (g.num_nodes,)
267-
@test count(g.node_data.train_mask) == 153431
268-
@test count(g.node_data.val_mask) == 23831
269-
@test count(g.node_data.test_mask) == 55703
275+
split = data.metadata["node"].split[1]
276+
@test count(split.train) == 153431
277+
@test count(split.val) == 23831
278+
@test count(split.test) == 55703
270279
s, t = g.edge_index
271280
@test length(s) == length(t) == g.num_edges
272281
@test minimum(s) == minimum(t) == 1
@@ -281,9 +290,10 @@ end
281290
@test g.num_edges == 23213838
282291
@test size(g.node_data.features) == (602, g.num_nodes)
283292
@test size(g.node_data.labels) == (g.num_nodes,)
284-
@test count(g.node_data.train_mask) == 152410
285-
@test count(g.node_data.val_mask) == 23699
286-
@test count(g.node_data.test_mask) == 55334
293+
split = data.metadata["node"].split[1]
294+
@test count(split.train) == 152410
295+
@test count(split.val) == 23699
296+
@test count(split.test) == 55334
287297
s, t = g.edge_index
288298
@test length(s) == length(t) == g.num_edges
289299
@test minimum(s) == minimum(t) == 1

0 commit comments

Comments
 (0)