@@ -187,6 +187,98 @@ function Base.show(io::IO, a3tgcn::A3TGCN)
187187 print (io, " A3TGCN($(a3tgcn. in) => $(a3tgcn. out) )" )
188188end
189189
190+ struct GConvGRUCell <: GNNLayer
191+ conv_x_r:: ChebConv
192+ conv_h_r:: ChebConv
193+ conv_x_z:: ChebConv
194+ conv_h_z:: ChebConv
195+ conv_x_h:: ChebConv
196+ conv_h_h:: ChebConv
197+ k:: Int
198+ state0
199+ in:: Int
200+ out:: Int
201+ end
202+
203+ Flux. @functor GConvGRUCell
204+
205+ function GConvGRUCell (ch:: Pair{Int, Int} , k:: Int , n:: Int ;
206+ bias:: Bool = true ,
207+ init = Flux. glorot_uniform,
208+ init_state = Flux. zeros32)
209+ in, out = ch
210+ # reset gate
211+ conv_x_r = ChebConv (in => out, k; bias, init)
212+ conv_h_r = ChebConv (out => out, k; bias, init)
213+ # update gate
214+ conv_x_z = ChebConv (in => out, k; bias, init)
215+ conv_h_z = ChebConv (out => out, k; bias, init)
216+ # new gate
217+ conv_x_h = ChebConv (in => out, k; bias, init)
218+ conv_h_h = ChebConv (out => out, k; bias, init)
219+ state0 = init_state (out, n)
220+ return GConvGRUCell (conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out)
221+ end
222+
223+ function (ggru:: GConvGRUCell )(h, g:: GNNGraph , x)
224+ r = ggru. conv_x_r (g, x) .+ ggru. conv_h_r (g, h)
225+ r = Flux. sigmoid_fast (r)
226+ z = ggru. conv_x_z (g, x) .+ ggru. conv_h_z (g, h)
227+ z = Flux. sigmoid_fast (z)
228+ h̃ = ggru. conv_x_h (g, x) .+ ggru. conv_h_h (g, r .* h)
229+ h̃ = Flux. tanh_fast (h̃)
230+ h = (1 .- z) .* h̃ .+ z .* h
231+ return h, h
232+ end
233+
234+ function Base. show (io:: IO , ggru:: GConvGRUCell )
235+ print (io, " GConvGRUCell($(ggru. in) => $(ggru. out) )" )
236+ end
237+
238+ """
239+ GConvGRU(in => out, k, n; [bias, init, init_state])
240+
241+ Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
242+
243+ Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
244+
245+ # Arguments
246+
247+ - `in`: Number of input features.
248+ - `out`: Number of output features.
249+ - `k`: Chebyshev polynomial order.
250+ - `n`: Number of nodes in the graph.
251+ - `bias`: Add learnable bias. Default `true`.
252+ - `init`: Weights' initializer. Default `glorot_uniform`.
253+ - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`.
254+
255+ # Examples
256+
257+ ```jldoctest
258+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
259+
260+ julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes);
261+
262+ julia> y = ggru(g1, x1);
263+
264+ julia> size(y)
265+ (5, 5)
266+
267+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
268+
269+ julia> z = ggru(g2, x2);
270+
271+ julia> size(z)
272+ (5, 5, 30)
273+ ```
274+ """
275+ GConvGRU (ch, k, n; kwargs... ) = Flux. Recur (GConvGRUCell (ch, k, n; kwargs... ))
276+ Flux. Recur (ggru:: GConvGRUCell ) = Flux. Recur (ggru, ggru. state0)
277+
278+ (l:: Flux.Recur{GConvGRUCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
279+ _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) = l (g, x)
280+ _applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) = l (g)
281+
190282function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
191283 return l .(tg. snapshots, x)
192284end
0 commit comments