Skip to content

Commit 2155207

Browse files
AmoghAmogh
authored andcommitted
Updated self loop logic and tests
1 parent 8ac413e commit 2155207

File tree

3 files changed

+34
-29
lines changed

3 files changed

+34
-29
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ function Base.show(io::IO, l::AGNNConv)
422422
end
423423

424424
function (l::AGNNConv)(g::AbstractGNNGraph, x, ps, st)
425-
β = l.trainable ? ps.β : l.init_beta
426-
m = (; β, l.add_self_loops)
425+
β = l.trainable ? ps.β[1] : l.init_beta[1]
427426
return GNNlib.agnn_conv(g, x, β; self_loops=l.add_self_loops), st
428427
end
429428

GNNlib/src/layers/conv.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -336,26 +336,22 @@ end
336336

337337
function agnn_conv(g::AbstractGNNGraph, x, β; self_loops=true)
338338

339-
if self_loops && !(g.num_nodes isa Integer)
339+
if self_loops && (!(g isa GNNHeteroGraph) || all(et -> et[1] == et[3], g.etypes))
340340
g = add_self_loops(g)
341341
end
342342

343343
xj, xi = expand_srcdst(g, x)
344-
T = eltype(xj)
345-
ϵ = T(1e-9)
346-
347-
xi_norm = xi ./ sqrt.(sum(abs2, xi, dims=1) .+ ϵ)
348-
xj_norm = xj ./ sqrt.(sum(abs2, xj, dims=1) .+ ϵ)
349-
350-
s, d = edge_index(g)
351344

352-
cos_dist = sum(xi_norm[:, d] .* xj_norm[:, s], dims=1)
345+
xi_norm = xi ./ sqrt.(sum(abs2, xi, dims=1) .+ eps(eltype(xi)))
346+
xj_norm = xj ./ sqrt.(sum(abs2, xj, dims=1) .+ eps(eltype(xj)))
353347

348+
cos_dist = apply_edges(xi_dot_xj, g, xi=xi_norm, xj=xj_norm)
349+
354350
α = softmax_edge_neighbors(g, β .* cos_dist)
355351

356-
m = α .* xj[:, s]
357-
358-
return GNNlib.aggregate_neighbors(g, +, m)
352+
return propagate(g, +; xj=xj, e=α) do xi_i, xj_j, α_e
353+
α_e .* xj_j
354+
end
359355
end
360356

361357
####################### MegNetConv ######################################

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ end
397397
# end
398398

399399
@testitem "AGNNConv" setup=[TolSnippet, TestModule] begin
400-
using .TestModule
400+
using .TestModule
401401
@testset "Initialization & Basic Forward" begin
402402
l = AGNNConv(init_beta=2.0f0, trainable=true)
403403
@test l.β[1] == 2.0f0
@@ -406,25 +406,28 @@ end
406406
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
407407
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
408408
end
409-
end
410-
@testset "Bipartite Support" begin
411-
l = AGNNConv(add_self_loops=false)
412-
in_channel = 16
413-
s = [rand(1:5, 14)..., 5]
414-
d = [rand(1:8, 14)..., 8]
415-
g = GNNGraph((s, d))
416-
x = (randn(Float32, in_channel, 5), randn(Float32, in_channel, 8))
417-
y = l(g, x)
418-
@test size(y) == (in_channel, 8)
419-
test_gradients(l, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
420-
end
409+
end
410+
@testset "Heterogeneous Graph Support" begin
411+
l = AGNNConv(add_self_loops=false)
412+
hg = rand_bipartite_heterograph((10, 15), 20)
413+
x = (A = randn(Float32, D_IN, 10), B = randn(Float32, D_IN, 15))
414+
hetero_layer = HeteroGraphConv(
415+
(:A, :to, :B) => l,
416+
(:B, :to, :A) => l;
417+
aggr = +
418+
)
419+
y = hetero_layer(hg, x)
420+
@test size(y.A) == (D_IN, 10)
421+
@test size(y.B) == (D_IN, 15)
422+
end
421423
@testset "Stability (Epsilon)" begin
422424
l = AGNNConv()
423425
g = TEST_GRAPHS[1]
424426
x_dead = randn(Float32, D_IN, g.num_nodes)
425-
x_dead[:, 1] .= 0.0f0
427+
x_dead[:, 1] .= 0.0f0
426428
y = l(g, x_dead)
427429
@test !any(isnan.(y))
430+
@test !any(isinf.(y))
428431
end
429432
end
430433

@@ -435,7 +438,14 @@ end
435438
g.graph isa AbstractSparseMatrix && continue
436439
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
437440
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false)
438-
end
441+
end
442+
l_bip = AGNNConv(add_self_loops=false)
443+
s = [1, 1, 2, 3]
444+
t = [1, 2, 1, 2]
445+
g = GNNGraph((s, t)) |> gpu
446+
x = (randn(Float32, D_IN, 3) |> gpu, randn(Float32, D_IN, 2) |> gpu)
447+
y = l_bip(g, x)
448+
@test size(y) == (D_IN, 2)
439449
end
440450

441451
@testitem "MEGNetConv" setup=[TolSnippet, TestModule] begin

0 commit comments

Comments
 (0)