forked from JuliaGraphs/GraphNeuralNetworks.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpool.jl
More file actions
89 lines (77 loc) · 2.79 KB
/
pool.jl
File metadata and controls
89 lines (77 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
@testitem "GlobalPool" setup=[TestModule] begin
using .TestModule
@testset "GlobalPool $GRAPH_T" for GRAPH_T in GRAPH_TYPES
p = GlobalPool(+)
n = 10
chin = 6
X = rand(Float32, 6, n)
g = GNNGraph(random_regular_graph(n, 4), ndata = X, graph_type = GRAPH_T)
u = p(g, X)
@test u ≈ sum(X, dims = 2)
ng = 3
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
ndata = rand(Float32, chin, n),
graph_type = GRAPH_T)
for i in 1:ng])
u = p(g, g.ndata.x)
@test size(u) == (chin, ng)
@test u[:, [1]] ≈ sum(g.ndata.x[:, 1:n], dims = 2)
@test p(g).gdata.u == u
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end
@testitem "GlobalAttentionPool" setup=[TestModule] begin
using .TestModule
@testset "GlobalAttentionPool $GRAPH_T" for GRAPH_T in GRAPH_TYPES
n = 10
chin = 6
chout = 5
ng = 3
fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
p = GlobalAttentionPool(fgate, ffeat)
@test length(Flux.trainables(p)) == 4
g = Flux.batch([GNNGraph(random_regular_graph(n, 4),
ndata = rand(Float32, chin, n),
graph_type = GRAPH_T)
for i in 1:ng])
@test size(p(g, g.x)) == (chout, ng)
test_gradients(p, g, g.x, rtol = 1e-5, test_mooncake = TEST_MOONCAKE)
end
end
@testitem "TopKPool" setup=[TestModule] begin
using .TestModule
N = 10
k, in_channel = 4, 7
X = rand(in_channel, N)
for T in [Bool, Float64]
adj = rand(T, N, N)
p = TopKPool(adj, k, in_channel)
@test eltype(p.p) === Float32
@test size(p.p) == (in_channel,)
@test eltype(p.Ã) === T
@test size(p.Ã) == (k, k)
y = p(X)
@test size(y) == (in_channel, k)
end
end
@testitem "topk_index" begin
X = [8, 7, 6, 5, 4, 3, 2, 1]
@test topk_index(X, 4) == [1, 2, 3, 4]
@test topk_index(X', 4) == [1, 2, 3, 4]
end
@testitem "Set2Set" setup=[TestModule] begin
using .TestModule
@testset "Set2Set $GRAPH_T" for GRAPH_T in GRAPH_TYPES
n_in = 3
n_iters = 2
n_layers = 1 #TODO test with more layers
g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5])
g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes))
l = Set2Set(n_in, n_iters, n_layers)
y = l(g, node_features(g))
@test size(y) == (2 * n_in, g.num_graphs)
## TODO the numerical gradient seems to be 3 times smaller than zygote one
# test_gradients(l, g, g.x, rtol = 1e-4, atol=1e-4)
end
end