Skip to content

Commit 9b794a9

Browse files
committed
fix: Handle multiple heads
1 parent e628595 commit 9b794a9

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

GNNlib/src/layers/conv.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
124124
_, chout = l.channel
125125
heads = l.heads
126126

127-
Wxi = Wxj = l.dense_x(xj)
128-
Wxi = Wxj = reshape(Wxj, chout, heads, :)
129-
130-
if xi !== xj
131-
Wxi = l.dense_x(xi)
132-
Wxi = reshape(Wxi, chout, heads, :)
133-
end
127+
Wxj = l.dense_x(xj)
128+
Wxj = reshape(Wxj, chout, heads, :)
129+
Wxi = l.dense_x(xi)
130+
Wxi = reshape(Wxi, chout, heads, :)
134131

135132
# a hand-written message passing
136133
message = Fix1(gat_message, l)
@@ -139,10 +136,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
139136
α = dropout(α, l.dropout)
140137
β = α .* m.Wxj
141138
x = aggregate_neighbors(g, +, β)
142-
width = size(x, 1)
143139

144140
if !l.concat
145141
x = mean(x, dims = 2)
142+
width = size(x, 1)
143+
else
144+
width = size(x, 1) * size(x, 2)
146145
end
147146
x = reshape(x, width, size(x, 3)) # return a matrix
148147
x = l.σ.(x .+ l.bias)
@@ -195,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix}
195194

196195
if !l.concat
197196
x = mean(x, dims = 2)
197+
width = size(x, 1)
198+
else
199+
width = size(x, 1) * size(x, 2)
198200
end
199-
x = reshape(x, :, size(x, 3))
201+
x = reshape(x, width, size(x, 3))
200202
x = l.σ.(x .+ l.bias)
201203
return x
202204
end

0 commit comments

Comments
 (0)