|
66 | 66 | g1 = rand_graph(10, 20, graph_type = GRAPH_T) |
67 | 67 | g2 = rand_graph(5, 10, graph_type = GRAPH_T) |
68 | 68 | g12 = MLUtils.batch([g1, g2]) |
69 | | - gs = MLUtils.unbatch([g1, g2]) |
| 69 | + gs = MLUtils.unbatch(g12) |
70 | 70 | @test length(gs) == 2 |
71 | 71 | @test gs[1].num_nodes == 10 |
72 | 72 | @test gs[1].num_edges == 20 |
73 | 73 | @test gs[1].num_graphs == 1 |
74 | 74 | @test gs[2].num_nodes == 5 |
75 | 75 | @test gs[2].num_edges == 10 |
76 | 76 | @test gs[2].num_graphs == 1 |
| 77 | + |
| 78 | + if GRAPH_T == :coo |
| 79 | + @test gs[1] == g1 |
| 80 | + @test gs[2] == g2 |
| 81 | + |
| 82 | + @testset "coo zero-edge graphs" begin |
| 83 | + gempty = GNNGraph(2) |
| 84 | + gedge1 = GNNGraph(([1], [2]), num_nodes = 2) |
| 85 | + gedge2 = GNNGraph(([2], [1]), num_nodes = 2) |
| 86 | + |
| 87 | + for graphs in ([gempty, gedge1, gedge2], |
| 88 | + [gedge1, gempty, gedge2], |
| 89 | + [gedge1, gedge2, gempty]) |
| 90 | + @test MLUtils.unbatch(MLUtils.batch(graphs)) == collect(graphs) |
| 91 | + end |
| 92 | + end |
| 93 | + |
| 94 | + @testset "coo zero-edge graphs preserve features" begin |
| 95 | + g1f = GNNGraph(([1], [2]), num_nodes = 2, |
| 96 | + ndata = (x = Float32[1 2; 3 4],), |
| 97 | + edata = (e = Float32[10; 11;;],), |
| 98 | + gdata = 100f0) |
| 99 | + g2f = GNNGraph(2, |
| 100 | + ndata = (x = Float32[5 6; 7 8],), |
| 101 | + edata = (e = zeros(Float32, 2, 0),), |
| 102 | + gdata = 200f0) |
| 103 | + g3f = GNNGraph(([2], [1]), num_nodes = 2, |
| 104 | + ndata = (x = Float32[9 10; 11 12],), |
| 105 | + edata = (e = Float32[12; 13;;],), |
| 106 | + gdata = 300f0) |
| 107 | + |
| 108 | + gs_feat = [g1f, g2f, g3f] |
| 109 | + @test MLUtils.unbatch(MLUtils.batch(gs_feat)) == gs_feat |
| 110 | + end |
| 111 | + end |
77 | 112 | end |
78 | 113 | end |
79 | 114 |
|
|
0 commit comments