Skip to content

Commit 4a3212c

Browse files
committed
refactor: rename topk_mask to topk and align with torch.topk API
1 parent e9567ce commit 4a3212c

8 files changed

Lines changed: 358 additions & 267 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <vector>
5+
6+
#include "infini_train/include/autograd/function.h"
7+
8+
namespace infini_train {
9+
class Tensor;
10+
}
11+
12+
namespace infini_train::autograd {
13+
14+
// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once
15+
// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices
16+
// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs.
17+
class TopK : public Function {
18+
public:
19+
static constexpr char kType[] = "TopKFunction";
20+
21+
explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true)
22+
: Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {}
23+
24+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
25+
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
26+
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
27+
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;
28+
29+
std::shared_ptr<Tensor> TopIndices() const;
30+
31+
private:
32+
int64_t topk_ = 1;
33+
int64_t dim_ = -1;
34+
bool largest_ = true;
35+
bool sorted_ = true;
36+
std::shared_ptr<Tensor> top_indices_;
37+
std::vector<int64_t> input_dims_;
38+
};
39+
40+
} // namespace infini_train::autograd

infini_train/include/autograd/topk_mask.h

Lines changed: 0 additions & 29 deletions
This file was deleted.

infini_train/src/autograd/topk.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "infini_train/include/autograd/topk.h"
2+
3+
#include "glog/logging.h"
4+
5+
#include "infini_train/include/dispatcher.h"
6+
#include "infini_train/include/tensor.h"
7+
8+
namespace infini_train::autograd {
9+
10+
std::vector<std::shared_ptr<Tensor>> TopK::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
11+
CHECK_EQ(input_tensors.size(), 1);
12+
CHECK_GT(topk_, 0);
13+
const auto &input = input_tensors[0];
14+
auto device = input->GetDevice().type();
15+
auto topk_outputs = Dispatcher::Instance().Call<std::vector<std::shared_ptr<Tensor>>>(
16+
{device, "TopKForward"}, input, topk_, dim_, largest_, sorted_);
17+
CHECK_EQ(topk_outputs.size(), 2);
18+
top_indices_ = topk_outputs[1];
19+
return {topk_outputs[0]};
20+
}
21+
22+
void TopK::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
23+
const std::vector<std::shared_ptr<Tensor>> &) {
24+
input_dims_ = input_tensors[0]->Dims();
25+
saved_tensors_ = {top_indices_};
26+
}
27+
28+
std::vector<std::shared_ptr<Tensor>> TopK::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
29+
CHECK_EQ(grad_outputs.size(), 1);
30+
const auto &top_grad = grad_outputs[0];
31+
const auto &top_indices = saved_tensors_[0];
32+
auto device = top_grad->GetDevice().type();
33+
return {Dispatcher::Instance().Call<std::shared_ptr<Tensor>>({device, "TopKBackward"}, top_grad, top_indices,
34+
input_dims_, dim_)};
35+
}
36+
37+
std::shared_ptr<Tensor> TopK::TopIndices() const { return top_indices_; }
38+
39+
} // namespace infini_train::autograd

