Skip to content

Commit c5505db

Browse files
committed
graph compatible penalty unit test passed
Signed-off-by: Ceng23333 <441651826@qq.com>
1 parent 7c97894 commit c5505db

13 files changed

Lines changed: 887 additions & 5 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#ifndef __INFINIOP_REPETITION_PENALTY_API_H__
2+
#define __INFINIOP_REPETITION_PENALTY_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
#include <stdint.h>
6+
7+
typedef struct InfiniopDescriptor *infiniopRepetitionPenaltyDescriptor_t;
8+
9+
/**
10+
* @brief Creates a repetition penalty operator descriptor.
11+
*
12+
* @param handle InfiniCore handle
13+
* @param desc_ptr Output descriptor pointer
14+
* @param logits_desc Logits tensor descriptor [num_seqs, vocab_size] - will be modified in-place
15+
* @return infiniStatus_t Status code
16+
*/
17+
__C __export infiniStatus_t infiniopCreateRepetitionPenaltyDescriptor(
18+
infiniopHandle_t handle,
19+
infiniopRepetitionPenaltyDescriptor_t *desc_ptr,
20+
infiniopTensorDescriptor_t logits_desc);
21+
22+
/**
23+
* @brief Gets the workspace size required for repetition penalty operation.
24+
*
25+
* @param desc Operator descriptor
26+
* @param size Output workspace size
27+
* @return infiniStatus_t Status code
28+
*/
29+
__C __export infiniStatus_t infiniopGetRepetitionPenaltyWorkspaceSize(
30+
infiniopRepetitionPenaltyDescriptor_t desc,
31+
size_t *size);
32+
33+
/**
34+
* @brief Applies repetition penalty to logits in-place using token indices only.
35+
*
36+
* @param desc Operator descriptor
37+
* @param workspace Workspace buffer
38+
* @param workspace_size Workspace size
39+
* @param logits Logits tensor [num_seqs, vocab_size] - modified in-place (device pointer)
40+
* @param repetition_penalties Repetition penalty values [num_seqs] - device pointer for GPU backends, host pointer for CPU
41+
* @param token_indices Flattened token ids to penalize (device pointer)
42+
* @param token_offsets Prefix sums into token_indices, length = num_seqs + 1 (device pointer)
43+
* @param total_indices Total number of token indices across all sequences (token_offsets[num_seqs])
44+
* @param stream CUDA stream
45+
* @return infiniStatus_t Status code
46+
*
47+
* @note For CUDA graph compatibility:
48+
* - repetition_penalties and token buffers must be device pointers for GPU backends
49+
* - total_indices must be computed on host before graph capture
50+
* - The caller is responsible for copying penalty values and token buffers to device before graph capture
51+
*/
52+
__C __export infiniStatus_t infiniopApplyRepetitionPenalty(
53+
infiniopRepetitionPenaltyDescriptor_t desc,
54+
void *workspace,
55+
size_t workspace_size,
56+
void *logits,
57+
const float *repetition_penalties,
58+
const uint32_t *token_indices, // flattened token ids to penalize
59+
const size_t *token_offsets, // prefix sum, len = num_seqs + 1
60+
size_t total_indices, // total number of indices (token_offsets[num_seqs])
61+
void *stream);
62+
63+
/**
64+
* @brief Destroys a repetition penalty operator descriptor.
65+
*
66+
* @param desc Operator descriptor
67+
* @return infiniStatus_t Status code
68+
*/
69+
__C __export infiniStatus_t infiniopDestroyRepetitionPenaltyDescriptor(
70+
infiniopRepetitionPenaltyDescriptor_t desc);
71+
72+
#endif
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include "repetition_penalty_cpu.h"
2+
#include "../../../devices/cpu/common_cpu.h"
3+
#include "../info.h"
4+
#include "infinicore.h"
5+
#include <algorithm>
6+
7+
namespace op::repetition_penalty::cpu {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle_,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t logits_desc) {
15+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
16+
17+
auto result = RepetitionPenaltyInfo::create(logits_desc);
18+
CHECK_RESULT(result);
19+
20+
*desc_ptr = new Descriptor(
21+
result.take(),
22+
0, // No workspace needed for CPU
23+
nullptr,
24+
handle->device, handle->device_id);
25+
return INFINI_STATUS_SUCCESS;
26+
}
27+
28+
size_t Descriptor::minWorkspaceSize() const {
29+
return _min_workspace_size;
30+
}
31+
32+
template <typename T>
33+
void apply_penalty_cpu(
34+
T *logits,
35+
const float *repetition_penalties,
36+
const uint32_t *token_indices,
37+
const size_t *token_offsets,
38+
size_t num_seqs,
39+
size_t vocab_size) {
40+
41+
for (size_t seq_idx = 0; seq_idx < num_seqs; seq_idx++) {
42+
float penalty = repetition_penalties[seq_idx];
43+
if (penalty == 1.0f) {
44+
continue; // Skip if no penalty
45+
}
46+
47+
size_t start = token_offsets[seq_idx];
48+
size_t end = token_offsets[seq_idx + 1];
49+
for (size_t i = start; i < end; i++) {
50+
uint32_t token_id = token_indices[i];
51+
if (token_id >= vocab_size) {
52+
continue; // skip out-of-range ids
53+
}
54+
size_t offset = seq_idx * vocab_size + token_id;
55+
T logit_val_orig = logits[offset];
56+
float logit_val = utils::cast<float>(logit_val_orig);
57+
58+
// Match PyTorch behavior exactly: val / p if val > 0 else val * p
59+
if (logit_val > 0.0f) {
60+
logits[offset] = utils::cast<T>(logit_val / penalty);
61+
} else {
62+
// For val <= 0: multiply by penalty (covers negative and zero)
63+
logits[offset] = utils::cast<T>(logit_val * penalty);
64+
}
65+
}
66+
}
67+
}
68+
69+
infiniStatus_t Descriptor::calculate(
70+
void *workspace,
71+
size_t workspace_size,
72+
void *logits,
73+
const float *repetition_penalties,
74+
const uint32_t *token_indices,
75+
const size_t *token_offsets,
76+
size_t total_indices,
77+
void *stream) const {
78+
79+
switch (_info.dt_logits) {
80+
case INFINI_DTYPE_F16:
81+
apply_penalty_cpu<fp16_t>(
82+
reinterpret_cast<fp16_t *>(logits),
83+
repetition_penalties,
84+
token_indices,
85+
token_offsets,
86+
_info.num_seqs,
87+
_info.vocab_size);
88+
break;
89+
case INFINI_DTYPE_BF16:
90+
apply_penalty_cpu<bf16_t>(
91+
reinterpret_cast<bf16_t *>(logits),
92+
repetition_penalties,
93+
token_indices,
94+
token_offsets,
95+
_info.num_seqs,
96+
_info.vocab_size);
97+
break;
98+
case INFINI_DTYPE_F32:
99+
apply_penalty_cpu<float>(
100+
reinterpret_cast<float *>(logits),
101+
repetition_penalties,
102+
token_indices,
103+
token_offsets,
104+
_info.num_seqs,
105+
_info.vocab_size);
106+
break;
107+
case INFINI_DTYPE_F64:
108+
apply_penalty_cpu<double>(
109+
reinterpret_cast<double *>(logits),
110+
repetition_penalties,
111+
token_indices,
112+
token_offsets,
113+
_info.num_seqs,
114+
_info.vocab_size);
115+
break;
116+
default:
117+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
118+
}
119+
120+
return INFINI_STATUS_SUCCESS;
121+
}
122+
123+
} // namespace op::repetition_penalty::cpu
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __REPETITION_PENALTY_CPU_H__
2+
#define __REPETITION_PENALTY_CPU_H__
3+
4+
#include "../repetition_penalty.h"
5+
6+
DESCRIPTOR(cpu)
7+
8+
#endif // __REPETITION_PENALTY_CPU_H__
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef __REPETITION_PENALTY_INFO_H__
2+
#define __REPETITION_PENALTY_INFO_H__
3+
4+
#include "../../../utils.h"
5+
#include "../../tensor.h"
6+
7+
namespace op::repetition_penalty {
8+
9+
struct RepetitionPenaltyInfo {
10+
infiniDtype_t dt_logits;
11+
size_t num_seqs;
12+
size_t vocab_size;
13+
14+
static utils::Result<RepetitionPenaltyInfo> create(
15+
infiniopTensorDescriptor_t logits_desc) {
16+
17+
CHECK_OR_RETURN(logits_desc->ndim() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
18+
19+
auto num_seqs = logits_desc->dim(0);
20+
auto vocab_size = logits_desc->dim(1);
21+
22+
CHECK_DTYPE(logits_desc->dtype(), INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
23+
24+
return utils::Result<RepetitionPenaltyInfo>({
25+
logits_desc->dtype(),
26+
num_seqs,
27+
vocab_size
28+
});
29+
}
30+
};
31+
32+
} // namespace op::repetition_penalty
33+
34+
#endif // __REPETITION_PENALTY_INFO_H__
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#ifndef __REPETITION_PENALTY_KERNEL_H__
2+
#define __REPETITION_PENALTY_KERNEL_H__
3+
4+
#include "../../../devices/metax/metax_common.h"
5+
#include "../info.h"
6+
7+
namespace op::repetition_penalty::metax {
8+
9+
// CUDA graph compatible kernel - all operations on device, no host-device memcpy
10+
template <typename T>
11+
static __global__ void applyRepetitionPenaltyKernel(
12+
T *__restrict__ logits,
13+
const float *__restrict__ repetition_penalties,
14+
const uint32_t *__restrict__ token_indices,
15+
const size_t *__restrict__ token_offsets,
16+
size_t num_seqs,
17+
size_t vocab_size,
18+
size_t total_indices) {
19+
20+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
21+
if (idx >= total_indices) {
22+
return;
23+
}
24+
25+
// Binary search over token_offsets to find seq_idx such that
26+
// token_offsets[seq_idx] <= idx < token_offsets[seq_idx + 1]
27+
size_t lo = 0;
28+
size_t hi = num_seqs;
29+
while (lo < hi) {
30+
size_t mid = (lo + hi) >> 1;
31+
if (token_offsets[mid + 1] <= idx) {
32+
lo = mid + 1;
33+
} else {
34+
hi = mid;
35+
}
36+
}
37+
size_t seq_idx = lo;
38+
39+
uint32_t token_id = token_indices[idx];
40+
if (token_id >= vocab_size) {
41+
return;
42+
}
43+
44+
float penalty = repetition_penalties[seq_idx];
45+
if (penalty == 1.0f) {
46+
return; // No penalty, skip
47+
}
48+
49+
size_t offset = seq_idx * vocab_size + token_id;
50+
float logit_val = static_cast<float>(logits[offset]);
51+
if (logit_val > 0) {
52+
logits[offset] = static_cast<T>(logit_val / penalty);
53+
} else {
54+
logits[offset] = static_cast<T>(logit_val * penalty);
55+
}
56+
}
57+
58+
} // namespace op::repetition_penalty::metax
59+
60+
#endif // __REPETITION_PENALTY_KERNEL_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __REPETITION_PENALTY_METAX_H__
2+
#define __REPETITION_PENALTY_METAX_H__
3+
4+
#include "../repetition_penalty.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif // __REPETITION_PENALTY_METAX_H__

0 commit comments

Comments
 (0)