Skip to content

Commit f43f822

Browse files
authored
Changes attended
1 parent 4b5d41b commit f43f822

1 file changed

Lines changed: 5 additions & 9 deletions

File tree

GNNlib/src/layers/conv.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,11 @@ function agnn_conv(l, g::AbstractGNNGraph, x)
343343
xj, xi = expand_srcdst(g, x)
344344

345345
xi_n = xi ./ sqrt.(sum(xi .^ 2, dims = 1))
346-
xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1))
346+
if xj !== xi
347+
xj_n = xj ./ sqrt.(sum(xj .^ 2, dims = 1))
348+
else
349+
xj_n = xi_n
350+
end
347351
cos_dist = apply_edges(xi_dot_xj, g, xi = xi_n, xj = xj_n)
348352
α = softmax_edge_neighbors(g, l.β .* cos_dist)
349353

@@ -354,14 +358,6 @@ function agnn_conv(l, g::AbstractGNNGraph, x)
354358
return x
355359
end
356360

357-
"""
358-
_has_same_node_types(g::GNNHeteroGraph)
359-
360-
Return true if all edge types in the heterogeneous graph have the same source and
361-
target node types (i.e., no bipartite relations).
362-
"""
363-
_has_same_node_types(g::GNNHeteroGraph) = all(et -> et[1] == et[3], g.etypes)
364-
365361
####################### MegNetConv ######################################
366362

367363
function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)

0 commit comments

Comments
 (0)