Skip to content

Commit a82ad42

Browse files
fix: adapt kernel call under new Device/DeviceGuard impl
1 parent f15ccf9 commit a82ad42

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

infini_train/src/nn/parallel/ddp/param_and_grad_buffer.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ void ParamAndGradBucketGroup::AccumulateParamGrad(const std::shared_ptr<Tensor>
224224
if (overwrite) {
225225
bucket_grad_view->CopyFrom(*grad);
226226
} else {
227-
auto kernel = Dispatcher::Instance().GetKernel({parameter->GetDevice()->Type(), "AccumulateGrad"});
227+
auto device = parameter->GetDevice();
228+
auto kernel = Dispatcher::Instance().GetKernel({device.type(), "AccumulateGrad"});
228229
kernel.Call<void>(grad, learning_rate, bucket_grad_view);
229230
}
230231
}

0 commit comments

Comments
 (0)