Skip to content

Commit b4d0a0a

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: fix cast kernel count, fix ReduceAddCoalesced kernel to make sure profiler works correctly
1 parent d2ee257 commit b4d0a0a

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

infini_train/include/dispatcher.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class Dispatcher {
7474
template <typename RetT, class... ArgsT> RetT Call(KeyT key, ArgsT... args) const {
7575
auto kernel = this->GetKernel(key);
7676
tls_autocast_context.Autocast(key, args...);
77+
#ifdef PROFILE_MODE
78+
SetProfileContext(key.second, key.first);
79+
#endif
7780
return kernel.Call<RetT>(std::forward<ArgsT>(args)...);
7881
}
7982

infini_train/src/kernels/cuda/comm.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ std::vector<std::shared_ptr<Tensor>> Broadcast(const std::vector<std::shared_ptr
2525
std::vector<std::shared_ptr<Tensor>> ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads,
2626
Device destination) {
2727
std::vector<std::shared_ptr<Tensor>> outputs;
28-
auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"});
2928
std::vector<std::vector<std::shared_ptr<Tensor>>> to_destination_grads;
3029
for (int i = 0; i < grads[0].size(); ++i) {
3130
outputs.emplace_back(std::make_shared<Tensor>(grads[0][i]->Dims(), grads[0][i]->Dtype(), destination));
@@ -37,6 +36,9 @@ std::vector<std::shared_ptr<Tensor>> ReduceAddCoalesced(const std::vector<std::v
3736
to_destination_grads[i].push_back(std::make_shared<Tensor>(grads[i][j]->To(destination)));
3837
}
3938
}
39+
// NOTE(zbl): To ensure Profiler works correctly, there should not be any other kernel calls
40+
// between GetKernel and kernel.Call, otherwise ProfileContext would be tainted
41+
auto kernel = Dispatcher::Instance().GetKernel({destination.type(), "AccumulateGrad"});
4042
for (int i = 0; i < grads.size(); ++i) {
4143
for (int j = 0; j < grads[i].size(); ++j) {
4244
kernel.Call<void>(to_destination_grads[i][j], static_cast<float>(1.0), outputs[j]);

0 commit comments

Comments
 (0)