infini_train/src/autograd/topk_mask.cc

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include <cstddef>
2+
#include <cstring>
3+
#include <limits>
4+
#include <memory>
5+
#include <vector>
6+
7+
#include "glog/logging.h"
8+
9+
#include "infini_train/include/dispatcher.h"
10+
#include "infini_train/include/tensor.h"
11+
12+
namespace infini_train::kernels::cpu {
13+
14+
std::vector<std::shared_ptr<Tensor>> TopKForward(const std::shared_ptr<Tensor> &input, int64_t topk, int64_t dim,
15+
bool largest, bool sorted) {
16+
CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only";
17+
CHECK_GE(input->Dims().size(), 1);
18+
(void)sorted;
19+
20+
const auto &dims = input->Dims();
21+
if (dim < 0) {
22+
dim += static_cast<int64_t>(dims.size());
23+
}
24+
CHECK_GE(dim, 0);
25+
CHECK_LT(dim, static_cast<int64_t>(dims.size()));
26+
27+
const int64_t dim_size = dims[dim];
28+
CHECK_GT(dim_size, 0);
29+
CHECK_GT(topk, 0);
30+
CHECK_LE(topk, dim_size);
31+
32+
int64_t outer_size = 1;
33+
for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; }
34+
int64_t inner_size = 1;
35+
for (size_t idx = static_cast<size_t>(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; }
36+
37+
auto topk_dims = dims;
38+
topk_dims[dim] = topk;
39+
auto top_values = std::make_shared<Tensor>(topk_dims, input->Dtype(), input->GetDevice());
40+
auto top_indices = std::make_shared<Tensor>(topk_dims, DataType::kINT64, input->GetDevice());
41+
42+
const float *in = static_cast<const float *>(input->DataPtr());
43+
float *values = static_cast<float *>(top_values->DataPtr());
44+
int64_t *indices = static_cast<int64_t *>(top_indices->DataPtr());
45+
for (int64_t outer = 0; outer < outer_size; ++outer) {
46+
for (int64_t inner = 0; inner < inner_size; ++inner) {
47+
std::vector<bool> selected_indices(dim_size, false);
48+
for (int64_t selected = 0; selected < topk; ++selected) {
49+
int64_t best_idx = -1;
50+
float best_value
51+
= largest ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity();
52+
for (int64_t idx = 0; idx < dim_size; ++idx) {
53+
if (selected_indices[idx]) {
54+
continue;
55+
}
56+
const float value = in[outer * dim_size * inner_size + idx * inner_size + inner];
57+
const bool better = largest ? value > best_value : value < best_value;
58+
if (better) {
59+
best_value = value;
60+
best_idx = idx;
61+
}
62+
}
63+
CHECK_GE(best_idx, 0);
64+
selected_indices[best_idx] = true;
65+
const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner;
66+
values[out_offset] = best_value;
67+
indices[out_offset] = best_idx;
68+
}
69+
}
70+
}
71+
72+
return {top_values, top_indices};
73+
}
74+
75+
std::shared_ptr<Tensor> TopKBackward(const std::shared_ptr<Tensor> &grad_values, const std::shared_ptr<Tensor> &indices,
76+
const std::vector<int64_t> &input_dims, int64_t dim) {
77+
CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices";
78+
CHECK(grad_values->Dims() == indices->Dims());
79+
CHECK(!input_dims.empty());
80+
if (dim < 0) {
81+
dim += static_cast<int64_t>(input_dims.size());
82+
}
83+
CHECK_GE(dim, 0);
84+
CHECK_LT(dim, static_cast<int64_t>(input_dims.size()));
85+
86+
const int64_t dim_size = input_dims[dim];
87+
const int64_t topk = indices->Dims()[dim];
88+
int64_t outer_size = 1;
89+
for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; }
90+
int64_t inner_size = 1;
91+
for (size_t idx = static_cast<size_t>(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; }
92+
93+
auto grad_input = std::make_shared<Tensor>(input_dims, grad_values->Dtype(), grad_values->GetDevice());
94+
std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes());
95+
96+
const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype());
97+
const auto *src = static_cast<const std::byte *>(grad_values->DataPtr());
98+
auto *dst = static_cast<std::byte *>(grad_input->DataPtr());
99+
const auto *idx_ptr = static_cast<const int64_t *>(indices->DataPtr());
100+
for (int64_t outer = 0; outer < outer_size; ++outer) {
101+
for (int64_t inner = 0; inner < inner_size; ++inner) {
102+
for (int64_t selected = 0; selected < topk; ++selected) {
103+
const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner;
104+
const int64_t selected_idx = idx_ptr[out_offset];
105+
CHECK_GE(selected_idx, 0);
106+
CHECK_LT(selected_idx, dim_size);
107+
std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size,
108+
src + out_offset * elem_size, elem_size);
109+
}
110+
}
111+
}
112+
113+
return grad_input;
114+
}
115+
116+
} // namespace infini_train::kernels::cpu
117+
118+
#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \
119+
REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name)
120+
121+
REGISTER_CPU_TOPK_KERNEL(TopKForward)
122+
REGISTER_CPU_TOPK_KERNEL(TopKBackward)
123+
124+
#undef REGISTER_CPU_TOPK_KERNEL

infini_train/src/kernels/cpu/topk_mask.cc

Lines changed: 0 additions & 88 deletions
This file was deleted.

0 commit comments

Comments
 (0)