Skip to content

Commit c4c69f4

Browse files
authored
Test: Add Mooncake AD testing to conv layer test infrastructure (JuliaGraphs#641)
* test: Add Mooncake AD testing to conv layer test infrastructure Signed-off-by: Parvm1102 <parvmittal31757@gmail.com>
1 parent 5b25c89 commit c4c69f4

4 files changed

Lines changed: 61 additions & 24 deletions

File tree

GraphNeuralNetworks/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
88
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
99
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1010
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
11+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1112
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ end
1010
l = GCNConv(D_IN => D_OUT)
1111
for g in TEST_GRAPHS
1212
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
13-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
13+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
1414
end
1515

1616
l = GCNConv(D_IN => D_OUT, tanh, bias = false)
1717
for g in TEST_GRAPHS
1818
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
19-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
19+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
2020
end
2121

2222
l = GCNConv(D_IN => D_OUT, add_self_loops = false)
2323
for g in TEST_GRAPHS
2424
has_isolated_nodes(g) && continue
2525
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
26-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
26+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
2727
end
2828
end
2929

@@ -49,7 +49,7 @@ end
4949
l = GCNConv(1 => 1, add_self_loops = false, use_edge_weight = true)
5050
@test gradient(w -> sum(l(g, x, w)), w)[1] isa AbstractVector{Float32} # redundant test but more explicit
5151
@test size(l(g, x, w)) == (1, g.num_nodes)
52-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
52+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
5353
end
5454

5555
@testset "conv_weight" begin
@@ -86,6 +86,7 @@ end
8686
for g in TEST_GRAPHS
8787
g = add_self_loops(g)
8888
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
89+
# Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
8990
test_gradients(l, g, g.x, rtol = RTOL_LOW)
9091
end
9192

@@ -124,13 +125,13 @@ end
124125
l = GraphConv(D_IN => D_OUT)
125126
for g in TEST_GRAPHS
126127
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
127-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
128+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
128129
end
129130

130131
l = GraphConv(D_IN => D_OUT, tanh, bias = false, aggr = mean)
131132
for g in TEST_GRAPHS
132133
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
133-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
134+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
134135
end
135136

136137
@testset "bias=false" begin
@@ -157,7 +158,7 @@ end
157158
l = GATConv(D_IN => D_OUT; heads, concat, dropout=0)
158159
for g in TEST_GRAPHS
159160
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
160-
test_gradients(l, g, g.x, rtol = RTOL_LOW)
161+
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
161162
end
162163
end
163164

@@ -166,7 +167,7 @@ end
166167
l = GATConv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
167168
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
168169
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
169-
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW)
170+
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
170171
end
171172

172173
@testset "num params" begin
@@ -197,6 +198,7 @@ end
197198
l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0)
198199
for g in TEST_GRAPHS
199200
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
201+
# Mooncake backward pass error for this layer on CI
200202
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW)
201203
end
202204
end
@@ -206,6 +208,7 @@ end
206208
l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
207209
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
208210
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
211+
# Mooncake backward pass error for this layer on CI
209212
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW)
210213
end
211214

@@ -239,7 +242,7 @@ end
239242

240243
for g in TEST_GRAPHS
241244
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
242-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
245+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
243246
end
244247
end
245248

@@ -260,7 +263,7 @@ end
260263
l = EdgeConv(Dense(2 * D_IN, D_OUT), aggr = +)
261264
for g in TEST_GRAPHS
262265
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
263-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
266+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
264267
end
265268
end
266269

@@ -281,7 +284,7 @@ end
281284
l = GINConv(nn, 0.01, aggr = mean)
282285
for g in TEST_GRAPHS
283286
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
284-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
287+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
285288
end
286289

287290
@test !in(:eps, Flux.trainable(l))
@@ -307,7 +310,7 @@ end
307310
for g in TEST_GRAPHS
308311
g = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))
309312
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
310-
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH)
313+
test_gradients(l, g, g.x, g.e, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
311314
end
312315
end
313316

@@ -332,7 +335,7 @@ end
332335
l = SAGEConv(D_IN => D_OUT, tanh, bias = false, aggr = +)
333336
for g in TEST_GRAPHS
334337
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
335-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
338+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
336339
end
337340
end
338341

@@ -351,7 +354,7 @@ end
351354
l = ResGatedGraphConv(D_IN => D_OUT, tanh, bias = true)
352355
for g in TEST_GRAPHS
353356
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
354-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
357+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
355358
end
356359
end
357360

@@ -411,7 +414,7 @@ end
411414
Flux.trainable(l) == (; β = [1f0])
412415
for g in TEST_GRAPHS
413416
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
414-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
417+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
415418
end
416419
end
417420

@@ -437,7 +440,7 @@ end
437440
y = l(g, x, e)
438441
return mean(y[1]) + sum(y[2])
439442
end
440-
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW; loss)
443+
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW; loss, test_mooncake = TEST_MOONCAKE)
441444
end
442445
end
443446

