@@ -255,3 +255,264 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
255255end
256256
257257
258+ @concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
259+ ϕe
260+ ϕx
261+ ϕh
262+ num_features
263+ residual:: Bool
264+ end
265+
266+ function EGNNConv (ch:: Pair{Int, Int} , hidden_size = 2 * ch[1 ]; residual = false )
267+ return EGNNConv ((ch[1 ], 0 ) => ch[2 ]; hidden_size, residual)
268+ end
269+
270+ # Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
271+ function EGNNConv (ch:: Pair{NTuple{2, Int}, Int} ; hidden_size:: Int = 2 * ch[1 ][1 ],
272+ residual = false )
273+ (in_size, edge_feat_size), out_size = ch
274+ act_fn = swish
275+
276+ # +1 for the radial feature: ||x_i - x_j||^2
277+ ϕe = Chain (Dense (in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
278+ Dense (hidden_size => hidden_size, act_fn))
279+
280+ ϕh = Chain (Dense (in_size + hidden_size => hidden_size, swish),
281+ Dense (hidden_size => out_size))
282+
283+ ϕx = Chain (Dense (hidden_size => hidden_size, swish),
284+ Dense (hidden_size => 1 , use_bias = false ))
285+
286+ num_features = (in = in_size, edge = edge_feat_size, out = out_size,
287+ hidden = hidden_size)
288+ if residual
289+ @assert in_size== out_size " Residual connection only possible if in_size == out_size"
290+ end
291+ return EGNNConv (ϕe, ϕx, ϕh, num_features, residual)
292+ end
293+
294+ LuxCore. outputsize (l:: EGNNConv ) = (l. num_features. out,)
295+
296+ (l:: EGNNConv )(g, h, x, ps, st) = l (g, h, x, nothing , ps, st)
297+
298+ function (l:: EGNNConv )(g, h, x, e, ps, st)
299+ ϕe = StatefulLuxLayer {true} (l. ϕe, ps. ϕe, _getstate (st, :ϕe ))
300+ ϕx = StatefulLuxLayer {true} (l. ϕx, ps. ϕx, _getstate (st, :ϕx ))
301+ ϕh = StatefulLuxLayer {true} (l. ϕh, ps. ϕh, _getstate (st, :ϕh ))
302+ m = (; ϕe, ϕx, ϕh, l. residual, l. num_features)
303+ return GNNlib. egnn_conv (m, g, h, x, e), st
304+ end
305+
306+ function Base. show (io:: IO , l:: EGNNConv )
307+ ne = l. num_features. edge
308+ nin = l. num_features. in
309+ nout = l. num_features. out
310+ nh = l. num_features. hidden
311+ print (io, " EGNNConv(($nin , $ne ) => $nout ; hidden_size=$nh " )
312+ if l. residual
313+ print (io, " , residual=true" )
314+ end
315+ print (io, " )" )
316+ end
317+
318+ @concrete struct DConv <: GNNLayer
319+ in_dims:: Int
320+ out_dims:: Int
321+ k:: Int
322+ init_weight
323+ init_bias
324+ use_bias:: Bool
325+ end
326+
327+ function DConv (ch:: Pair{Int, Int} , k:: Int ;
328+ init_weight = glorot_uniform,
329+ init_bias = zeros32,
330+ use_bias = true )
331+ in, out = ch
332+ return DConv (in, out, k, init_weight, init_bias, use_bias)
333+ end
334+
335+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: DConv )
336+ weights = l. init_weight (rng, 2 , l. k, l. out_dims, l. in_dims)
337+ if l. use_bias
338+ bias = l. init_bias (rng, l. out_dims)
339+ return (; weights, bias)
340+ else
341+ return (; weights)
342+ end
343+ end
344+
345+ LuxCore. outputsize (l:: DConv ) = (l. out_dims,)
346+ LuxCore. parameterlength (l:: DConv ) = l. use_bias ? 2 * l. in_dims * l. out_dims * l. k + l. out_dims :
347+ 2 * l. in_dims * l. out_dims * l. k
348+
349+ function (l:: DConv )(g, x, ps, st)
350+ m = (; ps. weights, bias = _getbias (ps), l. k)
351+ return GNNlib. d_conv (m, g, x), st
352+ end
353+
354+ function Base. show (io:: IO , l:: DConv )
355+ print (io, " DConv($(l. in_dims) => $(l. out_dims) , k=$(l. k) )" )
356+ end
357+
358+ @concrete struct GATConv <: GNNLayer
359+ dense_x
360+ dense_e
361+ init_weight
362+ init_bias
363+ use_bias:: Bool
364+ σ
365+ negative_slope
366+ channel:: Pair{NTuple{2, Int}, Int}
367+ heads:: Int
368+ concat:: Bool
369+ add_self_loops:: Bool
370+ dropout
371+ end
372+
373+
374+ GATConv (ch:: Pair{Int, Int} , args... ; kws... ) = GATConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
375+
376+ function GATConv (ch:: Pair{NTuple{2, Int}, Int} , σ = identity;
377+ heads:: Int = 1 , concat:: Bool = true , negative_slope = 0.2 ,
378+ init_weight = glorot_uniform, init_bias = zeros32,
379+ use_bias:: Bool = true ,
380+ add_self_loops = true , dropout= 0.0 )
381+ (in, ein), out = ch
382+ if add_self_loops
383+ @assert ein== 0 " Using edge features and setting add_self_loops=true at the same time is not yet supported."
384+ end
385+
386+ dense_x = Dense (in => out * heads, use_bias = false )
387+ dense_e = ein > 0 ? Dense (ein => out * heads, use_bias = false ) : nothing
388+ negative_slope = convert (Float32, negative_slope)
389+ return GATConv (dense_x, dense_e, init_weight, init_bias, use_bias,
390+ σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
391+ end
392+
393+ LuxCore. outputsize (l:: GATConv ) = (l. concat ? l. channel[2 ]* l. heads : l. channel[2 ],)
394+ # #TODO : parameterlength
395+
396+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GATConv )
397+ (in, ein), out = l. channel
398+ dense_x = LuxCore. initialparameters (rng, l. dense_x)
399+ a = l. init_weight (ein > 0 ? 3 out : 2 out, l. heads)
400+ ps = (; dense_x, a)
401+ if ein > 0
402+ ps = (ps... , dense_e = LuxCore. initialparameters (rng, l. dense_e))
403+ end
404+ if l. use_bias
405+ ps = (ps... , bias = l. init_bias (rng, l. concat ? out * l. heads : out))
406+ end
407+ return ps
408+ end
409+
410+ (l:: GATConv )(g, x, ps, st) = l (g, x, nothing , ps, st)
411+
412+ function (l:: GATConv )(g, x, e, ps, st)
413+ dense_x = StatefulLuxLayer {true} (l. dense_x, ps. dense_x, _getstate (st, :dense_x ))
414+ dense_e = l. dense_e === nothing ? nothing :
415+ StatefulLuxLayer {true} (l. dense_e, ps. dense_e, _getstate (st, :dense_e ))
416+
417+ m = (; l. add_self_loops, l. channel, l. heads, l. concat, l. dropout, l. σ,
418+ ps. a, bias = _getbias (ps), dense_x, dense_e, l. negative_slope)
419+ return GNNlib. gat_conv (m, g, x, e), st
420+ end
421+
422+ function Base. show (io:: IO , l:: GATConv )
423+ (in, ein), out = l. channel
424+ print (io, " GATConv(" , ein == 0 ? in : (in, ein), " => " , out ÷ l. heads)
425+ l. σ == identity || print (io, " , " , l. σ)
426+ print (io, " , negative_slope=" , l. negative_slope)
427+ print (io, " )" )
428+ end
429+
430+ @concrete struct GATv2Conv <: GNNLayer
431+ dense_i
432+ dense_j
433+ dense_e
434+ init_weight
435+ init_bias
436+ use_bias:: Bool
437+ σ
438+ negative_slope
439+ channel:: Pair{NTuple{2, Int}, Int}
440+ heads:: Int
441+ concat:: Bool
442+ add_self_loops:: Bool
443+ dropout
444+ end
445+
446+ function GATv2Conv (ch:: Pair{Int, Int} , args... ; kws... )
447+ GATv2Conv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
448+ end
449+
450+ function GATv2Conv (ch:: Pair{NTuple{2, Int}, Int} ,
451+ σ = identity;
452+ heads:: Int = 1 ,
453+ concat:: Bool = true ,
454+ negative_slope = 0.2 ,
455+ init_weight = glorot_uniform,
456+ init_bias = zeros32,
457+ use_bias:: Bool = true ,
458+ add_self_loops = true ,
459+ dropout= 0.0 )
460+
461+ (in, ein), out = ch
462+
463+ if add_self_loops
464+ @assert ein== 0 " Using edge features and setting add_self_loops=true at the same time is not yet supported."
465+ end
466+
467+ dense_i = Dense (in => out * heads; use_bias, init_weight, init_bias)
468+ dense_j = Dense (in => out * heads; use_bias = false , init_weight)
469+ if ein > 0
470+ dense_e = Dense (ein => out * heads; use_bias = false , init_weight)
471+ else
472+ dense_e = nothing
473+ end
474+ return GATv2Conv (dense_i, dense_j, dense_e,
475+ init_weight, init_bias, use_bias,
476+ σ, negative_slope,
477+ ch, heads, concat, add_self_loops, dropout)
478+ end
479+
480+
481+ LuxCore. outputsize (l:: GATv2Conv ) = (l. concat ? l. channel[2 ]* l. heads : l. channel[2 ],)
482+ # #TODO : parameterlength
483+
484+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GATv2Conv )
485+ (in, ein), out = l. channel
486+ dense_i = LuxCore. initialparameters (rng, l. dense_i)
487+ dense_j = LuxCore. initialparameters (rng, l. dense_j)
488+ a = l. init_weight (out, l. heads)
489+ ps = (; dense_i, dense_j, a)
490+ if ein > 0
491+ ps = (ps... , dense_e = LuxCore. initialparameters (rng, l. dense_e))
492+ end
493+ if l. use_bias
494+ ps = (ps... , bias = l. init_bias (rng, l. concat ? out * l. heads : out))
495+ end
496+ return ps
497+ end
498+
499+ (l:: GATv2Conv )(g, x, ps, st) = l (g, x, nothing , ps, st)
500+
501+ function (l:: GATv2Conv )(g, x, e, ps, st)
502+ dense_i = StatefulLuxLayer {true} (l. dense_i, ps. dense_i, _getstate (st, :dense_i ))
503+ dense_j = StatefulLuxLayer {true} (l. dense_j, ps. dense_j, _getstate (st, :dense_j ))
504+ dense_e = l. dense_e === nothing ? nothing :
505+ StatefulLuxLayer {true} (l. dense_e, ps. dense_e, _getstate (st, :dense_e ))
506+
507+ m = (; l. add_self_loops, l. channel, l. heads, l. concat, l. dropout, l. σ,
508+ ps. a, bias = _getbias (ps), dense_i, dense_j, dense_e, l. negative_slope)
509+ return GNNlib. gatv2_conv (m, g, x, e), st
510+ end
511+
512+ function Base. show (io:: IO , l:: GATv2Conv )
513+ (in, ein), out = l. channel
514+ print (io, " GATv2Conv(" , ein == 0 ? in : (in, ein), " => " , out ÷ l. heads)
515+ l. σ == identity || print (io, " , " , l. σ)
516+ print (io, " , negative_slope=" , l. negative_slope)
517+ print (io, " )" )
518+ end
0 commit comments