Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GraphNeuralNetworks/test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

Flux.testmode!(gnn)

test_gradients(gnn, g, x, rtol = 1e-5)
test_gradients(gnn, g, x, rtol = 1e-5, test_mooncake = false)

@testset "constructor with names" begin
m = GNNChain(GCNConv(din => d),
Expand Down Expand Up @@ -53,7 +53,7 @@

Flux.trainmode!(gnn)

test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4)
test_gradients(gnn, g, x, rtol = 1e-4, atol=1e-4, test_mooncake = false)
end
end

Expand Down
19 changes: 7 additions & 12 deletions GraphNeuralNetworks/test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ end
for g in TEST_GRAPHS
g = add_self_loops(g)
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for ChebConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = false)
end

@testset "bias=false" begin
Expand Down Expand Up @@ -198,8 +197,7 @@ end
l = GATv2Conv(D_IN => D_OUT, tanh; heads, concat, dropout=0)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (concat ? heads * D_OUT : D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -208,8 +206,7 @@ end
l = GATv2Conv((D_IN, ein) => D_OUT, add_self_loops = false, dropout=0)
g = GNNGraph(TEST_GRAPHS[1], edata = rand(Float32, ein, TEST_GRAPHS[1].num_edges))
@test size(l(g, g.x, g.e)) == (D_OUT, g.num_nodes)
# Mooncake backward pass error for this layer on CI
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

@testset "num params" begin
Expand Down Expand Up @@ -568,31 +565,30 @@ end
ein = 2
heads = 3
# used like in Kool et al., 2019
# Mooncake backward pass error for this layer on CI
l = TransformerConv(D_IN * heads => D_IN; heads, add_self_loops = true,
root_weight = false, ff_channels = 10, skip_connection = true,
batch_norm = false)
# batch_norm=false here for tests to pass; true in paper
for g in TEST_GRAPHS
g = GNNGraph(g, ndata = rand(Float32, D_IN * heads, g.num_nodes))
@test size(l(g, g.x)) == (D_IN * heads, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
# used like in Shi et al., 2021
l = TransformerConv((D_IN, ein) => D_IN; heads, gating = true,
bias_qkv = true)
for g in TEST_GRAPHS
g = GNNGraph(g, edata = rand(Float32, ein, g.num_edges))
@test size(l(g, g.x, g.e)) == (D_IN * heads, g.num_nodes)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW)
test_gradients(l, g, g.x, g.e, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
# test averaging heads
l = TransformerConv(D_IN => D_IN; heads, concat = false,
bias_root = false,
root_weight = false)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
test_gradients(l, g, g.x, rtol = RTOL_LOW)
test_gradients(l, g, g.x, rtol = RTOL_LOW, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down Expand Up @@ -620,8 +616,7 @@ end
l = DConv(D_IN => D_OUT, k)
for g in TEST_GRAPHS
@test size(l(g, g.x)) == (D_OUT, g.num_nodes)
# Note: test_mooncake not enabled for DConv (Mooncake backward pass error)
test_gradients(l, g, g.x, rtol = RTOL_HIGH)
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = false)
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions GraphNeuralNetworks/test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2)
@test p(g).gdata.u == u

test_gradients(p, g, g.x, rtol = 1e-5)
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end

Expand All @@ -42,7 +42,7 @@ end
for i in 1:ng])

@test size(p(g, g.x)) == (chout, ng)
test_gradients(p, g, g.x, rtol = 1e-5)
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end

Expand Down
58 changes: 29 additions & 29 deletions GraphNeuralNetworks/test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# Test with custom activation function
custom_activation = tanh
Expand All @@ -45,9 +45,9 @@ end
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol=RTOL_HIGH)
# with no initial state
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell_custom, g, g.x, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH)
test_gradients(cell_custom, g, g.x, h_custom, loss=cell_loss, rtol=RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
end

@testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -61,9 +61,9 @@ end
@test layer isa GNNRecurrence
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol = RTOL_HIGH)
test_gradients(layer, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH)
test_gradients(layer, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# Test with custom activation function
custom_activation = tanh
Expand All @@ -74,15 +74,15 @@ end
# Test that outputs differ when using different activation functions
@test !isapprox(y, y_custom, rtol = RTOL_HIGH)
# with no initial state
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH)
test_gradients(layer_custom, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH)
test_gradients(layer_custom, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)

# interplay with GNNChain
model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -93,9 +93,9 @@ end
@test size(h) == (out_channel, g.num_nodes)
@test size(c) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -107,15 +107,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -125,9 +125,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end


Expand All @@ -140,15 +140,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -158,9 +158,9 @@ end
@test y === h
@test size(h) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -172,15 +172,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)

# interplay with GNNChain
model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false)
end

@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -189,9 +189,9 @@ end
y, state = cell(g, g.x)
@test size(y) == (out_channel, g.num_nodes)
# with no initial state
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
# with initial state
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = false)
end

@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin
Expand All @@ -203,15 +203,15 @@ end
y = layer(g, x)
@test size(y) == (out_channel, timesteps, g.num_nodes)
# with no initial state
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
# with initial state
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)

# interplay with GNNChain
model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1))
y = model(g, x)
@test size(y) == (1, timesteps, g.num_nodes)
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW)
test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW, test_mooncake = TEST_MOONCAKE)
end

# @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin
Expand Down
24 changes: 10 additions & 14 deletions GraphNeuralNetworks/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,13 @@ function test_gradients(
end

if test_mooncake
# Mooncake gradient with respect to input, compared against Zygote.
# Mooncake gradient with respect to input via Flux integration, compared against Zygote.
loss_mc_x = (xs...) -> loss(f, graph, xs...)
# TODO error without `invokelatest` when using TestItemRunner
_cache_x = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_x, xs...)
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_x, loss_mc_x, xs...)
result = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...)
y_mc = result.val # Extract value from NamedTuple
g_mc = result.grad # Extract gradients tuple
Comment thread
CarloLucibello marked this conversation as resolved.
Outdated
@assert isapprox(y, y_mc; rtol, atol)
for i in eachindex(xs)
@assert isapprox(g[i], g_mc[i+1]; rtol, atol)
end
check_equal_leaves(g, g_mc; rtol, atol)
end

if test_gpu
Expand All @@ -158,14 +156,12 @@ function test_gradients(
end

if test_mooncake
# Mooncake gradient with respect to f, compared against Zygote.
ps_mc, re_mc = Flux.destructure(f)
loss_mc_f = ps -> loss(re_mc(ps), graph, xs...)
_cache_f = Base.invokelatest(Mooncake.prepare_gradient_cache, loss_mc_f, ps_mc)
y_mc, g_mc = Base.invokelatest(Mooncake.value_and_gradient!!, _cache_f, loss_mc_f, ps_mc)
# Mooncake gradient with respect to f via Flux integration, compared against Zygote.
result = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f)
y_mc = result.val # Extract value from NamedTuple
g_mc_result = result.grad # Extract gradients tuple
Comment thread
CarloLucibello marked this conversation as resolved.
Outdated
@assert isapprox(y, y_mc; rtol, atol)
g_mc_f = (re_mc(g_mc[2]),)
check_equal_leaves(g, g_mc_f; rtol, atol)
check_equal_leaves(g, g_mc_result; rtol, atol)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
check_equal_leaves(g, g_mc_result; rtol, atol)
check_equal_leaves(g, g_mc; rtol, atol)

@CarloLucibello The g_mc_result should also be changed to g_mc, otherwise the tests are failing. Should I open a new pr fixing the mistake?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops, sorry, yes please fix

end

if test_gpu
Expand Down
Loading