Skip to content

Commit 3f134c9

Browse files
Test: Replaced raw Mooncake with Flux implementation and checked Mooncake compatibility for more layers (#642)
* Test: Replaced raw Mooncake with Flux implementation and checked Mooncake compatibility for more layers Signed-off-by: Parvm1102 <parvmittal31757@gmail.com> * Update GraphNeuralNetworks/test/test_module.jl * Update GraphNeuralNetworks/test/test_module.jl --------- Signed-off-by: Parvm1102 <parvmittal31757@gmail.com> Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
1 parent fa1aba3 commit 3f134c9

File tree

5 files changed

+46
-59
lines changed

5 files changed

+46
-59
lines changed

GraphNeuralNetworks/test/layers/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
Flux.testmode!(gnn)
2020

21-
test_gradients(gnn, g, x, rtol = 1e-5)
21+
test_gradients(gnn, g, x, rtol = 1e-5, test_mooncake = false)
2222

2323
@testset "constructor with names" begin
2424
m = GNNChain(GCNConv(din => d),
@@ -53,7 +53,7 @@
5353

5454
Flux.trainmode!(gnn)
5555

56-
test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4)
56+
test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4, test_mooncake = false)
5757
end
5858
end
5959

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ end
8686
for g in TEST_GRAPHS
8787
g = add_self_loops(g)
8888
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
89-
# Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
90-
test_gradients(l, g, g.x, rtol = RTOL_LOW)
89+
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = false)
9190
end
9291

9392
@testset "bias=false" begin
@@ -198,8 +197,7 @@ end
198197
l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0)
199198
for g in TEST_GRAPHS
200199
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
201-
# Mooncake backward pass error for this layer on CI
202-
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW)
200+
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
203201
end
204202
end
205203

@@ -208,8 +206,7 @@ end
208206
l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
209207
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
210208
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
211-
# Mooncake backward pass error for this layer on CI
212-
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW)
209+
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
213210
end
214211

215212
@testset "num params" begin
@@ -568,31 +565,30 @@ end
568565
ein = 2
569566
heads = 3
570567
# used like in Kool et al., 2019
571-
# Mooncake backward pass error for this layer on CI
572568
l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true,
573569
root_weight = false, ff_channels = 10, skip_connection = true,
574570
batch_norm = false)
575571
# batch_norm=false here for tests to pass; true in paper
576572
for g in TEST_GRAPHS
577573
g = GNNGraph(g, ndata = rand(Float32, D_IN * heads, g.num_nodes))
578574
@test size(l(g, g.x)) == (D_IN * heads, g.num_nodes)
579-
test_gradients(l, g, g.x, rtol = RTOL_LOW)
575+
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
580576
end
581577
# used like in Shi et al., 2021
582578
l = TransformerConv((D_IN, ein) => D_IN; heads, gating = true,
583579
bias_qkv = true)
584580
for g in TEST_GRAPHS
585581
g = GNNGraph(g, edata = rand(Float32, ein, g.num_edges))
586582
@test size(l(g, g.x, g.e)) == (D_IN * heads, g.num_nodes)
587-
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW)
583+
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
588584
end
589585
# test averaging heads
590586
l = TransformerConv(D_IN => D_IN; heads, concat = false,
591587
bias_root = false,
592588
root_weight = false)
593589
for g in TEST_GRAPHS
594590
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
595-
test_gradients(l, g, g.x, rtol = RTOL_LOW)
591+
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
596592
end
597593
end
598594

@@ -620,8 +616,7 @@ end
620616
l = DConv(D_IN => D_OUT, k)
621617
for g in TEST_GRAPHS
622618
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
623-
# Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
624-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
619+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = false)
625620
end
626621
end
627622
end

GraphNeuralNetworks/test/layers/pool.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@test u[:, [1]] sum(g.ndata.x[:, 1:n], dims = 2)
2020
@test p(g).gdata.u == u
2121

22-
test_gradients(p, g, g.x, rtol = 1e-5)
22+
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
2323
end
2424
end
2525

@@ -42,7 +42,7 @@ end
4242
for i in 1:ng])
4343

4444
@test size(p(g, g.x)) == (chout, ng)
45-
test_gradients(p, g, g.x, rtol = 1e-5)
45+
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
4646
end
4747
end
4848

GraphNeuralNetworks/test/layers/temporalconv.jl

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ end
3232
@test y === h
3333
@test size(h) == (out_channel, g.num_nodes)
3434
# with no initial state
35-
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
35+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
3636
# with initial state
37-
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)
37+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
3838

3939
# Test with custom activation function
4040
custom_activation = tanh
@@ -45,9 +45,9 @@ end
4545
# Test that outputs differ when using different activation functions
4646
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
4747
# with no initial state
48-
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
48+
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
4949
# with initial state
50-
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
50+
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
5151
end
5252

5353
@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
@@ -61,9 +61,9 @@ end
6161
@test layer isa GNNRecurrence
6262
@test size(y) == (out_channel, timesteps, g.num_nodes)
6363
# with no initial state
64-
test_gradients(layer, g, x, rtol = RTOL_HIGH)
64+
test_gradients(layer, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
6565
# with initial state
66-
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)
66+
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
6767