@@ -491,13 +494,13 @@ end
491494
l = SGConv(D_IN => D_OUT, k, add_self_loops = true)
492495
for g in TEST_GRAPHS
493496
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
494-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
497+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
495498
end
496499

497500
l = SGConv(D_IN => D_OUT, k, add_self_loops = true)
498501
for g in TEST_GRAPHS
499502
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
500-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
503+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
501504
end
502505
end
503506
end
@@ -520,13 +523,13 @@ end
520523
l = TAGConv(D_IN => D_OUT, k, add_self_loops = true)
521524
for g in TEST_GRAPHS
522525
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
523-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
526+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
524527
end
525528

526529
l = TAGConv(D_IN => D_OUT, k, add_self_loops = true)
527530
for g in TEST_GRAPHS
528531
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
529-
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
532+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
530533
end
531534
end
532535
end
@@ -565,6 +568,7 @@ end
565568
ein = 2
566569
heads = 3
567570
# used like in Kool et al., 2019
571+
# Mooncake backward pass error for this layer on CI
568572
l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true,
569573
root_weight = false, ff_channels = 10, skip_connection = true,
570574
batch_norm = false)
@@ -616,6 +620,7 @@ end
616620
l = DConv(D_IN => D_OUT, k)
617621
for g in TEST_GRAPHS
618622
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
623+
# Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
619624
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
620625
end
621626
end

GraphNeuralNetworks/test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using TestItemRunner
44
## for how to run the tests within VS Code.
55
## See test_module.jl for the test infrastructure.
66

7+
const TEST_MOONCAKE = VERSION >= v"1.12"
8+
79
## Uncomment below to change the default test settings
810
# ENV["GNN_TEST_CPU"] = "false"
911
# ENV["GNN_TEST_CUDA"] = "true"

GraphNeuralNetworks/test/test_module.jl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ using ChainRulesTestUtils, FiniteDifferences
2929
using Zygote: Zygote
3030
using SparseArrays
3131

32+
# Mooncake.jl requires Julia >= 1.12
33+
const TEST_MOONCAKE = VERSION >= v"1.12"
34+
if TEST_MOONCAKE
35+
import Mooncake
36+
end
3237

3338
# from Base
3439
export mean, randn, SparseArrays, AbstractSparseMatrix
@@ -45,7 +50,7 @@ export random_regular_graph, erdos_renyi
4550
# from this module
4651
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
4752
test_gradients, finitediff_withgradient,
48-
check_equal_leaves, gpu_backend
53+
check_equal_leaves, gpu_backend, TEST_MOONCAKE
4954

5055

5156
const D_IN = 3
@@ -82,12 +87,13 @@ function test_gradients(
8287
test_grad_f = true,
8388
test_grad_x = true,
8489
compare_finite_diff = true,
90+
test_mooncake = false,
8591
loss = (f, g, xs...) -> mean(f(g, xs...)),
8692
)
8793

88-
if !test_gpu && !compare_finite_diff
89-
error("You should either compare finite diff vs CPU AD \
90-
or CPU AD vs GPU AD.")
94+
if !test_gpu && !compare_finite_diff && !test_mooncake
95+
error("You should either compare finite diff vs CPU AD, \
96+
CPU AD vs GPU AD, or test Mooncake AD.")
9197
end
9298

9399
## Let's make sure first that the forward pass works.
@@ -116,6 +122,18 @@ function test_gradients(
116122
check_equal_leaves(g, g_fd; rtol, atol)
117123
end
118124

125+
if test_mooncake
126+
# Mooncake gradient with respect to input, compared against Zygote.
127+
loss_mc_x = (xs...) -> loss(f, graph, xs...)
128+
# TODO error without `invokelatest` when using TestItemRunner
129+
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
130+
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...)
131+
@assert isapprox(y, y_mc; rtol, atol)
132+
for i in eachindex(xs)
133+
@assert isapprox(g[i], g_mc[i+1]; rtol, atol)
134+
end
135+
end
136+
119137
if test_gpu
120138
# Zygote gradient with respect to input on GPU.
121139
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...)
@@ -139,6 +157,17 @@ function test_gradients(
139157
check_equal_leaves(g, g_fd; rtol, atol)
140158
end
141159

160+
if test_mooncake
161+
# Mooncake gradient with respect to f, compared against Zygote.
162+
ps_mc, re_mc = Flux.destructure(f)
163+
loss_mc_f = ps -> loss(re_mc(ps), graph, xs...)
164+
_cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc)
165+
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc)
166+
@assert isapprox(y, y_mc; rtol, atol)
167+
g_mc_f = (re_mc(g_mc[2]),)
168+
check_equal_leaves(g, g_mc_f; rtol, atol)
169+
end
170+
142171
if test_gpu
143172
# Zygote gradient with respect to f on GPU.
144173
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)

0 commit comments

Comments
 (0)