Skip to content

Commit 266475b

Browse files
committed
fix: Handle multiple heads
1 parent e628595 commit 266475b

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

GNNlib/src/layers/conv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
139139
α = dropout(α, l.dropout)
140140
β = α .* m.Wxj
141141
x = aggregate_neighbors(g, +, β)
142-
width = size(x, 1)
142+
width = size(x, 1) * size(x, 2)
143143

144144
if !l.concat
145145
x = mean(x, dims = 2)
@@ -193,10 +193,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix}
193193
β = α .* m.Wxj
194194
x = aggregate_neighbors(g, +, β)
195195

196+
width = size(x, 1) * size(x, 2)
196197
if !l.concat
197198
x = mean(x, dims = 2)
198199
end
199-
x = reshape(x, :, size(x, 3))
200+
x = reshape(x, width, size(x, 3))
200201
x = l.σ.(x .+ l.bias)
201202
return x
202203
end

0 commit comments

Comments
 (0)