Skip to content

Commit d32890c

Browse files
committed
correct bang-bang implementation
1 parent f6f959c commit d32890c

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

ext/TensorKitMooncakeExt/tangent.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,25 +115,21 @@ Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::TensorMap, t::Tensor
115115
TensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p))
116116
function Mooncake.tangent_to_primal_internal!!(p::TensorMap, t::TensorMap, c::Mooncake.MaybeCache)
117117
data = Mooncake.tangent_to_primal_internal!!(p.data, t.data, c)
118-
data === p.data || copy!(p.data, data)
119-
return p
118+
return data === p.data ? p : TensorMap(data, space(p))
120119
end
121120
function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache)
122121
data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c)
123-
data === t.data || copy!(t.data, data)
124-
return t
122+
return data === t.data ? t : TensorMap(data, space(t))
125123
end
126124
Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::DiagonalTensorMap, t::DiagonalTensorMap, unsafe::Bool) =
127125
DiagonalTensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p))
128126
function Mooncake.tangent_to_primal_internal!!(p::DiagonalTensorMap, t::DiagonalTensorMap, c::Mooncake.MaybeCache)
129127
data = Mooncake.tangent_to_primal_internal!!(p.data, t.data, c)
130-
data === p.data || copy!(p.data, data)
131-
return p
128+
return data === p.data ? p : DiagonalTensorMap(data, space(p, 1))
132129
end
133130
function Mooncake.primal_to_tangent_internal!!(t::DiagonalTensorMap, p::DiagonalTensorMap, c::Mooncake.MaybeCache)
134131
data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c)
135-
data === t.data || copy!(t.data, data)
136-
return p
132+
return data === t.data ? t : DiagonalTensorMap(data, space(t, 1))
137133
end
138134

139135
# to convert from/to chainrules tangents

0 commit comments

Comments
 (0)