diff --git a/infini_train/src/kernels/cpu/embedding.cc b/infini_train/src/kernels/cpu/embedding.cc index 6b3a5aa6..a1da926f 100644 --- a/infini_train/src/kernels/cpu/embedding.cc +++ b/infini_train/src/kernels/cpu/embedding.cc @@ -46,7 +46,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, for (int i = 0; i < input->NumElements(); ++i) { int idx = static_cast(static_cast(input->DataPtr())[i]); for (int j = 0; j < embedding_dim; ++j) { - static_cast(grad_weight->DataPtr())[idx * embedding_dim + j] // <-- 修复这里 + static_cast(grad_weight->DataPtr())[idx * embedding_dim + j] += static_cast(grad_output->DataPtr())[i * embedding_dim + j]; } }