@@ -398,29 +398,42 @@ end
398398
399399@testitem " AGNNConv" setup= [TolSnippet, TestModule] begin
400400 using . TestModule
401- l = AGNNConv (trainable= false , add_self_loops= false )
402- @test l. β == [1.0f0 ]
403- @test l. add_self_loops == false
404- @test l. trainable == false
405- Flux. trainable (l) == (;)
406-
407- l = AGNNConv (init_beta= 2.0f0 )
408- @test l. β == [2.0f0 ]
409- @test l. add_self_loops == true
410- @test l. trainable == true
411- Flux. trainable (l) == (; β = [1f0 ])
412- for g in TEST_GRAPHS
413- @test size (l (g, g. x)) == (D_IN, g. num_nodes)
414- test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
401+ @testset " Initialization & Basic Forward" begin
402+ l = AGNNConv (init_beta= 2.0f0 , trainable= true )
403+ @test l. β[1 ] == 2.0f0
404+ @test l. trainable == true
405+ for g in TEST_GRAPHS
406+ @test size (l (g, g. x)) == (D_IN, g. num_nodes)
407+ test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
408+ end
409+ end
410+ @testset " Bipartite Support" begin
411+ l = AGNNConv (add_self_loops= false )
412+ in_channel = 16
413+ s = [rand (1 : 5 , 14 )... , 5 ]
414+ d = [rand (1 : 8 , 14 )... , 8 ]
415+ g = GNNGraph ((s, d))
416+ x = (randn (Float32, in_channel, 5 ), randn (Float32, in_channel, 8 ))
417+ y = l (g, x)
418+ @test size (y) == (in_channel, 8 )
419+ test_gradients (l, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
420+ end
421+ @testset " Stability (Epsilon)" begin
422+ l = AGNNConv ()
423+ g = TEST_GRAPHS[1 ]
424+ x_dead = randn (Float32, D_IN, g. num_nodes)
425+ x_dead[:, 1 ] .= 0.0f0
426+ y = l (g, x_dead)
427+ @test ! any (isnan .(y))
415428 end
416429end
417430
418431@testitem " AGNNConv GPU" setup= [TolSnippet, TestModule] tags= [:gpu ] begin
419432 using . TestModule
420- l = AGNNConv (trainable = false , add_self_loops = false )
433+ l = AGNNConv ()
421434 for g in TEST_GRAPHS
422- g. graph isa AbstractSparseMatrix && continue
423- @test size (l (g, g. x)) == (D_IN, g. num_nodes)
435+ g. graph isa AbstractSparseMatrix && continue
436+ @test size (l (g, g. x)) == (D_IN, g. num_nodes)
424437 test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_gpu = true , compare_finite_diff = false )
425438 end
426439end
0 commit comments