Skip to content

Commit 8a427e4

Browse files
authored
Skip Mooncake tests on GNNLux and Sparse graphs (#665)
* Test: Skipped Mooncake for failing layers Signed-off-by: Parvm1102 <parvmittal31757@gmail.com> * skipped mooncake on sparse graphs Signed-off-by: Parvm1102 <parvmittal31757@gmail.com> --------- Signed-off-by: Parvm1102 <parvmittal31757@gmail.com>
1 parent c8a579c commit 8a427e4

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

GNNLux/test/layers/temporalconv.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testitem "layers/temporalconv" setup=[TestModuleLux] begin
22
using .TestModuleLux
3-
using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme
3+
using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme, AutoMooncake
44

55
rng = StableRNG(1234)
66
g = rand_graph(rng, 10, 40)
@@ -16,7 +16,7 @@
1616
st = LuxCore.initialstates(rng, l)
1717
y1, _ = l(g, x, ps, st)
1818
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
19-
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
19+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
2020

2121
# Test with custom activation (relu)
2222
l_relu = TGCN(3=>3, act = relu)
@@ -28,46 +28,46 @@
2828
@test !isapprox(y1, y2, rtol=1.0f-2)
2929

3030
loss_relu = (x, ps) -> sum(first(l_relu(g, x, ps, st_relu)))
31-
test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
31+
test_gradients(loss_relu, x, ps_relu; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
3232
end
3333

3434
@testset "A3TGCN" begin
3535
l = A3TGCN(3=>3)
3636
ps = LuxCore.initialparameters(rng, l)
3737
st = LuxCore.initialstates(rng, l)
3838
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
39-
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
39+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
4040
end
4141

4242
@testset "GConvGRU" begin
4343
l = GConvGRU(3=>3, 2)
4444
ps = LuxCore.initialparameters(rng, l)
4545
st = LuxCore.initialstates(rng, l)
4646
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
47-
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
47+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
4848
end
4949

5050
@testset "GConvLSTM" begin
5151
l = GConvLSTM(3=>3, 2)
5252
ps = LuxCore.initialparameters(rng, l)
5353
st = LuxCore.initialstates(rng, l)
5454
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
55-
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
55+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
5656
end
5757

5858
@testset "DCGRU" begin
5959
l = DCGRU(3=>3, 2)
6060
ps = LuxCore.initialparameters(rng, l)
6161
st = LuxCore.initialstates(rng, l)
6262
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
63-
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
63+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
6464
end
6565

6666
@testset "EvolveGCNO" begin
6767
l = EvolveGCNO(3=>3)
6868
ps = LuxCore.initialparameters(rng, l)
6969
st = LuxCore.initialstates(rng, l)
7070
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
71-
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
71+
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
7272
end
73-
end
73+
end

GNNLux/test/test_module.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Reexport: @reexport
3232
@reexport using StableRNGs
3333
@reexport using Random, Statistics
3434

35-
using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme
35+
using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme, AutoMooncake
3636

3737
export test_lux_layer
3838

@@ -71,7 +71,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
7171
else
7272
loss = (x, ps) -> mean(first(l(g, x, ps, st)))
7373
end
74-
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoForwardDiff(), AutoEnzyme()])
74+
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoForwardDiff(), AutoEnzyme(), AutoMooncake()])
7575
end
7676

7777
end

GraphNeuralNetworks/test/test_module.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function test_gradients(
122122
check_equal_leaves(g, g_fd; rtol, atol)
123123
end
124124

125-
if test_mooncake
125+
if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals
126126
# Mooncake gradient with respect to input via Flux integration, compared against Zygote.
127127
loss_mc_x = (xs...) -> loss(f, graph, xs...)
128128
y_mc, g_mc = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...)
@@ -153,7 +153,7 @@ function test_gradients(
153153
check_equal_leaves(g, g_fd; rtol, atol)
154154
end
155155

156-
if test_mooncake
156+
if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals
157157
# Mooncake gradient with respect to f via Flux integration, compared against Zygote.
158158
y_mc, g_mc = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f)
159159
@assert isapprox(y, y_mc; rtol, atol)

0 commit comments

Comments
 (0)