Skip to content

Commit b9cdc0f

Browse files
authored
Fix unbatch for COO batches with zero-edge graphs (#652)
Signed-off-by: Parvm1102 <parvmittal31757@gmail.com>
1 parent 5bcf181 commit b9cdc0f

2 files changed

Lines changed: 42 additions & 11 deletions

File tree

GNNGraphs/src/transform.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -795,18 +795,14 @@ function _unbatch_nodemasks(graph_indicator, num_graphs)
795795
end
796796

797797
function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
798-
edgemasks = []
799-
for i in 1:(num_graphs - 1)
800-
lastedgeid = findfirst(s) do x
801-
x > cumnum_nodes[i + 1] && x <= cumnum_nodes[i + 2]
802-
end
803-
firstedgeid = i == 1 ? 1 : last(edgemasks[i - 1]) + 1
804-
# if nothing make empty range
805-
lastedgeid = lastedgeid === nothing ? firstedgeid - 1 : lastedgeid - 1
798+
edgemasks = [Int[] for _ in 1:num_graphs]
806799

807-
push!(edgemasks, firstedgeid:lastedgeid)
800+
for (eid, src) in enumerate(s)
801+
graph_idx = searchsortedfirst(cumnum_nodes, src) - 1
802+
@assert 1 <= graph_idx <= num_graphs
803+
push!(edgemasks[graph_idx], eid)
808804
end
809-
push!(edgemasks, (last(edgemasks[end]) + 1):length(s))
805+
810806
return edgemasks
811807
end
812808

GNNGraphs/test/transform.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,49 @@ end
6666
g1 = rand_graph(10, 20, graph_type = GRAPH_T)
6767
g2 = rand_graph(5, 10, graph_type = GRAPH_T)
6868
g12 = MLUtils.batch([g1, g2])
69-
gs = MLUtils.unbatch([g1, g2])
69+
gs = MLUtils.unbatch(g12)
7070
@test length(gs) == 2
7171
@test gs[1].num_nodes == 10
7272
@test gs[1].num_edges == 20
7373
@test gs[1].num_graphs == 1
7474
@test gs[2].num_nodes == 5
7575
@test gs[2].num_edges == 10
7676
@test gs[2].num_graphs == 1
77+
78+
if GRAPH_T == :coo
79+
@test gs[1] == g1
80+
@test gs[2] == g2
81+
82+
@testset "coo zero-edge graphs" begin
83+
gempty = GNNGraph(2)
84+
gedge1 = GNNGraph(([1], [2]), num_nodes = 2)
85+
gedge2 = GNNGraph(([2], [1]), num_nodes = 2)
86+
87+
for graphs in ([gempty, gedge1, gedge2],
88+
[gedge1, gempty, gedge2],
89+
[gedge1, gedge2, gempty])
90+
@test MLUtils.unbatch(MLUtils.batch(graphs)) == collect(graphs)
91+
end
92+
end
93+
94+
@testset "coo zero-edge graphs preserve features" begin
95+
g1f = GNNGraph(([1], [2]), num_nodes = 2,
96+
ndata = (x = Float32[1 2; 3 4],),
97+
edata = (e = Float32[10; 11;;],),
98+
gdata = 100f0)
99+
g2f = GNNGraph(2,
100+
ndata = (x = Float32[5 6; 7 8],),
101+
edata = (e = zeros(Float32, 2, 0),),
102+
gdata = 200f0)
103+
g3f = GNNGraph(([2], [1]), num_nodes = 2,
104+
ndata = (x = Float32[9 10; 11 12],),
105+
edata = (e = Float32[12; 13;;],),
106+
gdata = 300f0)
107+
108+
gs_feat = [g1f, g2f, g3f]
109+
@test MLUtils.unbatch(MLUtils.batch(gs_feat)) == gs_feat
110+
end
111+
end
77112
end
78113
end
79114

0 commit comments

Comments
 (0)