Skip to content

Commit 1abbd3b

Browse files
committed
some changes to the ad test utils
1 parent c86c7d7 commit 1abbd3b

6 files changed

Lines changed: 161 additions & 224 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Enzyme = "0.13.118"
3434
EnzymeTestUtils = "0.2.5"
3535
GenericLinearAlgebra = "0.3.19"
3636
GenericSchur = "0.5.6"
37-
JET = "0.9, 0.10"
37+
JET = "0.9, 0.10, 0.11"
3838
LinearAlgebra = "1"
3939
Mooncake = "0.5"
4040
ParallelTestRunner = "2"

src/pullbacks/svd.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true)
1+
svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true)
22

33
function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV))
44
mask = abs.(Sr' .- Sr) .< degeneracy_atol
@@ -43,7 +43,7 @@ function svd_pullback!(
4343
minmn = min(m, n)
4444
S = diagview(Smat)
4545
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
46-
r = svd_rank(S, rank_atol)
46+
r = svd_rank(S; rank_atol)
4747
Ur = view(U, :, 1:r)
4848
Vᴴr = view(Vᴴ, 1:r, :)
4949
Sr = view(S, 1:r)

test/testsuite/TestSuite.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module TestSuite
1111
using Test
1212
using MatrixAlgebraKit
1313
using MatrixAlgebraKit: diagview
14-
using LinearAlgebra: Diagonal, norm, istriu, istril, I
14+
using LinearAlgebra: Diagonal, norm, istriu, istril, I, mul!
1515
using Random, StableRNGs
1616
using Mooncake
1717
using AMDGPU, CUDA
@@ -85,9 +85,9 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz)
8585
end
8686
instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A))))
8787

88-
function instantiate_rank_deficient_matrix(T, sz; trunc = trunctol(rtol = 0.5))
88+
function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2)))
8989
A = instantiate_matrix(T, sz)
90-
V, C = left_orth!(A; trunc = trunctol(rtol = 0.5))
90+
V, C = left_orth!(A; trunc)
9191
return mul!(A, V, C)
9292
end
9393

0 commit comments

Comments
 (0)