Skip to content

Commit e13c881

Browse files
committed
Don't use diagview(Vector)
1 parent fd9544f commit e13c881

2 files changed

Lines changed: 7 additions & 8 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ for (f!, f, f_full, pb, adj) in (
136136
copy!(D, diagview(DV[1]))
137137
V = DV[2]
138138
function $adj(::Mooncake.NoRData)
139-
$pb(dA, A, (D, V), (dD, nothing))
139+
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
140140
MatrixAlgebraKit.zero!(dD)
141141
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
142142
end
@@ -152,8 +152,8 @@ for (f!, f, f_full, pb, adj) in (
152152
output = diagview(DV[1])
153153
output_codual = Mooncake.CoDual(output, Mooncake.zero_tangent(output))
154154
function $adj(::Mooncake.NoRData)
155-
D_dD = Mooncake.arrayify(D_dD)
156-
$pb(dA, A, (D, V), (dD, nothing))
155+
D, dD = Mooncake.arrayify(output_codual)
156+
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
157157
MatrixAlgebraKit.zero!(dD)
158158
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
159159
end
@@ -181,7 +181,7 @@ for (f, pb, adj) in (
181181
function $adj(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
182182
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
183183
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
184-
abs() > MatrixAlgebraKit.defaulttol() && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
184+
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
185185
D, dD = Mooncake.arrayify(Dtrunc, dDtrunc_)
186186
V, dV = Mooncake.arrayify(Vtrunc, dVtrunc_)
187187
$pb(dA, A, (D, V), (dD, dV))
@@ -275,7 +275,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::Co
275275
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
276276
copy!(S, diagview(nS))
277277
function dsvd_vals_adjoint(::Mooncake.NoRData)
278-
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
278+
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
279279
MatrixAlgebraKit.zero!(dS)
280280
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
281281
end
@@ -294,7 +294,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_vals)}, A_dA::CoD
294294
S_codual = Mooncake.CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(S)))
295295
function dsvd_vals_adjoint(::Mooncake.NoRData)
296296
S, dS = Mooncake.arrayify(S_codual)
297-
svd_pullback!(dA, A, (U, S, Vᴴ), (nothing, dS, nothing))
297+
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
298298
MatrixAlgebraKit.zero!(dS)
299299
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
300300
end
@@ -317,7 +317,7 @@ function Mooncake.rrule!!(::CoDual{typeof(MatrixAlgebraKit.svd_trunc)}, A_dA::Co
317317
function dsvd_trunc_adjoint(dy::Tuple{Mooncake.NoRData, Mooncake.NoRData, Mooncake.NoRData, T}) where {T <: Real}
318318
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
319319
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
320-
abs() > MatrixAlgebraKit.defaulttol() && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
320+
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
321321
U, dU = Mooncake.arrayify(Utrunc, dUtrunc_)
322322
S, dS = Mooncake.arrayify(Strunc, dStrunc_)
323323
Vᴴ, dVᴴ = Mooncake.arrayify(Vᴴtrunc, dVᴴtrunc_)

src/common/view.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# diagind: provided by LinearAlgebra.jl
22
diagview(D::Diagonal) = D.diag
33
diagview(D::AbstractMatrix) = view(D, diagind(D))
4-
diagview(D::AbstractVector) = D
54

65
# triangularind
76
function lowertriangularind(A::AbstractMatrix)

0 commit comments

Comments
 (0)