|
155 | 155 | x = (A = rand(Float32, 4,2), B = rand(Float32, 4, 3)) |
156 | 156 | layers = HeteroGraphConv( (:A, :to, :B) => GCNConv(4 => 2, tanh), |
157 | 157 | (:B, :to, :A) => GCNConv(4 => 2, tanh)); |
158 | | - y = layers(g, x); |
| 158 | + y = layers(g, x); |
159 | 159 | @test size(y.A) == (2,2) && size(y.B) == (2,3) |
160 | 160 | end |
| 161 | + |
| 162 | + @testset "NNConv" begin |
| 163 | + nA, nB = 2, 3 |
| 164 | + nedges = 6 |
| 165 | + g = rand_bipartite_heterograph((nA, nB), nedges; |
| 166 | + edata = Dict((:A, :to, :B) => (e = rand(Float32, 5, nedges),), |
| 167 | + (:B, :to, :A) => (e = rand(Float32, 5, nedges),)) |
| 168 | + ) |
| 169 | + x = (A = rand(Float32, 4, nA), B = rand(Float32, 4, nB)) |
| 170 | + nn = Dense(5 => 4 * 2) |
| 171 | + layers = HeteroGraphConv((:A, :to, :B) => NNConv(4 => 2, nn, tanh), |
| 172 | + (:B, :to, :A) => NNConv(4 => 2, nn, tanh)) |
| 173 | + y = layers(g, x) |
| 174 | + @test size(y.A) == (2, nA) && size(y.B) == (2, nB) |
| 175 | + end |
161 | 176 | end |
0 commit comments