Skip to content

Commit 2bbe9d2

Browse files
committed
possibly fix implementation
1 parent 29ba61e commit 2bbe9d2

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,11 +792,15 @@ for (f!, f, adj) in (
792792
arg = $f!(A, arg, Mooncake.primal(alg_dalg))
793793

794794
function $adj(::NoRData)
795-
dA .+= $f(darg)
796-
dA === darg || zero!(darg)
795+
$f!(darg)
796+
if dA !== darg
797+
dA .+= darg
798+
zero!(darg)
799+
end
797800
copy!(arg, argc)
798801
return ntuple(Returns(NoRData()), 4)
799802
end
803+
800804
return arg_darg, $adj
801805
end
802806

0 commit comments

Comments
 (0)