Skip to content

Commit 25e53b5

Browse files
committed
naively specialize diagonal pullback
1 parent 4008d91 commit 25e53b5

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

src/pullbacks/eigh.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ function eigh_pullback!(
6868
end
6969
return ΔA
7070
end
71+
function eigh_pullback!(
72+
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
73+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
74+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
75+
)
76+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
77+
ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
78+
diagview(ΔA) .+= diagview(ΔA_full)
79+
return ΔA
80+
end
7181

7282
"""
7383
eigh_trunc_pullback!(
@@ -141,6 +151,16 @@ function eigh_trunc_pullback!(
141151
end
142152
return ΔA
143153
end
154+
function eigh_trunc_pullback!(
155+
ΔA::Diagonal, A, DV, ΔDV;
156+
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
157+
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
158+
)
159+
ΔA_full = zero!(similar(ΔA, size(ΔA)))
160+
ΔA_full = eigh_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
161+
diagview(ΔA) .+= diagview(ΔA_full)
162+
return ΔA
163+
end
144164

145165
"""
146166
eigh_vals_pullback!(

0 commit comments

Comments
 (0)