@@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
124124 _, chout = l. channel
125125 heads = l. heads
126126
127- Wxi = Wxj = l. dense_x (xj)
128- Wxi = Wxj = reshape (Wxj, chout, heads, :)
129-
130- if xi != = xj
131- Wxi = l. dense_x (xi)
132- Wxi = reshape (Wxi, chout, heads, :)
133- end
127+ Wxj = l. dense_x (xj)
128+ Wxj = reshape (Wxj, chout, heads, :)
129+ Wxi = l. dense_x (xi)
130+ Wxi = reshape (Wxi, chout, heads, :)
134131
135132 # a hand-written message passing
136133 message = Fix1 (gat_message, l)
@@ -139,10 +136,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
139136 α = dropout (α, l. dropout)
140137 β = α .* m. Wxj
141138 x = aggregate_neighbors (g, + , β)
142- width = size (x, 1 )
143139
144140 if ! l. concat
145141 x = mean (x, dims = 2 )
142+ width = size (x, 1 )
143+ else
144+ width = size (x, 1 ) * size (x, 2 )
146145 end
147146 x = reshape (x, width, size (x, 3 )) # return a matrix
148147 x = l. σ .(x .+ l. bias)
@@ -195,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix}
195194
196195 if ! l. concat
197196 x = mean (x, dims = 2 )
197+ width = size (x, 1 )
198+ else
199+ width = size (x, 1 ) * size (x, 2 )
198200 end
199- x = reshape (x, : , size (x, 3 ))
201+ x = reshape (x, width , size (x, 3 ))
200202 x = l. σ .(x .+ l. bias)
201203 return x
202204end
0 commit comments