6868
# Test with custom activation function
6969
custom_activation = tanh
@@ -74,15 +74,15 @@ end
7474
# Test that outputs differ when using different activation functions
7575
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
7676
# with no initial state
77-
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
77+
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
7878
# with initial state
79-
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)
79+
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
8080

8181
# interplay with GNNChain
8282
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
8383
y = model(g, x)
8484
@test size(y) == (1, timesteps, g.num_nodes)
85-
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
85+
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW, test_mooncake = TEST_MOONCAKE)
8686
end
8787

8888
@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
@@ -93,9 +93,9 @@ end
9393
@test size(h) == (out_channel, g.num_nodes)
9494
@test size(c) == (out_channel, g.num_nodes)
9595
# with no initial state
96-
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
96+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
9797
# with initial state
98-
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
98+
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
9999
end
100100

101101
@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin
@@ -107,15 +107,15 @@ end
107107
y = layer(g, x)
108108
@test size(y) == (out_channel, timesteps, g.num_nodes)
109109
# with no initial state
110-
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
110+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
111111
# with initial state
112-
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
112+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
113113

114114
# interplay with GNNChain
115115
model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1))
116116
y = model(g, x)
117117
@test size(y) == (1, timesteps, g.num_nodes)
118-
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
118+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
119119
end
120120

121121
@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin
@@ -125,9 +125,9 @@ end
125125
@test y === h
126126
@test size(h) == (out_channel, g.num_nodes)
127127
# with no initial state
128-
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
128+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
129129
# with initial state
130-
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
130+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
131131
end
132132

133133

@@ -140,15 +140,15 @@ end
140140
y = layer(g, x)
141141
@test size(y) == (out_channel, timesteps, g.num_nodes)
142142
# with no initial state
143-
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
143+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
144144
# with initial state
145-
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
145+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
146146

147147
# interplay with GNNChain
148148
model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
149149
y = model(g, x)
150150
@test size(y) == (1, timesteps, g.num_nodes)
151-
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
151+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
152152
end
153153

154154
@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin
@@ -158,9 +158,9 @@ end
158158
@test y === h
159159
@test size(h) == (out_channel, g.num_nodes)
160160
# with no initial state
161-
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
161+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
162162
# with initial state
163-
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
163+
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
164164
end
165165

166166
@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin
@@ -172,15 +172,15 @@ end
172172
y = layer(g, x)
173173
@test size(y) == (out_channel, timesteps, g.num_nodes)
174174
# with no initial state
175-
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
175+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
176176
# with initial state
177-
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
177+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
178178

179179
# interplay with GNNChain
180180
model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
181181
y = model(g, x)
182182
@test size(y) == (1, timesteps, g.num_nodes)
183-
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
183+
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
184184
end
185185

186186
@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin
@@ -189,9 +189,9 @@ end
189189
y, state = cell(g, g.x)
190190
@test size(y) == (out_channel, g.num_nodes)
191191
# with no initial state
192-
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
192+
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
193193
# with initial state
194-
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
194+
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
195195
end
196196

197197
@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin
@@ -203,15 +203,15 @@ end
203203
y = layer(g, x)
204204
@test size(y) == (out_channel, timesteps, g.num_nodes)
205205
# with no initial state
206-
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
206+
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
207207
# with initial state
208-
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
208+
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
209209

210210
# interplay with GNNChain
211211
model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1))
212212
y = model(g, x)
213213
@test size(y) == (1, timesteps, g.num_nodes)
214-
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
214+
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
215215
end
216216

217217
# @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin

GraphNeuralNetworks/test/test_module.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,11 @@ function test_gradients(
123123
end
124124

125125
if test_mooncake
126-
# Mooncake gradient with respect to input, compared against Zygote.
126+
# Mooncake gradient with respect to input via Flux integration, compared against Zygote.
127127
loss_mc_x = (xs...) -> loss(f, graph, xs...)
128-
# TODO error without `invokelatest` when using TestItemRunner
129-
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
130-
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...)
128+
y_mc, g_mc = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...)
131129
@assert isapprox(y, y_mc; rtol, atol)
132-
for i in eachindex(xs)
133-
@assert isapprox(g[i], g_mc[i+1]; rtol, atol)
134-
end
130+
check_equal_leaves(g, g_mc; rtol, atol)
135131
end
136132

137133
if test_gpu
@@ -158,14 +154,10 @@ function test_gradients(
158154
end
159155

160156
if test_mooncake
161-
# Mooncake gradient with respect to f, compared against Zygote.
162-
ps_mc, re_mc = Flux.destructure(f)
163-
loss_mc_f = ps -> loss(re_mc(ps), graph, xs...)
164-
_cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc)
165-
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc)
157+
# Mooncake gradient with respect to f via Flux integration, compared against Zygote.
158+
y_mc, g_mc = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f)
166159
@assert isapprox(y, y_mc; rtol, atol)
167-
g_mc_f = (re_mc(g_mc[2]),)
168-
check_equal_leaves(g, g_mc_f; rtol, atol)
160+
check_equal_leaves(g, g_mc_result; rtol, atol)
169161
end
170162

171163
if test_gpu

0 commit comments

Comments
 (0)