Skip to content

Commit deb215c

Browse files
committed
Progress on trunc
1 parent 4bd6397 commit deb215c

3 files changed

Lines changed: 92 additions & 69 deletions

File tree

Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4949
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
5050
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5151
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
52+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5253
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
5354
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
5455
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
@@ -60,3 +61,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6061

6162
[targets]
6263
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake", "Enzyme", "EnzymeTestUtils"]
64+
65+
[sources]
66+
Enzyme = {path="/Users/khyatt/.julia/dev/Enzyme"}
67+
EnzymeTestUtils = {path="/Users/khyatt/.julia/dev/Enzyme/lib/EnzymeTestUtils"}

ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,24 @@ function EnzymeRules.augmented_primal(
191191
::Type{RT},
192192
A::Annotation,
193193
USVᴴ::Annotation,
194-
ϵ::Annotation{T},
195194
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
196-
) where {RT, T <: Real}
195+
) where {RT}
197196
# form cache if needed
198197
cache_A = copy(A.val)
199198
svd_compact!(A.val, USVᴴ.val, alg.val.alg)
200199
cache_USVᴴ = copy.(USVᴴ.val)
201200
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ.val, alg.val.trunc)
202-
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
203-
primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
201+
ϵ.val = MatrixAlgebraKit.truncation_error!(diagview(USVᴴ.val[2]), ind)
202+
primal = EnzymeRules.needs_primal(config) ? (USVᴴ′..., ϵ.val) : nothing
204203
shadow_USVᴴ = if !isa(A, Const) && !isa(USVᴴ, Const)
205204
dU, dS, dVᴴ = USVᴴ.dval
206205
# This creates new output shadow matrices, we do this slicing
207206
# to ensure they have the correct eltype and dimensions.
208207
# These new shadow matrices are "filled in" with the accumulated
209208
# results from earlier in reverse-mode AD after this function exits
210209
# and before `reverse` is called.
211-
dStrunc = Diagonal(diagview(dS)[ind])
212-
dUtrunc = dU[:, ind]
210+
dStrunc = Diagonal(diagview(dS)[ind])
211+
dUtrunc = dU[:, ind]
213212
dVᴴtrunc = dVᴴ[ind, :]
214213
(dUtrunc, dStrunc, dVᴴtrunc)
215214
else
@@ -225,34 +224,64 @@ function EnzymeRules.reverse(
225224
cache,
226225
A::Annotation,
227226
USVᴴ::Annotation,
228-
ϵ::Annotation{T},
229227
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
230-
) where {RT, T <: Real}
228+
) where {RT}
231229
cache_A, cache_USVᴴ, shadow_USVᴴ, ind = cache
232-
U, S, Vᴴ = cache_USVᴴ
230+
U, S, Vᴴ = cache_USVᴴ
233231
dU, dS, dVᴴ = shadow_USVᴴ
234-
Aval = isnothing(cache_A) ? A.val : cache_A
232+
Aval = isnothing(cache_A) ? A.val : cache_A
235233
if !isa(A, Const) && !isa(USVᴴ, Const)
236234
svd_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
237235
end
238-
if !isa(USVᴴ, Const)
239-
make_zero!(USVᴴ.dval)
240-
end
241-
if !isa(ϵ, Const)
242-
make_zero!.dval)
243-
end
236+
!isa(USVᴴ, Const) && make_zero!(USVᴴ.dval)
237+
!isa(ϵ, Const) && make_zero!.dval)
244238
return (nothing, nothing, nothing, nothing)
245239
end
246240

