Skip to content

Commit 880d046

Browse files
authored
Try to test Enzyme on 1.10 (#270)
* Try to test Enzyme on 1.10 * Remove nonexistent test * Fix forward mode returns
1 parent 77a9008 commit 880d046

3 files changed

Lines changed: 34 additions & 12 deletions

File tree

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
2020
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
2121
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2222
CUDACore = "bd0ed864-bdfe-4181-a5ed-ce625a5fdea2"
23+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2324
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2425
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2526
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
26-
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2727

2828
[extensions]
2929
TensorOperationsBumperExt = "Bumper"
@@ -42,7 +42,7 @@ CUDACore = "6"
4242
ChainRulesCore = "1"
4343
ChainRulesTestUtils = "1"
4444
DynamicPolynomials = "0.5, 0.6"
45-
Enzyme = "0.13.115"
45+
Enzyme = "0.13.146"
4646
EnzymeTestUtils = "0.2"
4747
JLArrays = "0.3"
4848
LRUCache = "1"
@@ -74,6 +74,7 @@ cuRAND = "20fd9a0b-12d5-4c2f-a8af-7c34e9e60431"
7474
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
7575
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
7676
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
77+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
7778
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
7879
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7980
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,15 @@ function EnzymeRules.forward(
164164
end
165165
end
166166
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA, conjA, B_dB.val, pB, conjB, pAB, α, β, ba...)
167-
return C_dC
167+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
168+
return C_dC
169+
elseif EnzymeRules.needs_primal(config)
170+
return C_dC.val
171+
elseif EnzymeRules.needs_shadow(config)
172+
return C_dC.dval
173+
else
174+
return nothing
175+
end
168176
end
169177

170178
function EnzymeRules.augmented_primal(
@@ -275,7 +283,15 @@ function EnzymeRules.forward(
275283
end
276284
end
277285
TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA, conjA, α, β, ba...)
278-
return C_dC
286+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
287+
return C_dC
288+
elseif EnzymeRules.needs_primal(config)
289+
return C_dC.val
290+
elseif EnzymeRules.needs_shadow(config)
291+
return C_dC.dval
292+
else
293+
return nothing
294+
end
279295
end
280296

281297
function EnzymeRules.augmented_primal(
@@ -389,7 +405,15 @@ function EnzymeRules.forward(
389405
end
390406
# D = α * tr(A) + β * C
391407
TensorOperations.tensortrace!(C_dC.val, A_dA.val, p, q, conjA, α, β, ba...)
392-
return C_dC
408+
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
409+
return C_dC
410+
elseif EnzymeRules.needs_primal(config)
411+
return C_dC.val
412+
elseif EnzymeRules.needs_shadow(config)
413+
return C_dC.dval
414+
else
415+
return nothing
416+
end
393417
end
394418

395419
end

test/runtests.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@ if !is_buildkite
3939
@testset "mooncake" verbose = false begin
4040
include("mooncake.jl")
4141
end
42-
# mystery segfault on 1.10 for now
43-
@static if VERSION >= v"1.11.0"
44-
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
45-
if !is_apple_ci
46-
@testset "enzyme" verbose = false begin
47-
include("enzyme.jl")
48-
end
42+
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
43+
if !is_apple_ci
44+
@testset "enzyme" verbose = false begin
45+
include("enzyme.jl")
4946
end
5047
end
5148
end

0 commit comments

Comments
 (0)