1010 l = GCNConv (D_IN => D_OUT)
1111 for g in TEST_GRAPHS
1212 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
13- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
13+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
1414 end
1515
1616 l = GCNConv (D_IN => D_OUT, tanh, bias = false )
1717 for g in TEST_GRAPHS
1818 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
19- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
19+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
2020 end
2121
2222 l = GCNConv (D_IN => D_OUT, add_self_loops = false )
2323 for g in TEST_GRAPHS
2424 has_isolated_nodes (g) && continue
2525 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
26- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
26+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
2727 end
2828 end
2929
4949 l = GCNConv (1 => 1 , add_self_loops = false , use_edge_weight = true )
5050 @test gradient (w -> sum (l (g, x, w)), w)[1 ] isa AbstractVector{Float32} # redundant test but more explicit
5151 @test size (l (g, x, w)) == (1 , g. num_nodes)
52- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
52+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
5353 end
5454
5555 @testset " conv_weight" begin
8686 for g in TEST_GRAPHS
8787 g = add_self_loops (g)
8888 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
89+ # Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
8990 test_gradients (l, g, g. x, rtol = RTOL_LOW)
9091 end
9192
@@ -124,13 +125,13 @@ end
124125 l = GraphConv (D_IN => D_OUT)
125126 for g in TEST_GRAPHS
126127 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
127- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
128+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
128129 end
129130
130131 l = GraphConv (D_IN => D_OUT, tanh, bias = false , aggr = mean)
131132 for g in TEST_GRAPHS
132133 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
133- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
134+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
134135 end
135136
136137 @testset " bias=false" begin
157158 l = GATConv (D_IN => D_OUT; heads, concat, dropout= 0 )
158159 for g in TEST_GRAPHS
159160 @test size (l (g, g. x)) == (concat ? heads * D_OUT : D_OUT, g. num_nodes)
160- test_gradients (l, g, g. x, rtol = RTOL_LOW)
161+ test_gradients (l, g, g. x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE )
161162 end
162163 end
163164
166167 l = GATConv ((D_IN, ein) => D_OUT, add_self_loops = false , dropout= 0 )
167168 g = GNNGraph (TEST_GRAPHS[1 ], edata = rand (Float32, ein, TEST_GRAPHS[1 ]. num_edges))
168169 @test size (l (g, g. x, g. e)) == (D_OUT, g. num_nodes)
169- test_gradients (l, g, g. x, g. e, rtol = RTOL_LOW)
170+ test_gradients (l, g, g. x, g. e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE )
170171 end
171172
172173 @testset " num params" begin
197198 l = GATv2Conv (D_IN => D_OUT, tanh; heads, concat, dropout= 0 )
198199 for g in TEST_GRAPHS
199200 @test size (l (g, g. x)) == (concat ? heads * D_OUT : D_OUT, g. num_nodes)
201+ # Mooncake backward pass error for this layer on CI
200202 test_gradients (l, g, g. x, rtol = RTOL_LOW, atol= ATOL_LOW)
201203 end
202204 end
206208 l = GATv2Conv ((D_IN, ein) => D_OUT, add_self_loops = false , dropout= 0 )
207209 g = GNNGraph (TEST_GRAPHS[1 ], edata = rand (Float32, ein, TEST_GRAPHS[1 ]. num_edges))
208210 @test size (l (g, g. x, g. e)) == (D_OUT, g. num_nodes)
211+ # Mooncake backward pass error for this layer on CI
209212 test_gradients (l, g, g. x, g. e, rtol = RTOL_LOW, atol= ATOL_LOW)
210213 end
211214
239242
240243 for g in TEST_GRAPHS
241244 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
242- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
245+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
243246 end
244247end
245248
260263 l = EdgeConv (Dense (2 * D_IN, D_OUT), aggr = + )
261264 for g in TEST_GRAPHS
262265 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
263- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
266+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
264267 end
265268end
266269
281284 l = GINConv (nn, 0.01 , aggr = mean)
282285 for g in TEST_GRAPHS
283286 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
284- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
287+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
285288 end
286289
287290 @test ! in (:eps , Flux. trainable (l))
307310 for g in TEST_GRAPHS
308311 g = GNNGraph (g, edata = rand (Float32, edim, g. num_edges))
309312 @test size (l (g, g. x, g. e)) == (D_OUT, g. num_nodes)
310- test_gradients (l, g, g. x, g. e, rtol = RTOL_HIGH)
313+ test_gradients (l, g, g. x, g. e, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
311314 end
312315end
313316
332335 l = SAGEConv (D_IN => D_OUT, tanh, bias = false , aggr = + )
333336 for g in TEST_GRAPHS
334337 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
335- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
338+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
336339 end
337340end
338341
351354 l = ResGatedGraphConv (D_IN => D_OUT, tanh, bias = true )
352355 for g in TEST_GRAPHS
353356 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
354- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
357+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
355358 end
356359end
357360
411414 Flux. trainable (l) == (; β = [1f0 ])
412415 for g in TEST_GRAPHS
413416 @test size (l (g, g. x)) == (D_IN, g. num_nodes)
414- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
417+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
415418 end
416419end
417420
437440 y = l (g, x, e)
438441 return mean (y[1 ]) + sum (y[2 ])
439442 end
440- test_gradients (l, g, g. x, g. e, rtol = RTOL_LOW; loss)
443+ test_gradients (l, g, g. x, g. e, rtol = RTOL_LOW; loss, test_mooncake = TEST_MOONCAKE )
441444 end
442445end
443446
@@ -491,13 +494,13 @@ end
491494 l = SGConv (D_IN => D_OUT, k, add_self_loops = true )
492495 for g in TEST_GRAPHS
493496 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
494- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
497+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
495498 end
496499
497500 l = SGConv (D_IN => D_OUT, k, add_self_loops = true )
498501 for g in TEST_GRAPHS
499502 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
500- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
503+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
501504 end
502505 end
503506end
@@ -520,13 +523,13 @@ end
520523 l = TAGConv (D_IN => D_OUT, k, add_self_loops = true )
521524 for g in TEST_GRAPHS
522525 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
523- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
526+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
524527 end
525528
526529 l = TAGConv (D_IN => D_OUT, k, add_self_loops = true )
527530 for g in TEST_GRAPHS
528531 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
529- test_gradients (l, g, g. x, rtol = RTOL_HIGH)
532+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
530533 end
531534 end
532535end
565568 ein = 2
566569 heads = 3
567570 # used like in Kool et al., 2019
571+ # Mooncake backward pass error for this layer on CI
568572 l = TransformerConv (D_IN * heads => D_IN; heads, add_self_loops = true ,
569573 root_weight = false , ff_channels = 10 , skip_connection = true ,
570574 batch_norm = false )
616620 l = DConv (D_IN => D_OUT, k)
617621 for g in TEST_GRAPHS
618622 @test size (l (g, g. x)) == (D_OUT, g. num_nodes)
623+ # Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
619624 test_gradients (l, g, g. x, rtol = RTOL_HIGH)
620625 end
621626 end
0 commit comments