Skip to content

Commit 8ac413e

Browse files
AmoghAmogh
authored andcommitted
Added support for HeteroGraphConv in AGNNConv
1 parent 00ebef3 commit 8ac413e

7 files changed

Lines changed: 71 additions & 35 deletions

File tree

GNNLux/src/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,10 @@ function Base.show(io::IO, l::AGNNConv)
421421
print(io, ")")
422422
end
423423

424-
function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
424+
function (l::AGNNConv)(g::AbstractGNNGraph, x, ps, st)
425425
β = l.trainable ? ps.β : l.init_beta
426426
m = (; β, l.add_self_loops)
427-
return GNNlib.agnn_conv(m, g, x), st
427+
return GNNlib.agnn_conv(g, x, β; self_loops=l.add_self_loops), st
428428
end
429429

430430
@doc raw"""

GNNlib/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GNNlib"
22
uuid = "a6a84749-d869-43f8-aacc-be26a1996e48"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "1.2.0"
4+
version = "1.3.0-DEV"
55

66
[workspace]
77
projects = ["test", "docs", "../GNNGraphs"]

GNNlib/src/layers/conv.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,21 +334,28 @@ end
334334

335335
####################### AGNNConv ######################################
336336

337-
function agnn_conv(l, g::GNNGraph, x::AbstractMatrix)
338-
check_num_nodes(g, x)
339-
if l.add_self_loops
337+
function agnn_conv(g::AbstractGNNGraph, x, β; self_loops=true)
338+
339+
if self_loops && !(g.num_nodes isa Integer)
340340
g = add_self_loops(g)
341341
end
342342

343-
xn = x ./ sqrt.(sum(x .^ 2, dims = 1))
344-
cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn)
345-
α = softmax_edge_neighbors(g, l.β .* cos_dist)
343+
xj, xi = expand_srcdst(g, x)
344+
T = eltype(xj)
345+
ϵ = T(1e-9)
346346

347-
x = propagate(g, +; xj = x, e = α) do xi, xj, α
348-
α .* xj
349-
end
347+
xi_norm = xi ./ sqrt.(sum(abs2, xi, dims=1) .+ ϵ)
348+
xj_norm = xj ./ sqrt.(sum(abs2, xj, dims=1) .+ ϵ)
350349

351-
return x
350+
s, d = edge_index(g)
351+
352+
cos_dist = sum(xi_norm[:, d] .* xj_norm[:, s], dims=1)
353+
354+
α = softmax_edge_neighbors(g, β .* cos_dist)
355+
356+
m = α .* xj[:, s]
357+
358+
return GNNlib.aggregate_neighbors(g, +, m)
352359
end
353360

354361
####################### MegNetConv ######################################

GraphNeuralNetworks/Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "1.1.0"
5-
6-
[workspace]
7-
projects = ["test", "docs"]
4+
version = "1.2.0-DEV"
85

96
[deps]
7+
AutoStructs = "2e0df379-9877-4907-ab94-cd881f8d985b"
108
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
119
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1210
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -21,6 +19,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2119
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2220

2321
[compat]
22+
AutoStructs = "0.1.1"
2423
ChainRulesCore = "1"
2524
ConcreteStructs = "0.2.3"
2625
Flux = "0.16.0"
@@ -34,3 +33,6 @@ Random = "1"
3433
Reexport = "1"
3534
Statistics = "1"
3635
julia = "1.10"
36+
37+
[workspace]
38+
projects = ["test", "docs"]

GraphNeuralNetworks/src/layers/conv.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# the implementation of a single layer,
55
# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code.
66

7+
using Flux, GraphNeuralNetworks, AutoStructs
8+
79
@doc raw"""
810
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
911
@@ -999,7 +1001,9 @@ function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true)
9991001
AGNNConv([init_beta], add_self_loops, trainable)
10001002
end
10011003

1002-
(l::AGNNConv)(g, x) = GNNlib.agnn_conv(l, g, x)
1004+
function (l::AGNNConv)(g::AbstractGNNGraph, x)
1005+
return GNNlib.agnn_conv(g, x, l.β[1]; self_loops=l.add_self_loops)
1006+
end
10031007

10041008
@doc raw"""
10051009
MEGNetConv(ϕe, ϕv; aggr=mean)

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -398,29 +398,42 @@ end
398398

399399
@testitem "AGNNConv" setup=[TolSnippet, TestModule] begin
400400
using .TestModule
401-
l = AGNNConv(trainable=false, add_self_loops=false)
402-
@test l.β == [1.0f0]
403-
@test l.add_self_loops == false
404-
@test l.trainable == false
405-
Flux.trainable(l) == (;)
406-
407-
l = AGNNConv(init_beta=2.0f0)
408-
@test l.β == [2.0f0]
409-
@test l.add_self_loops == true
410-
@test l.trainable == true
411-
Flux.trainable(l) == (; β = [1f0])
412-
for g in TEST_GRAPHS
413-
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
414-
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
401+
@testset "Initialization & Basic Forward" begin
402+
l = AGNNConv(init_beta=2.0f0, trainable=true)
403+
@test l.β[1] == 2.0f0
404+
@test l.trainable == true
405+
for g in TEST_GRAPHS
406+
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
407+
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
408+
end
409+
end
410+
@testset "Bipartite Support" begin
411+
l = AGNNConv(add_self_loops=false)
412+
in_channel = 16
413+
s = [rand(1:5, 14)..., 5]
414+
d = [rand(1:8, 14)..., 8]
415+
g = GNNGraph((s, d))
416+
x = (randn(Float32, in_channel, 5), randn(Float32, in_channel, 8))
417+
y = l(g, x)
418+
@test size(y) == (in_channel, 8)
419+
test_gradients(l, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
420+
end
421+
@testset "Stability (Epsilon)" begin
422+
l = AGNNConv()
423+
g = TEST_GRAPHS[1]
424+
x_dead = randn(Float32, D_IN, g.num_nodes)
425+
x_dead[:, 1] .= 0.0f0
426+
y = l(g, x_dead)
427+
@test !any(isnan.(y))
415428
end
416429
end
417430

418431
@testitem "AGNNConv GPU" setup=[TolSnippet, TestModule] tags=[:gpu] begin
419432
using .TestModule
420-
l = AGNNConv(trainable=false, add_self_loops=false)
433+
l = AGNNConv()
421434
for g in TEST_GRAPHS
422-
g.graph isa AbstractSparseMatrix && continue
423-
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
435+
g.graph isa AbstractSparseMatrix && continue
436+
@test size(l(g, g.x)) == (D_IN, g.num_nodes)
424437
test_gradients(l, g, g.x, rtol = RTOL_HIGH, test_gpu = true, compare_finite_diff = false)
425438
end
426439
end

GraphNeuralNetworks/test/layers/heteroconv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,14 @@
158158
y = layers(g, x);
159159
@test size(y.A) == (2,2) && size(y.B) == (2,3)
160160
end
161+
162+
@testset "AGNNConv" begin
163+
x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3))
164+
165+
layers = HeteroGraphConv((:A, :to, :B) => AGNNConv(add_self_loops=false),
166+
(:B, :to, :A) => AGNNConv(add_self_loops=false))
167+
168+
y = layers(hg, x)
169+
@test size(y.A) == (4, 2) && size(y.B) == (4, 3)
170+
end
161171
end

0 commit comments

Comments
 (0)