We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f15ccf9 commit a82ad42Copy full SHA for a82ad42
1 file changed
infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc
@@ -224,7 +224,8 @@ void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr<Tensor>
224
if (overwrite) {
225
bucket_grad_view->CopyFrom(*grad);
226
} else {
227
- auto kernel = Dispatcher::Instance().GetKernel({parameter->GetDevice()->Type(), "AccumulateGrad"});
+ auto device = parameter->GetDevice();
228
+ auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"});
229
kernel.Call<void>(grad, learning_rate, bucket_grad_view);
230
}
231
0 commit comments