|
| 1 | +#include "repetition_penalty_cpu.h" |
| 2 | +#include "../../../devices/cpu/common_cpu.h" |
| 3 | +#include "../info.h" |
| 4 | +#include "infinicore.h" |
| 5 | +#include <algorithm> |
| 6 | + |
| 7 | +namespace op::repetition_penalty::cpu { |
| 8 | + |
| 9 | +Descriptor::~Descriptor() = default; |
| 10 | + |
| 11 | +infiniStatus_t Descriptor::create( |
| 12 | + infiniopHandle_t handle_, |
| 13 | + Descriptor **desc_ptr, |
| 14 | + infiniopTensorDescriptor_t logits_desc) { |
| 15 | + auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); |
| 16 | + |
| 17 | + auto result = RepetitionPenaltyInfo::create(logits_desc); |
| 18 | + CHECK_RESULT(result); |
| 19 | + |
| 20 | + *desc_ptr = new Descriptor( |
| 21 | + result.take(), |
| 22 | + 0, // No workspace needed for CPU |
| 23 | + nullptr, |
| 24 | + handle->device, handle->device_id); |
| 25 | + return INFINI_STATUS_SUCCESS; |
| 26 | +} |
| 27 | + |
| 28 | +size_t Descriptor::minWorkspaceSize() const { |
| 29 | + return _min_workspace_size; |
| 30 | +} |
| 31 | + |
| 32 | +template <typename T> |
| 33 | +void apply_penalty_cpu( |
| 34 | + T *logits, |
| 35 | + const float *repetition_penalties, |
| 36 | + const uint32_t *token_indices, |
| 37 | + const size_t *token_offsets, |
| 38 | + size_t num_seqs, |
| 39 | + size_t vocab_size) { |
| 40 | + |
| 41 | + for (size_t seq_idx = 0; seq_idx < num_seqs; seq_idx++) { |
| 42 | + float penalty = repetition_penalties[seq_idx]; |
| 43 | + if (penalty == 1.0f) { |
| 44 | + continue; // Skip if no penalty |
| 45 | + } |
| 46 | + |
| 47 | + size_t start = token_offsets[seq_idx]; |
| 48 | + size_t end = token_offsets[seq_idx + 1]; |
| 49 | + for (size_t i = start; i < end; i++) { |
| 50 | + uint32_t token_id = token_indices[i]; |
| 51 | + if (token_id >= vocab_size) { |
| 52 | + continue; // skip out-of-range ids |
| 53 | + } |
| 54 | + size_t offset = seq_idx * vocab_size + token_id; |
| 55 | + T logit_val_orig = logits[offset]; |
| 56 | + float logit_val = utils::cast<float>(logit_val_orig); |
| 57 | + |
| 58 | + // Match PyTorch behavior exactly: val / p if val > 0 else val * p |
| 59 | + if (logit_val > 0.0f) { |
| 60 | + logits[offset] = utils::cast<T>(logit_val / penalty); |
| 61 | + } else { |
| 62 | + // For val <= 0: multiply by penalty (covers negative and zero) |
| 63 | + logits[offset] = utils::cast<T>(logit_val * penalty); |
| 64 | + } |
| 65 | + } |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +infiniStatus_t Descriptor::calculate( |
| 70 | + void *workspace, |
| 71 | + size_t workspace_size, |
| 72 | + void *logits, |
| 73 | + const float *repetition_penalties, |
| 74 | + const uint32_t *token_indices, |
| 75 | + const size_t *token_offsets, |
| 76 | + size_t total_indices, |
| 77 | + void *stream) const { |
| 78 | + |
| 79 | + switch (_info.dt_logits) { |
| 80 | + case INFINI_DTYPE_F16: |
| 81 | + apply_penalty_cpu<fp16_t>( |
| 82 | + reinterpret_cast<fp16_t *>(logits), |
| 83 | + repetition_penalties, |
| 84 | + token_indices, |
| 85 | + token_offsets, |
| 86 | + _info.num_seqs, |
| 87 | + _info.vocab_size); |
| 88 | + break; |
| 89 | + case INFINI_DTYPE_BF16: |
| 90 | + apply_penalty_cpu<bf16_t>( |
| 91 | + reinterpret_cast<bf16_t *>(logits), |
| 92 | + repetition_penalties, |
| 93 | + token_indices, |
| 94 | + token_offsets, |
| 95 | + _info.num_seqs, |
| 96 | + _info.vocab_size); |
| 97 | + break; |
| 98 | + case INFINI_DTYPE_F32: |
| 99 | + apply_penalty_cpu<float>( |
| 100 | + reinterpret_cast<float *>(logits), |
| 101 | + repetition_penalties, |
| 102 | + token_indices, |
| 103 | + token_offsets, |
| 104 | + _info.num_seqs, |
| 105 | + _info.vocab_size); |
| 106 | + break; |
| 107 | + case INFINI_DTYPE_F64: |
| 108 | + apply_penalty_cpu<double>( |
| 109 | + reinterpret_cast<double *>(logits), |
| 110 | + repetition_penalties, |
| 111 | + token_indices, |
| 112 | + token_offsets, |
| 113 | + _info.num_seqs, |
| 114 | + _info.vocab_size); |
| 115 | + break; |
| 116 | + default: |
| 117 | + return INFINI_STATUS_BAD_TENSOR_DTYPE; |
| 118 | + } |
| 119 | + |
| 120 | + return INFINI_STATUS_SUCCESS; |
| 121 | +} |
| 122 | + |
| 123 | +} // namespace op::repetition_penalty::cpu |
0 commit comments