241+
function EnzymeRules.augmented_primal(
242+
config::EnzymeRules.RevConfigWidth{1},
243+
func::Const{typeof(svd_trunc)},
244+
::Type{MixedDuplicated},
245+
A::Annotation,
246+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
247+
)
248+
# form cache if needed
249+
cache_A = copy(A.val)
250+
U, S, Vᴴ, ϵ = svd_trunc(A.val, USVᴴ.val, alg.val.alg)
251+
primal = EnzymeRules.needs_primal(config) ? (U, S, Vᴴ, ϵ) : nothing
252+
dU = zero(U)
253+
dS = zero(S)
254+
dVᴴ = zero(Vᴴ)
255+
= zero(ϵ)
256+
shadow = EnzymeRules.needs_shadow(config) ? (dU, dS, dVᴴ, dϵ) : nothing
257+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, (U, S, Vᴴ), (dU, dS, dVᴴ)))
258+
end
259+
function EnzymeRules.reverse(
260+
config::EnzymeRules.RevConfigWidth{1},
261+
func::Const{typeof(svd_trunc)},
262+
dret::Type{MixedDuplicated},
263+
cache,
264+
A::Annotation,
265+
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
266+
)
267+
cache_A, cache_USVᴴ, shadow_USVᴴ = cache
268+
U, S, Vᴴ = cache_USVᴴ
269+
dU, dS, dVᴴ = shadow_USVᴴ
270+
Aval = isnothing(cache_A) ? A.val : cache_A
271+
if !isa(A, Const) && !isa(USVᴴ, Const)
272+
svd_trunc_pullback!(A.dval, Aval, (U, S, Vᴴ), shadow_USVᴴ, ind)
273+
end
274+
return (nothing, nothing, nothing)
275+
end
276+
247277
function EnzymeRules.augmented_primal(
248278
config::EnzymeRules.RevConfigWidth{1},
249279
func::Const{typeof(eigh_trunc!)},
250280
::Type{RT},
251281
A::Annotation,
252282
DV::Annotation{Tuple{TD, TV}},
253-
ϵ::Annotation{T},
254283
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
255-
) where {RT, T, TD, TV}
284+
) where {RT, TD, TV}
256285
# form cache if needed
257286
cache_A = copy(A.val)
258287
MatrixAlgebraKit.eigh_full!(A.val, DV.val, alg.val.alg)
@@ -280,9 +309,8 @@ function EnzymeRules.reverse(
280309
cache,
281310
A::Annotation,
282311
DV::Annotation{Tuple{TD, TV}},
283-
ϵ::Annotation{T},
284312
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
285-
) where {RT, T, TD, TV}
313+
) where {RT, TD, TV}
286314
cache_A, cache_DV, cache_dDVtrunc, ind = cache
287315
Aval = cache_A
288316
D, V = cache_DV
@@ -301,9 +329,8 @@ function EnzymeRules.augmented_primal(
301329
::Type{RT},
302330
A::Annotation,
303331
DV::Annotation{Tuple{TD, TV}},
304-
ϵ::Annotation{T},
305332
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
306-
) where {RT, T, TD, TV}
333+
) where {RT, TD, TV}
307334
# form cache if needed
308335
cache_A = copy(A.val)
309336
eig_full!(A.val, DV.val, alg.val.alg)
@@ -329,9 +356,8 @@ function EnzymeRules.reverse(
329356
cache,
330357
A::Annotation,
331358
DV::Annotation{Tuple{TD, TV}},
332-
ϵ::Annotation{T},
333359
alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm},
334-
) where {RT, T, TD, TV}
360+
) where {RT, TD, TV}
335361
cache_A, cache_DV, cache_dDVtrunc = cache
336362
D, V = cache_DV
337363
Aval = cache_A

test/enzyme.jl

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ is_ci = get(ENV, "CI", "false") == "true"
1111

