Skip to content

Commit 51e1b80

Browse files
AmoghAmogh
authored andcommitted
Adding NNConv support for HeteroGraphConv
1 parent b9cdc0f commit 51e1b80

4 files changed

Lines changed: 32 additions & 4 deletions

File tree

GNNGraphs/src/gnnheterograph/query.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,13 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
8989
end
9090
return gi
9191
end
92+
93+
edge_features(g::GNNHeteroGraph, edge_t::EType) = begin
94+
ds == g.edata[edge_t]
95+
isempty(ds) ? nothing : first(values(ds))
96+
end
97+
98+
edge_features(g::GNNHeteroGraph) = begin
99+
ds = only(values(g.edata))
100+
isempty(ds) ? nothing : first(values(ds))
101+
end

GNNlib/src/layers/conv.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,12 @@ end
257257

258258
####################### NNConv ######################################
259259

260-
function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e)
260+
function nn_conv(l, g::AbstractGNNGraph, x, e)
261261
check_num_nodes(g, x)
262+
xj, xi = expand_srcdst(g, x)
262263
message = Fix1(nn_conv_message, l)
263-
m = propagate(message, g, l.aggr, xj = x, e = e)
264-
return l.σ.(l.weight * x .+ m .+ l.bias)
264+
m = propagate(message, g, l.aggr, xj = xj, e = e)
265+
return l.σ.(l.weight * xi .+ m .+ l.bias)
265266
end
266267

267268
function nn_conv_message(l, xi, xj, e)

GraphNeuralNetworks/src/layers/conv.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,8 @@ end
718718

719719
(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e)
720720

721+
(l::NNConv)(g, x) = GNNlib.nn_conv(l, g, x, edge_features(g))
722+
721723
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
722724

723725
function Base.show(io::IO, l::NNConv)

GraphNeuralNetworks/test/layers/heteroconv.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,22 @@
155155
x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3))
156156
layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh),
157157
(:B, :to, :A) => GCNConv(4 => 2, tanh));
158-
y = layers(g, x);
158+
y = layers(g, x);
159159
@test size(y.A) == (2,2) && size(y.B) == (2,3)
160160
end
161+
162+
@testset "NNConv" begin
163+
nA, nB = 2, 3
164+
nedges = 6
165+
g = rand_bipartite_heterograph((nA, nB), nedges;
166+
edata = Dict((:A, :to, :B) => (e = rand(Float32, 5, nedges),),
167+
(:B, :to, :A) => (e = rand(Float32, 5, nedges),))
168+
)
169+
x = (A = rand(Float32, 4, nA), B = rand(Float32, 4, nB))
170+
nn = Dense(5 => 4 * 2)
171+
layers = HeteroGraphConv((:A, :to, :B) => NNConv(4 => 2, nn, tanh),
172+
(:B, :to, :A) => NNConv(4 => 2, nn, tanh))
173+
y = layers(g, x)
174+
@test size(y.A) == (2, nA) && size(y.B) == (2, nB)
175+
end
161176
end

0 commit comments

Comments
 (0)