@@ -279,6 +279,128 @@ Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0)
279279_applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph , x) = l (g, x)
280280_applylayer (l:: Flux.Recur{GConvGRUCell} , g:: GNNGraph ) = l (g)
281281
282+ struct GConvLSTMCell <: GNNLayer
283+ conv_x_i:: ChebConv
284+ conv_h_i:: ChebConv
285+ w_i
286+ b_i
287+ conv_x_f:: ChebConv
288+ conv_h_f:: ChebConv
289+ w_f
290+ b_f
291+ conv_x_c:: ChebConv
292+ conv_h_c:: ChebConv
293+ w_c
294+ b_c
295+ conv_x_o:: ChebConv
296+ conv_h_o:: ChebConv
297+ w_o
298+ b_o
299+ k:: Int
300+ state0
301+ in:: Int
302+ out:: Int
303+ end
304+
305+ Flux. @functor GConvLSTMCell
306+
307+ function GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int , n:: Int ;
308+ bias:: Bool = true ,
309+ init = Flux. glorot_uniform,
310+ init_state = Flux. zeros32)
311+ in, out = ch
312+ # input gate
313+ conv_x_i = ChebConv (in => out, k; bias, init)
314+ conv_h_i = ChebConv (out => out, k; bias, init)
315+ w_i = init (out, 1 )
316+ b_i = bias ? Flux. create_bias (w_i, true , out) : false
317+ # forget gate
318+ conv_x_f = ChebConv (in => out, k; bias, init)
319+ conv_h_f = ChebConv (out => out, k; bias, init)
320+ w_f = init (out, 1 )
321+ b_f = bias ? Flux. create_bias (w_f, true , out) : false
322+ # cell state
323+ conv_x_c = ChebConv (in => out, k; bias, init)
324+ conv_h_c = ChebConv (out => out, k; bias, init)
325+ w_c = init (out, 1 )
326+ b_c = bias ? Flux. create_bias (w_c, true , out) : false
327+ # output gate
328+ conv_x_o = ChebConv (in => out, k; bias, init)
329+ conv_h_o = ChebConv (out => out, k; bias, init)
330+ w_o = init (out, 1 )
331+ b_o = bias ? Flux. create_bias (w_o, true , out) : false
332+ state0 = (init_state (out, n), init_state (out, n))
333+ return GConvLSTMCell (conv_x_i, conv_h_i, w_i, b_i,
334+ conv_x_f, conv_h_f, w_f, b_f,
335+ conv_x_c, conv_h_c, w_c, b_c,
336+ conv_x_o, conv_h_o, w_o, b_o,
337+ k, state0, in, out)
338+ end
339+
340+ function (gclstm:: GConvLSTMCell )((h, c), g:: GNNGraph , x)
341+ # input gate
342+ i = gclstm. conv_x_i (g, x) .+ gclstm. conv_h_i (g, h) .+ gclstm. w_i .* c .+ gclstm. b_i
343+ i = Flux. sigmoid_fast (i)
344+ # forget gate
345+ f = gclstm. conv_x_f (g, x) .+ gclstm. conv_h_f (g, h) .+ gclstm. w_f .* c .+ gclstm. b_f
346+ f = Flux. sigmoid_fast (f)
347+ # cell state
348+ c = f .* c .+ i .* Flux. tanh_fast (gclstm. conv_x_c (g, x) .+ gclstm. conv_h_c (g, h) .+ gclstm. w_c .* c .+ gclstm. b_c)
349+ # output gate
350+ o = gclstm. conv_x_o (g, x) .+ gclstm. conv_h_o (g, h) .+ gclstm. w_o .* c .+ gclstm. b_o
351+ o = Flux. sigmoid_fast (o)
352+ h = o .* Flux. tanh_fast (c)
353+ return (h,c), h
354+ end
355+
356+ function Base. show (io:: IO , gclstm:: GConvLSTMCell )
357+ print (io, " GConvLSTMCell($(gclstm. in) => $(gclstm. out) )" )
358+ end
359+
360+ """
361+ GConvLSTM(in => out, k, n; [bias, init, init_state])
362+
363+ Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659).
364+
365+ Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
366+
367+ # Arguments
368+
369+ - `in`: Number of input features.
370+ - `out`: Number of output features.
371+ - `k`: Chebyshev polynomial order.
372+ - `n`: Number of nodes in the graph.
373+ - `bias`: Add learnable bias. Default `true`.
374+ - `init`: Weights' initializer. Default `glorot_uniform`.
375+ - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
376+
377+ # Examples
378+
379+ ```jldoctest
380+ julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5);
381+
382+ julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes);
383+
384+ julia> y = gclstm(g1, x1);
385+
386+ julia> size(y)
387+ (5, 5)
388+
389+ julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30);
390+
391+ julia> z = gclstm(g2, x2);
392+
393+ julia> size(z)
394+ (5, 5, 30)
395+ ```
396+ """
397+ GConvLSTM (ch, k, n; kwargs... ) = Flux. Recur (GConvLSTMCell (ch, k, n; kwargs... ))
398+ Flux. Recur (tgcn:: GConvLSTMCell ) = Flux. Recur (tgcn, tgcn. state0)
399+
400+ (l:: Flux.Recur{GConvLSTMCell} )(g:: GNNGraph ) = GNNGraph (g, ndata = l (g, node_features (g)))
401+ _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph , x) = l (g, x)
402+ _applylayer (l:: Flux.Recur{GConvLSTMCell} , g:: GNNGraph ) = l (g)
403+
282404function (l:: GINConv )(tg:: TemporalSnapshotsGNNGraph , x:: AbstractVector )
283405 return l .(tg. snapshots, x)
284406end
0 commit comments