|
| 1 | +#include <cstddef> |
| 2 | +#include <cstring> |
| 3 | +#include <limits> |
| 4 | +#include <memory> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +#include "glog/logging.h" |
| 8 | + |
| 9 | +#include "infini_train/include/dispatcher.h" |
| 10 | +#include "infini_train/include/tensor.h" |
| 11 | + |
| 12 | +namespace infini_train::kernels::cpu { |
| 13 | + |
| 14 | +std::vector<std::shared_ptr<Tensor>> TopKForward(const std::shared_ptr<Tensor> &input, int64_t topk, int64_t dim, |
| 15 | + bool largest, bool sorted) { |
| 16 | + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only"; |
| 17 | + CHECK_GE(input->Dims().size(), 1); |
| 18 | + (void)sorted; |
| 19 | + |
| 20 | + const auto &dims = input->Dims(); |
| 21 | + if (dim < 0) { |
| 22 | + dim += static_cast<int64_t>(dims.size()); |
| 23 | + } |
| 24 | + CHECK_GE(dim, 0); |
| 25 | + CHECK_LT(dim, static_cast<int64_t>(dims.size())); |
| 26 | + |
| 27 | + const int64_t dim_size = dims[dim]; |
| 28 | + CHECK_GT(dim_size, 0); |
| 29 | + CHECK_GT(topk, 0); |
| 30 | + CHECK_LE(topk, dim_size); |
| 31 | + |
| 32 | + int64_t outer_size = 1; |
| 33 | + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } |
| 34 | + int64_t inner_size = 1; |
| 35 | + for (size_t idx = static_cast<size_t>(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } |
| 36 | + |
| 37 | + auto topk_dims = dims; |
| 38 | + topk_dims[dim] = topk; |
| 39 | + auto top_values = std::make_shared<Tensor>(topk_dims, input->Dtype(), input->GetDevice()); |
| 40 | + auto top_indices = std::make_shared<Tensor>(topk_dims, DataType::kINT64, input->GetDevice()); |
| 41 | + |
| 42 | + const float *in = static_cast<const float *>(input->DataPtr()); |
| 43 | + float *values = static_cast<float *>(top_values->DataPtr()); |
| 44 | + int64_t *indices = static_cast<int64_t *>(top_indices->DataPtr()); |
| 45 | + for (int64_t outer = 0; outer < outer_size; ++outer) { |
| 46 | + for (int64_t inner = 0; inner < inner_size; ++inner) { |
| 47 | + std::vector<bool> selected_indices(dim_size, false); |
| 48 | + for (int64_t selected = 0; selected < topk; ++selected) { |
| 49 | + int64_t best_idx = -1; |
| 50 | + float best_value |
| 51 | + = largest ? -std::numeric_limits<float>::infinity() : std::numeric_limits<float>::infinity(); |
| 52 | + for (int64_t idx = 0; idx < dim_size; ++idx) { |
| 53 | + if (selected_indices[idx]) { |
| 54 | + continue; |
| 55 | + } |
| 56 | + const float value = in[outer * dim_size * inner_size + idx * inner_size + inner]; |
| 57 | + const bool better = largest ? value > best_value : value < best_value; |
| 58 | + if (better) { |
| 59 | + best_value = value; |
| 60 | + best_idx = idx; |
| 61 | + } |
| 62 | + } |
| 63 | + CHECK_GE(best_idx, 0); |
| 64 | + selected_indices[best_idx] = true; |
| 65 | + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; |
| 66 | + values[out_offset] = best_value; |
| 67 | + indices[out_offset] = best_idx; |
| 68 | + } |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + return {top_values, top_indices}; |
| 73 | +} |
| 74 | + |
| 75 | +std::shared_ptr<Tensor> TopKBackward(const std::shared_ptr<Tensor> &grad_values, const std::shared_ptr<Tensor> &indices, |
| 76 | + const std::vector<int64_t> &input_dims, int64_t dim) { |
| 77 | + CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices"; |
| 78 | + CHECK(grad_values->Dims() == indices->Dims()); |
| 79 | + CHECK(!input_dims.empty()); |
| 80 | + if (dim < 0) { |
| 81 | + dim += static_cast<int64_t>(input_dims.size()); |
| 82 | + } |
| 83 | + CHECK_GE(dim, 0); |
| 84 | + CHECK_LT(dim, static_cast<int64_t>(input_dims.size())); |
| 85 | + |
| 86 | + const int64_t dim_size = input_dims[dim]; |
| 87 | + const int64_t topk = indices->Dims()[dim]; |
| 88 | + int64_t outer_size = 1; |
| 89 | + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } |
| 90 | + int64_t inner_size = 1; |
| 91 | + for (size_t idx = static_cast<size_t>(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } |
| 92 | + |
| 93 | + auto grad_input = std::make_shared<Tensor>(input_dims, grad_values->Dtype(), grad_values->GetDevice()); |
| 94 | + std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes()); |
| 95 | + |
| 96 | + const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype()); |
| 97 | + const auto *src = static_cast<const std::byte *>(grad_values->DataPtr()); |
| 98 | + auto *dst = static_cast<std::byte *>(grad_input->DataPtr()); |
| 99 | + const auto *idx_ptr = static_cast<const int64_t *>(indices->DataPtr()); |
| 100 | + for (int64_t outer = 0; outer < outer_size; ++outer) { |
| 101 | + for (int64_t inner = 0; inner < inner_size; ++inner) { |
| 102 | + for (int64_t selected = 0; selected < topk; ++selected) { |
| 103 | + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; |
| 104 | + const int64_t selected_idx = idx_ptr[out_offset]; |
| 105 | + CHECK_GE(selected_idx, 0); |
| 106 | + CHECK_LT(selected_idx, dim_size); |
| 107 | + std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size, |
| 108 | + src + out_offset * elem_size, elem_size); |
| 109 | + } |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | + return grad_input; |
| 114 | +} |
| 115 | + |
| 116 | +} // namespace infini_train::kernels::cpu |
| 117 | + |
| 118 | +#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \ |
| 119 | + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) |
| 120 | + |
| 121 | +REGISTER_CPU_TOPK_KERNEL(TopKForward) |
| 122 | +REGISTER_CPU_TOPK_KERNEL(TopKBackward) |
| 123 | + |
| 124 | +#undef REGISTER_CPU_TOPK_KERNEL |
0 commit comments