1212
ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
1313
include("ad_utils.jl")
14-
1514
function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated)
1615
ΔA = randn(rng, eltype(A), size(A)...)
1716
A_ΔA() = Duplicated(copy(A), copy(ΔA))
@@ -46,7 +45,7 @@ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ =
4645
end
4746
return
4847
end
49-
48+
#=
5049
@timedtestset "QR AD Rules with eltype $T" for T in ETs
5150
rng = StableRNG(12345)
5251
m = 19
@@ -190,9 +189,9 @@ end
190189
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
191190
ΔVtrunc = ΔV[:, ind]
192191
# broken due to Enzyme
193-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
192+
test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
194193
# broken due to Enzyme
195-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, (ΔDtrunc, ΔVtrunc, zero(real(T))))
194+
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
196195
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
197196
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
198197
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -204,9 +203,9 @@ end
204203
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
205204
ΔVtrunc = ΔV[:, ind]
206205
# broken due to Enzyme
207-
#test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
206+
test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
208207
# broken due to Enzyme
209-
#test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=MixedDuplicated)
208+
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
210209
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
211210
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
212211
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -309,8 +308,8 @@ end
309308
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
310309
ΔVtrunc = ΔV[:, ind]
311310
# broken due to Enzyme
312-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
313-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
311+
test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312+
test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314313
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
315314
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
316315
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
@@ -323,15 +322,15 @@ end
323322
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
324323
ΔVtrunc = ΔV[:, ind]
325324
# broken due to Enzyme
326-
#test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
327-
#test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T)), return_act=RT))
325+
test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
326+
test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
328327
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
329328
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
330329
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
331330
end
332331
end
333332
end
334-
333+
=#
335334
@timedtestset "SVD AD Rules with eltype $T" for T in ETs
336335
rng = StableRNG(12345)
337336
m = 19
@@ -343,7 +342,7 @@ end
343342
LAPACK_QRIteration(),
344343
LAPACK_DivideAndConquer(),
345344
)
346-
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
345+
#=@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
347346
@testset "svd_compact" begin
348347
U, S, Vᴴ = svd_compact(A)
349348
ΔU = randn(rng, T, m, minmn)
@@ -378,39 +377,16 @@ end
378377
test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm)
379378
test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg)
380379
end
381-
end
380+
end=#
382381
@testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (MixedDuplicated,), TA in (Duplicated,)
383-
@testset "svd_trunc" begin
384-
for r in 1:4:minmn
385-
U, S, Vᴴ = svd_compact(A)
386-
ΔU = randn(rng, T, m, minmn)
387-
ΔS = randn(rng, real(T), minmn, minmn)
388-
ΔS2 = Diagonal(randn(rng, real(T), minmn))
389-
ΔVᴴ = randn(rng, T, minmn, n)
390-
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
391-
truncalg = TruncatedAlgorithm(alg, truncrank(r))
392-
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
393-
Strunc = Diagonal(diagview(S)[ind])
394-
Utrunc = U[:, ind]
395-
Vᴴtrunc = Vᴴ[ind, :]
396-
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
397-
ΔUtrunc = ΔU[:, ind]
398-
ΔVᴴtrunc = ΔVᴴ[ind, :]
399-
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
400-
# broken due to Enzyme
401-
#test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
402-
#test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
403-
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
404-
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
405-
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
406-
end
382+
for r in 1:4:minmn
407383
U, S, Vᴴ = svd_compact(A)
408384
ΔU = randn(rng, T, m, minmn)
409385
ΔS = randn(rng, real(T), minmn, minmn)
410386
ΔS2 = Diagonal(randn(rng, real(T), minmn))
411387
ΔVᴴ = randn(rng, T, minmn, n)
412388
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
413-
truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2))
389+
truncalg = TruncatedAlgorithm(alg, truncrank(r))
414390
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
415391
Strunc = Diagonal(diagview(S)[ind])
416392
Utrunc = U[:, ind]
@@ -419,18 +395,33 @@ end
419395
ΔUtrunc = ΔU[:, ind]
420396
ΔVᴴtrunc = ΔVᴴ[ind, :]
421397
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
422-
# broken due to Enzyme
423-
#test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
424-
#test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
425-
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), copy(A), (copy(U), copy(S), copy(Vᴴ)), (copy(ΔUtrunc), copy(ΔStrunc), copy(ΔVᴴtrunc)), ind)
426-
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), copy(A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
427-
@test isapprox(dA1, dA2; atol = atol, rtol = rtol)
398+
# broken due to Enzyme -- copying in gaugefix????
399+
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
400+
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
428401
end
402+
U, S, Vᴴ = svd_compact(A)
403+
ΔU = randn(rng, T, m, minmn)
404+
ΔS = randn(rng, real(T), minmn, minmn)
405+
ΔS2 = Diagonal(randn(rng, real(T), minmn))
406+
ΔVᴴ = randn(rng, T, minmn, n)
407+
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
408+
truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2))
409+
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
410+
Strunc = Diagonal(diagview(S)[ind])
411+
Utrunc = U[:, ind]
412+
Vᴴtrunc = Vᴴ[ind, :]
413+
ΔStrunc = Diagonal(diagview(ΔS2)[ind])
414+
ΔUtrunc = ΔU[:, ind]
415+
ΔVᴴtrunc = ΔVᴴ[ind, :]
416+
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
417+
# broken due to Enzyme
418+
test_reverse(svd_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm = fdm)
419+
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), return_act=RT)
429420
end
430421
end
431422
end
432423
end
433-
424+
#=
434425
@timedtestset "Polar AD Rules with eltype $T" for T in ETs
435426
rng = StableRNG(12345)
436427
m = 19
@@ -513,3 +504,4 @@ end
513504
end
514505
end
515506
end
507+
=#

0 commit comments

Comments
 (0)