@@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
88using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
99import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
1010import MatrixAlgebraKit: heevj!, heevd!, geev!
11- import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!
11+ import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!
1212using CUDA, CUDA. cuBLAS
1313using CUDA: i32
1414using LinearAlgebra
@@ -205,4 +205,8 @@ function svd_pullback!(ΔA::AnyCuMatrix, A, USVᴴ, ΔUSVᴴ, ind::AnyCuVector;
205205 return svd_pullback! (ΔA, A, USVᴴ, ΔUSVᴴ, collect (ind); kwargs... )
206206end
207207
208+ function eigh_pullback! (ΔA:: AnyCuMatrix , A, DV, ΔDV, ind:: AnyCuVector ; kwargs... )
209+ return eigh_pullback! (ΔA, A, DV, ΔDV, collect (ind); kwargs... )
210+ end
211+
208212end
0 commit comments