3232 @test y === h
3333 @test size (h) == (out_channel, g. num_nodes)
3434 # with no initial state
35- test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_HIGH)
35+ test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
3636 # with initial state
37- test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_HIGH)
37+ test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
3838
3939 # Test with custom activation function
4040 custom_activation = tanh
4545 # Test that outputs differ when using different activation functions
4646 @test ! isapprox (y, y_custom, rtol= RTOL_HIGH)
4747 # with no initial state
48- test_gradients (cell_custom, g, g. x, loss= cell_loss, rtol= RTOL_HIGH)
48+ test_gradients (cell_custom, g, g. x, loss= cell_loss, rtol= RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
4949 # with initial state
50- test_gradients (cell_custom, g, g. x, h_custom, loss= cell_loss, rtol= RTOL_HIGH)
50+ test_gradients (cell_custom, g, g. x, h_custom, loss= cell_loss, rtol= RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
5151end
5252
5353@testitem " TGCN" setup= [TemporalConvTestModule, TestModule] begin
6161 @test layer isa GNNRecurrence
6262 @test size (y) == (out_channel, timesteps, g. num_nodes)
6363 # with no initial state
64- test_gradients (layer, g, x, rtol = RTOL_HIGH)
64+ test_gradients (layer, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
6565 # with initial state
66- test_gradients (layer, g, x, state0, rtol = RTOL_HIGH)
66+ test_gradients (layer, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
6767
6868 # Test with custom activation function
6969 custom_activation = tanh
7474 # Test that outputs differ when using different activation functions
7575 @test ! isapprox (y, y_custom, rtol = RTOL_HIGH)
7676 # with no initial state
77- test_gradients (layer_custom, g, x, rtol = RTOL_HIGH)
77+ test_gradients (layer_custom, g, x, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
7878 # with initial state
79- test_gradients (layer_custom, g, x, state0, rtol = RTOL_HIGH)
79+ test_gradients (layer_custom, g, x, state0, rtol = RTOL_HIGH, test_mooncake = TEST_MOONCAKE )
8080
8181 # interplay with GNNChain
8282 model = GNNChain (TGCN (in_channel => out_channel), Dense (out_channel, 1 ))
8383 y = model (g, x)
8484 @test size (y) == (1 , timesteps, g. num_nodes)
85- test_gradients (model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW)
85+ test_gradients (model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW, test_mooncake = TEST_MOONCAKE )
8686end
8787
8888@testitem " GConvLSTMCell" setup= [TemporalConvTestModule, TestModule] begin
9393 @test size (h) == (out_channel, g. num_nodes)
9494 @test size (c) == (out_channel, g. num_nodes)
9595 # with no initial state
96- test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
96+ test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
9797 # with initial state
98- test_gradients (cell, g, g. x, (h, c), loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
98+ test_gradients (cell, g, g. x, (h, c), loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
9999end
100100
101101@testitem " GConvLSTM" setup= [TemporalConvTestModule, TestModule] begin
@@ -107,15 +107,15 @@ end
107107 y = layer (g, x)
108108 @test size (y) == (out_channel, timesteps, g. num_nodes)
109109 # with no initial state
110- test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW)
110+ test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
111111 # with initial state
112- test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW)
112+ test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
113113
114114 # interplay with GNNChain
115115 model = GNNChain (GConvLSTM (in_channel => out_channel, 2 ), Dense (out_channel, 1 ))
116116 y = model (g, x)
117117 @test size (y) == (1 , timesteps, g. num_nodes)
118- test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
118+ test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false )
119119end
120120
121121@testitem " GConvGRUCell" setup= [TemporalConvTestModule, TestModule] begin
125125 @test y === h
126126 @test size (h) == (out_channel, g. num_nodes)
127127 # with no initial state
128- test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
128+ test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
129129 # with initial state
130- test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
130+ test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
131131end
132132
133133
@@ -140,15 +140,15 @@ end
140140 y = layer (g, x)
141141 @test size (y) == (out_channel, timesteps, g. num_nodes)
142142 # with no initial state
143- test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW)
143+ test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
144144 # with initial state
145- test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW)
145+ test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
146146
147147 # interplay with GNNChain
148148 model = GNNChain (GConvGRU (in_channel => out_channel, 2 ), Dense (out_channel, 1 ))
149149 y = model (g, x)
150150 @test size (y) == (1 , timesteps, g. num_nodes)
151- test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
151+ test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false )
152152end
153153
154154@testitem " DCGRUCell" setup= [TemporalConvTestModule, TestModule] begin
158158 @test y === h
159159 @test size (h) == (out_channel, g. num_nodes)
160160 # with no initial state
161- test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
161+ test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
162162 # with initial state
163- test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
163+ test_gradients (cell, g, g. x, h, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
164164end
165165
166166@testitem " DCGRU" setup= [TemporalConvTestModule, TestModule] begin
@@ -172,15 +172,15 @@ end
172172 y = layer (g, x)
173173 @test size (y) == (out_channel, timesteps, g. num_nodes)
174174 # with no initial state
175- test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW)
175+ test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
176176 # with initial state
177- test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW)
177+ test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
178178
179179 # interplay with GNNChain
180180 model = GNNChain (DCGRU (in_channel => out_channel, 2 ), Dense (out_channel, 1 ))
181181 y = model (g, x)
182182 @test size (y) == (1 , timesteps, g. num_nodes)
183- test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW)
183+ test_gradients (model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW, test_mooncake = false )
184184end
185185
186186@testitem " EvolveGCNOCell" setup= [TemporalConvTestModule, TestModule] begin
189189 y, state = cell (g, g. x)
190190 @test size (y) == (out_channel, g. num_nodes)
191191 # with no initial state
192- test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
192+ test_gradients (cell, g, g. x, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
193193 # with initial state
194- test_gradients (cell, g, g. x, state, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW)
194+ test_gradients (cell, g, g. x, state, loss= cell_loss, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = false )
195195end
196196
197197@testitem " EvolveGCNO" setup= [TemporalConvTestModule, TestModule] begin
@@ -203,15 +203,15 @@ end
203203 y = layer (g, x)
204204 @test size (y) == (out_channel, timesteps, g. num_nodes)
205205 # with no initial state
206- test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW)
206+ test_gradients (layer, g, x, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = TEST_MOONCAKE )
207207 # with initial state
208- test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW)
208+ test_gradients (layer, g, x, state0, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = TEST_MOONCAKE )
209209
210210 # interplay with GNNChain
211211 model = GNNChain (EvolveGCNO (in_channel => out_channel), Dense (out_channel, 1 ))
212212 y = model (g, x)
213213 @test size (y) == (1 , timesteps, g. num_nodes)
214- test_gradients (model, g, x, rtol= RTOL_LOW, atol= ATOL_LOW)
214+ test_gradients (model, g, x, rtol= RTOL_LOW, atol= ATOL_LOW, test_mooncake = TEST_MOONCAKE )
215215end
216216
217217# @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin
0 commit comments