397397# end
398398
399399@testitem " AGNNConv" setup= [TolSnippet, TestModule] begin
400- using . TestModule
400+ using . TestModule
401401 @testset " Initialization & Basic Forward" begin
402402 l = AGNNConv (init_beta= 2.0f0 , trainable= true )
403403 @test l. β[1 ] == 2.0f0
@@ -406,25 +406,28 @@ end
406406 @test size (l (g, g. x)) == (D_IN, g. num_nodes)
407407 test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
408408 end
409- end
410- @testset " Bipartite Support" begin
411- l = AGNNConv (add_self_loops= false )
412- in_channel = 16
413- s = [rand (1 : 5 , 14 )... , 5 ]
414- d = [rand (1 : 8 , 14 )... , 8 ]
415- g = GNNGraph ((s, d))
416- x = (randn (Float32, in_channel, 5 ), randn (Float32, in_channel, 8 ))
417- y = l (g, x)
418- @test size (y) == (in_channel, 8 )
419- test_gradients (l, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE)
420- end
409+ end
410+ @testset " Heterogeneous Graph Support" begin
411+ l = AGNNConv (add_self_loops= false )
412+ hg = rand_bipartite_heterograph ((10 , 15 ), 20 )
413+ x = (A = randn (Float32, D_IN, 10 ), B = randn (Float32, D_IN, 15 ))
414+ hetero_layer = HeteroGraphConv (
415+ (:A , :to , :B ) => l,
416+ (:B , :to , :A ) => l;
417+ aggr = +
418+ )
419+ y = hetero_layer (hg, x)
420+ @test size (y. A) == (D_IN, 10 )
421+ @test size (y. B) == (D_IN, 15 )
422+ end
421423 @testset " Stability (Epsilon)" begin
422424 l = AGNNConv ()
423425 g = TEST_GRAPHS[1 ]
424426 x_dead = randn (Float32, D_IN, g. num_nodes)
425- x_dead[:, 1 ] .= 0.0f0
427+ x_dead[:, 1 ] .= 0.0f0
426428 y = l (g, x_dead)
427429 @test ! any (isnan .(y))
430+ @test ! any (isinf .(y))
428431 end
429432end
430433
435438 g. graph isa AbstractSparseMatrix && continue
436439 @test size (l (g, g. x)) == (D_IN, g. num_nodes)
437440 test_gradients (l, g, g. x, rtol = RTOL_HIGH, test_gpu = true , compare_finite_diff = false )
438- end
441+ end
442+ l_bip = AGNNConv (add_self_loops= false )
443+ s = [1 , 1 , 2 , 3 ]
444+ t = [1 , 2 , 1 , 2 ]
445+ g = GNNGraph ((s, t)) |> gpu
446+ x = (randn (Float32, D_IN, 3 ) |> gpu, randn (Float32, D_IN, 2 ) |> gpu)
447+ y = l_bip (g, x)
448+ @test size (y) == (D_IN, 2 )
439449end
440450
441451@testitem " MEGNetConv" setup= [TolSnippet, TestModule] begin
0 commit comments