Skip to content

Commit 7e76c77

Browse files
committed
Reproducer
1 parent 54a08e5 commit 7e76c77

3 files changed

Lines changed: 32 additions & 28 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6161

6262
[targets]
6363
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", "EnzymeTestUtils"]
64+
65+
[sources]
66+
EnzymeTestUtils = {path="/Users/khyatt/.julia/dev/Enzyme/lib/EnzymeTestUtils"}

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,22 @@ function EnzymeRules.augmented_primal(
190190
func::Const{typeof(svd_trunc!)},
191191
::Type{RT},
192192
A::Annotation,
193-
USVᴴ::Annotation,
193+
USVᴴϵ::Annotation,
194194
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
195195
) where {RT}
196196
# form cache if needed
197197
cache_A = copy(A.val)
198-
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
199-
cache_USVᴴ = copy.(USVᴴ.val)
200-
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
201-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
202-
primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
203-
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
204-
dU, dS, dVᴴ = USVᴴ.dval
198+
USVᴴ = USVᴴϵ.val[1:3]
199+
ϵ = USVᴴϵ.val[end]
200+
svd_compact!(A.val, USVᴴ, alg.val.alg)
201+
cache_USVᴴ = copy.(USVᴴ)
202+
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.val.trunc)
203+
if !isempty(ϵ)
204+
ϵ .= MatrixAlgebraKit.truncation_error!(diagview(USVᴴ[2]), ind)
205+
end
206+
primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ) : nothing
207+
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴϵ, Const)
208+
dU, dS, dVᴴ, dϵ = USVᴴϵ.dval
205209
# This creates new output shadow matrices, we do this slicing
206210
# to ensure they have the correct eltype and dimensions.
207211
# These new shadow matrices are "filled in" with the accumulated
@@ -214,7 +218,7 @@ function EnzymeRules.augmented_primal(
214218
else
215219
(nothing, nothing, nothing)
216220
end
217-
shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., ϵ.dval) : nothing
221+
shadow = EnzymeRules.needs_shadow(config) ? (shadow_USVᴴ..., USVᴴϵ.dval[end]) : nothing
218222
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_USVᴴ, shadow_USVᴴ, ind))
219223
end
220224
function EnzymeRules.reverse(
@@ -223,7 +227,7 @@ function EnzymeRules.reverse(
223227
dret::Type{RT},
224228
cache,
225229
A::Annotation,
226-
USVᴴ::Annotation,
230+
USVᴴϵ::Annotation,
227231
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
228232
) where {RT}
229233
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
@@ -233,8 +237,7 @@ function EnzymeRules.reverse(
233237
if !isa(A, Const) && !isa(USVᴴ, Const)
234238
svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
235239
end
236-
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
237-
!isa(ϵ, Const) && make_zero!.dval)
240+
!isa(USVᴴϵ, Const) && make_zero!(USVᴴϵ.dval)
238241
return (nothing, nothing, nothing, nothing)
239242
end
240243

test/enzyme.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99

1010
is_ci = get(ENV, "CI", "false") == "true"
1111

12-
ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
12+
#ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
13+
ETs = (Float64,) # Enzyme/#2631
1314
include("ad_utils.jl")
1415
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated)
1516
ΔA = randn(rng, eltype(A), size(A)...)
@@ -180,7 +181,7 @@ end
180181
test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
181182
test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg)
182183
end
183-
@testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
184+
@testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
184185
for r in 1:4:m
185186
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
186187
ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc)
@@ -298,7 +299,7 @@ end
298299
test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
299300
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
300301
end
301-
@testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
302+
@testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
302303
for r in 1:4:m
303304
Ddiag = diagview(D)
304305
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
@@ -334,13 +335,13 @@ end
334335
@timedtestset "SVD AD Rules with eltype $T" for T in ETs
335336
rng = StableRNG(12345)
336337
m = 19
337-
@testset "size ($m, $n)" for n in (17, m, 23)
338+
@testset "size ($m, $n)" for n in (17,)# m, 23)
338339
atol = rtol = m * n * precision(T)
339340
A = randn(rng, T, m, n)
340341
minmn = min(m, n)
341342
@testset for alg in (
342343
LAPACK_QRIteration(),
343-
LAPACK_DivideAndConquer(),
344+
#LAPACK_DivideAndConquer(),
344345
)
345346
#=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
346347
@testset "svd_compact" begin
@@ -378,8 +379,9 @@ end
378379
test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg)
379380
end
380381
end=#
381-
@testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
382-
for r in 1:4:minmn
382+
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
383+
@testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
384+
#=for r in 1:4:minmn
383385
U, S, Vᴴ = svd_compact(A)
384386
ΔU = randn(rng, T, m, minmn)
385387
ΔS = randn(rng, real(T), minmn, minmn)
@@ -394,11 +396,9 @@ end
394396
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
395397
ΔUtrunc = ΔU[:, ind]
396398
ΔVᴴtrunc = ΔVᴴ[ind, :]
397-
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
398-
# broken due to Enzyme -- copying in gaugefix????
399-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
400-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
401-
end
399+
test_reverse(svd_trunc!, RT, (copy(A), TA), ((copy(U), copy(S), copy(Vᴴ), [zero(real(T))]), TA), (truncalg, Const); atol = atol, rtol = rtol, output_tangent = (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc), [zero(real(T))]), fdm = fdm)
400+
#test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
401+
end=#
402402
U, S, Vᴴ = svd_compact(A)
403403
ΔU = randn(rng, T, m, minmn)
404404
ΔS = randn(rng, real(T), minmn, minmn)
@@ -413,10 +413,8 @@ end
413413
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
414414
ΔUtrunc = ΔU[:, ind]
415415
ΔVᴴtrunc = ΔVᴴ[ind, :]
416-
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
417-
# broken due to Enzyme
418-
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
419-
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
416+
test_reverse(svd_trunc!, RT, (A, TA), ((copy(U), copy(S), copy(Vᴴ), [zero(real(T))]), TA), (truncalg, Const); atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), fdm = fdm)
417+
#test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
420418
end
421419
end
422420
end

0 commit comments

Comments
 (0)