@@ -115,25 +115,21 @@ Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::TensorMap, t::Tensor
115115 TensorMap (Mooncake. _add_to_primal_internal (c, p. data, t. data, unsafe), space (p))
116116function Mooncake. tangent_to_primal_internal!! (p:: TensorMap , t:: TensorMap , c:: Mooncake.MaybeCache )
117117 data = Mooncake. tangent_to_primal_internal!! (p. data, t. data, c)
118- data === p. data || copy! (p. data, data)
119- return p
118+ return data === p. data ? p : TensorMap (data, space (p))
120119end
121120function Mooncake. primal_to_tangent_internal!! (t:: TensorMap , p:: TensorMap , c:: Mooncake.MaybeCache )
122121 data = Mooncake. primal_to_tangent_internal!! (t. data, p. data, c)
123- data === t. data || copy! (t. data, data)
124- return t
122+ return data === t. data ? t : TensorMap (data, space (t))
125123end
126124Mooncake. _add_to_primal_internal (c:: Mooncake.MaybeCache , p:: DiagonalTensorMap , t:: DiagonalTensorMap , unsafe:: Bool ) =
127125 DiagonalTensorMap (Mooncake. _add_to_primal_internal (c, p. data, t. data, unsafe), space (p))
128126function Mooncake. tangent_to_primal_internal!! (p:: DiagonalTensorMap , t:: DiagonalTensorMap , c:: Mooncake.MaybeCache )
129127 data = Mooncake. tangent_to_primal_internal!! (p. data, t. data, c)
130- data === p. data || copy! (p. data, data)
131- return p
128+ return data === p. data ? p : DiagonalTensorMap (data, space (p, 1 ))
132129end
133130function Mooncake. primal_to_tangent_internal!! (t:: DiagonalTensorMap , p:: DiagonalTensorMap , c:: Mooncake.MaybeCache )
134131 data = Mooncake. primal_to_tangent_internal!! (t. data, p. data, c)
135- data === t. data || copy! (t. data, data)
136- return p
132+ return data === t. data ? t : DiagonalTensorMap (data, space (t, 1 ))
137133end
138134
139135# to convert from/to chainrules tangents
0 commit comments