11@testitem " layers/temporalconv" setup= [TestModuleLux] begin
22 using . TestModuleLux
3- using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme
3+ using LuxTestUtils: test_gradients, AutoTracker, AutoForwardDiff, AutoEnzyme, AutoMooncake
44
55 rng = StableRNG (1234 )
66 g = rand_graph (rng, 10 , 40 )
1616 st = LuxCore. initialstates (rng, l)
1717 y1, _ = l (g, x, ps, st)
1818 loss = (x, ps) -> sum (first (l (g, x, ps, st)))
19- test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
19+ test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
2020
2121 # Test with custom activation (relu)
2222 l_relu = TGCN (3 => 3 , act = relu)
2828 @test ! isapprox (y1, y2, rtol= 1.0f-2 )
2929
3030 loss_relu = (x, ps) -> sum (first (l_relu (g, x, ps, st_relu)))
31- test_gradients (loss_relu, x, ps_relu; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
31+ test_gradients (loss_relu, x, ps_relu; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
3232 end
3333
3434 @testset " A3TGCN" begin
3535 l = A3TGCN (3 => 3 )
3636 ps = LuxCore. initialparameters (rng, l)
3737 st = LuxCore. initialstates (rng, l)
3838 loss = (x, ps) -> sum (first (l (g, x, ps, st)))
39- test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
39+ test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
4040 end
4141
4242 @testset " GConvGRU" begin
4343 l = GConvGRU (3 => 3 , 2 )
4444 ps = LuxCore. initialparameters (rng, l)
4545 st = LuxCore. initialstates (rng, l)
4646 loss = (x, ps) -> sum (first (l (g, x, ps, st)))
47- test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
47+ test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
4848 end
4949
5050 @testset " GConvLSTM" begin
5151 l = GConvLSTM (3 => 3 , 2 )
5252 ps = LuxCore. initialparameters (rng, l)
5353 st = LuxCore. initialstates (rng, l)
5454 loss = (x, ps) -> sum (first (l (g, x, ps, st)))
55- test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
55+ test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
5656 end
5757
5858 @testset " DCGRU" begin
5959 l = DCGRU (3 => 3 , 2 )
6060 ps = LuxCore. initialparameters (rng, l)
6161 st = LuxCore. initialstates (rng, l)
6262 loss = (x, ps) -> sum (first (l (g, x, ps, st)))
63- test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
63+ test_gradients (loss, x, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
6464 end
6565
6666 @testset " EvolveGCNO" begin
6767 l = EvolveGCNO (3 => 3 )
6868 ps = LuxCore. initialparameters (rng, l)
6969 st = LuxCore. initialstates (rng, l)
7070 loss = (tx, ps) -> sum (sum (first (l (tg, tx, ps, st))))
71- test_gradients (loss, tx, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme ()])
71+ test_gradients (loss, tx, ps; atol= 1.0f-2 , rtol= 1.0f-2 , skip_backends= [AutoForwardDiff (), AutoEnzyme (), AutoMooncake () ])
7272 end
73- end
73+ end
0 commit comments