Skip to content

Commit 5677370

Browse files
lkdvosJutho
andauthored
fix: special case pullbacks for fully truncated decompositions (#233)
* Guard pullback implementations against empty `ind` * better guard against empty pullbacks * add test cases * update changelog * Apply suggestions from code review Co-authored-by: Jutho <Jutho@users.noreply.github.com> --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 978effe commit 5677370

5 files changed

Lines changed: 48 additions & 0 deletions

File tree

docs/src/changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec
3030

3131
### Fixed
3232

33+
- Pullbacks of `eig_trunc`, `eigh_trunc`, and `svd_trunc` no longer error when the truncation strategy keeps no values; `svd_pullback!` also handles the zero-rank case where every singular value falls below `rank_atol` ([#233](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/233)).
34+
3335
### Performance
3436

3537
## [0.6.7](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.6...v0.6.7) - 2026-05-06

src/pullbacks/eig.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ function check_and_prepare_eig_cotangents(
2626
ΔV₊ = nothing
2727
VᴴΔV₁ = zero!(similar(V, (p, p)))
2828
end
29+
2930
bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v
3031
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
3132
end
@@ -81,6 +82,7 @@ function eig_pullback!(
8182
D = diagview(Dmat)
8283
n == length(D) || throw(DimensionMismatch())
8384
(n, n) == size(ΔA) || throw(DimensionMismatch())
85+
iszero(n) && return ΔA
8486
ViG = inv(V)'
8587

8688
ΔDmat, ΔV = ΔDV
@@ -144,6 +146,7 @@ function eig_trunc_pullback!(
144146
(n, n) == size(ΔA) || throw(DimensionMismatch())
145147
D = diagview(Dmat)
146148
p == length(D) || throw(DimensionMismatch())
149+
iszero(p) && return ΔA
147150
G = V' * V
148151
ViG = V / LinearAlgebra.cholesky!(G)
149152

src/pullbacks/eigh.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ function check_and_prepare_eigh_cotangents(
2727
ΔV₊ = nothing
2828
aVᴴΔV₁ = zero!(similar(V, (p, p)))
2929
end
30+
3031
bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v
3132
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
3233
end
@@ -82,6 +83,7 @@ function eigh_pullback!(
8283
D = diagview(Dmat)
8384
n == length(D) || throw(DimensionMismatch())
8485
(n, n) == size(ΔA) || throw(DimensionMismatch())
86+
iszero(n) && return ΔA
8587

8688
ΔDmat, ΔV = ΔDV
8789
VᴴΔAV, = check_and_prepare_eigh_cotangents(
@@ -137,6 +139,7 @@ function eigh_trunc_pullback!(
137139
D = diagview(Dmat)
138140
p == length(D) || throw(DimensionMismatch())
139141
(n, n) == size(ΔA) || throw(DimensionMismatch())
142+
iszero(p) && return ΔA
140143

141144
ΔDmat, ΔV = ΔDV
142145
VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents(

src/pullbacks/svd.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ function check_and_prepare_svd_cotangents(
8181
ΔV₊ᴴ = nothing
8282
aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r)))
8383
end
84+
8485
bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v
8586
return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v)
8687
end
@@ -149,6 +150,7 @@ function svd_pullback!(
149150
(m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)"))
150151
S = diagview(Smat)
151152
r = svd_rank(S; rank_atol)
153+
iszero(r) && return ΔA
152154

153155
U₁ = view(U, :, 1:r)
154156
V₁ᴴ = view(Vᴴ, 1:r, :)
@@ -220,6 +222,7 @@ function svd_trunc_pullback!(
220222
p = length(S)
221223
p == size(U, 2) || throw(DimensionMismatch(lazy"U has $p columns but S has $(length(S)) singular values"))
222224
p == size(Vᴴ, 1) || throw(DimensionMismatch(lazy"Vᴴ has $p rows but S has $(length(S)) singular values"))
225+
iszero(p) && return ΔA
223226

224227
# Extract and check the cotangents
225228
ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ

test/testsuite/chainrules.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,16 @@ function test_chainrules_eig(
329329
output_tangent = ΔDVtrunc, atol = atol, rtol = rtol
330330
)
331331
@test isequal(ΔDVtrunc, ΔDVtrunc_copy)
332+
@testset "empty truncation" begin
333+
truncalg = TruncatedAlgorithm(alg, truncrank(0))
334+
DV, DVtrunc, _, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
335+
@test isempty(diagview(DVtrunc[1]))
336+
ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc)
337+
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind)
338+
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc)
339+
@test iszero(dA1)
340+
@test iszero(dA2)
341+
end
332342
end
333343
end
334344
end
@@ -473,6 +483,16 @@ function test_chainrules_eigh(
473483
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
474484
)
475485
@test isequal(ΔDVtrunc, ΔDVtrunc_copy)
486+
@testset "empty truncation" begin
487+
truncalg = TruncatedAlgorithm(alg, truncrank(0))
488+
DV, DVtrunc, _, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
489+
@test isempty(diagview(DVtrunc[1]))
490+
ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc)
491+
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind)
492+
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc)
493+
@test iszero(dA1)
494+
@test iszero(dA2)
495+
end
476496
end
477497
end
478498
end
@@ -624,6 +644,23 @@ function test_chainrules_svd(
624644
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
625645
)
626646
@test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy)
647+
@testset "empty truncation / zero rank" begin
648+
truncalg = TruncatedAlgorithm(alg, truncrank(0))
649+
USVᴴ, USVᴴtrunc, _, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
650+
@test isempty(diagview(USVᴴtrunc[2]))
651+
ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc)
652+
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind)
653+
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, USVᴴtrunc, ΔUSVᴴtrunc)
654+
@test iszero(dA1)
655+
@test iszero(dA2)
656+
# svd_pullback! short-circuits when every singular value is below rank_atol
657+
_, ΔUSVᴴ = ad_svd_compact_setup(A)
658+
huge_atol = 2 * maximum(diagview(USVᴴ[2]))
659+
dA3 = MatrixAlgebraKit.svd_pullback!(
660+
zero(A), A, USVᴴ, ΔUSVᴴ; rank_atol = huge_atol
661+
)
662+
@test iszero(dA3)
663+
end
627664
end
628665
end
629666
end

0 commit comments

Comments
 (0)