|
1 | 1 | @testmodule TestModuleLux begin |
2 | 2 |
|
3 | | -using Pkg |
4 | | - |
5 | | -## Uncomment below to change the default test settings |
6 | | -# ENV["GNN_TEST_CUDA"] = "true" |
7 | | -# ENV["GNN_TEST_AMDGPU"] = "true" |
8 | | -# ENV["GNN_TEST_Metal"] = "true" |
9 | | - |
10 | | -to_test(backend) = get(ENV, "GNN_TEST_$(backend)", "false") == "true" |
11 | | -has_dependecies(pkgs) = all(pkg -> haskey(Pkg.project().dependencies, pkg), pkgs) |
12 | | -deps_dict = Dict(:CUDA => ["CUDA", "cuDNN"], :AMDGPU => ["AMDGPU"], :Metal => ["Metal"]) |
13 | | - |
14 | | -for (backend, deps) in deps_dict |
15 | | - if to_test(backend) |
16 | | - if !has_dependecies(deps) |
17 | | - Pkg.add(deps) |
18 | | - end |
19 | | - @eval using $backend |
20 | | - if backend == :CUDA |
21 | | - @eval using cuDNN |
22 | | - end |
23 | | - @eval $backend.allowscalar(false) |
24 | | - end |
25 | | -end |
26 | | - |
27 | | -using Reexport: @reexport |
28 | | - |
29 | | -@reexport using Test |
30 | | -@reexport using GNNLux |
31 | | -@reexport using Lux |
32 | | -@reexport using StableRNGs |
33 | | -@reexport using Random, Statistics |
34 | | - |
35 | | -using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme |
36 | | - |
37 | | -export test_lux_layer |
38 | | - |
39 | | -function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; |
40 | | - outputsize=nothing, sizey=nothing, container=false, |
41 | | - atol=1.0f-2, rtol=1.0f-2, e=nothing) |
42 | | - |
43 | | - if container |
44 | | - @test l isa GNNContainerLayer |
45 | | - else |
46 | | - @test l isa GNNLayer |
47 | | - end |
48 | | - |
49 | | - ps = LuxCore.initialparameters(rng, l) |
50 | | - st = LuxCore.initialstates(rng, l) |
51 | | - @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) |
52 | | - @test LuxCore.statelength(l) == LuxCore.statelength(st) |
53 | | - |
54 | | - if e !== nothing |
55 | | - y, st′ = l(g, x, e, ps, st) |
56 | | - else |
57 | | - y, st′ = l(g, x, ps, st) |
58 | | - end |
59 | | - @test eltype(y) == eltype(x) |
60 | | - if outputsize !== nothing |
61 | | - @test LuxCore.outputsize(l) == outputsize |
62 | | - end |
63 | | - if sizey !== nothing |
64 | | - @test size(y) == sizey |
65 | | - elseif outputsize !== nothing |
66 | | - @test size(y) == (outputsize..., g.num_nodes) |
67 | | - end |
68 | | - |
69 | | - if e !== nothing |
70 | | - loss = (x, ps) -> sum(first(l(g, x, e, ps, st))) |
71 | | - else |
72 | | - loss = (x, ps) -> sum(first(l(g, x, ps, st))) |
73 | | - end |
74 | | - test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) |
75 | | -end |
| 3 | +include("test_utils.jl") |
76 | 4 |
|
77 | 5 | end |
0 commit comments