Skip to content

Commit 217e587

Browse files
author
cloudforge1
committed
fix: add CPU fallback path for ngram_match and hybrid_mtp_ngram ops
Restore backward compatibility with existing CPU-only operator tests (test_ngram_match.py, test_hybrid_mtp_ngram.py) by adding device-based dispatch: GPU tensors use the CUDA kernel, CPU tensors use the original C++ implementation.
1 parent c349b12 commit 217e587

File tree

2 files changed

+358
-37
lines changed

2 files changed

+358
-37
lines changed

custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu

Lines changed: 178 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,154 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include <algorithm>
1516
#include <cstdlib>
17+
#include <cstring>
1618
#include <string>
1719
#include "paddle/extension.h"
1820

1921
#ifndef PD_BUILD_STATIC_OP
2022
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
2123
#endif
2224

25+
// ============================================================
26+
// CPU path — preserved from original for backward compatibility
27+
// with CPU-only callers and tests.
28+
// ============================================================
29+
static int sum_mixed_cpu(const int *value, int num) {
30+
int sum_value = 0;
31+
for (int i = 0; i <= num; i++) {
32+
sum_value += value[i];
33+
}
34+
return sum_value;
35+
}
36+
37+
static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
38+
const int64_t *input_ids_len,
39+
const int64_t *pre_ids,
40+
const int64_t *step_idx,
41+
const int *draft_token_num,
42+
int64_t *draft_tokens,
43+
int32_t *seq_lens_this_time,
44+
int32_t *seq_lens_decoder,
45+
int64_t *max_dec_len,
46+
int64_t input_ids_stride,
47+
int64_t pre_ids_stride,
48+
int64_t draft_tokens_stride,
49+
int64_t max_batch_size,
50+
int max_ngram_size = 3,
51+
int min_ngram_size = 1,
52+
const int max_draft_tokens = 10) {
53+
int threshold = 1024;
54+
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
55+
if (env_var) {
56+
threshold = std::stoi(env_var);
57+
}
58+
int unprocessed_batch_size = 0;
59+
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
60+
if (seq_lens_decoder[batch_idx] > 0) {
61+
unprocessed_batch_size++;
62+
}
63+
}
64+
for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
65+
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
66+
int max_draft_tokens_query = std::min(
67+
static_cast<int64_t>(max_draft_tokens - ori_seq_len_this_time + 1),
68+
max_dec_len[batch_idx] - step_idx[batch_idx] - 1);
69+
70+
if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) {
71+
continue;
72+
}
73+
74+
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
75+
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
76+
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
77+
const int64_t cur_step_idx = step_idx[batch_idx];
78+
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
79+
unprocessed_batch_size--;
80+
81+
auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
82+
int left_min_token_num = unprocessed_batch_size;
83+
84+
if (sum_token_num + max_draft_tokens_query + left_min_token_num >
85+
threshold) {
86+
int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num;
87+
max_draft_tokens_query =
88+
std::min(max_draft_tokens_query, tmp_max_draft_tokens);
89+
}
90+
91+
if (sum_token_num + left_min_token_num >= threshold - 1) {
92+
continue;
93+
}
94+
bool match_global = false;
95+
for (int ngram_size = max_ngram_size;
96+
ngram_size >= min_ngram_size && !match_global;
97+
--ngram_size) {
98+
if (cur_step_idx < ngram_size) {
99+
continue;
100+
}
101+
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);
102+
103+
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global;
104+
++i) {
105+
bool match_local = true;
106+
for (int j = 0; j < ngram_size; j++) {
107+
if (ngram[j] != cur_input_ids[i + j]) {
108+
match_local = false;
109+
break;
110+
}
111+
}
112+
if (match_local) {
113+
int64_t start_idx = i + ngram_size;
114+
int64_t end_idx =
115+
std::min(start_idx + max_draft_tokens_query, cur_input_ids_len);
116+
if (start_idx >= end_idx) continue;
117+
118+
int64_t cur_draft_token_num = end_idx - start_idx;
119+
seq_lens_this_time[batch_idx] =
120+
ori_seq_len_this_time + cur_draft_token_num;
121+
memcpy(cur_draft_tokens + ori_seq_len_this_time,
122+
cur_input_ids + start_idx,
123+
sizeof(int64_t) * cur_draft_token_num);
124+
match_global = true;
125+
break;
126+
}
127+
}
128+
if (!match_global) {
129+
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global;
130+
++i) {
131+
bool match_local = true;
132+
for (int j = 0; j < ngram_size; j++) {
133+
if (ngram[j] != cur_pre_ids[i + j]) {
134+
match_local = false;
135+
break;
136+
}
137+
}
138+
if (match_local) {
139+
int64_t start_idx = i + ngram_size;
140+
int64_t end_idx =
141+
std::min(start_idx + max_draft_tokens_query, cur_step_idx);
142+
int64_t cur_draft_token_num = end_idx - start_idx;
143+
if (start_idx >= end_idx) continue;
144+
145+
seq_lens_this_time[batch_idx] =
146+
ori_seq_len_this_time + cur_draft_token_num;
147+
memcpy(cur_draft_tokens + ori_seq_len_this_time,
148+
cur_pre_ids + start_idx,
149+
sizeof(int64_t) * cur_draft_token_num);
150+
match_global = true;
151+
break;
152+
}
153+
}
154+
}
155+
}
156+
}
157+
}
158+
159+
// ============================================================
160+
// GPU path — CUDA kernel for zero-copy ngram matching.
161+
// ============================================================
162+
23163
// GPU kernel for hybrid MTP ngram matching — eliminates CPU↔GPU data copies.
24164
// Single-thread execution preserves sequential threshold semantics.
25165
// Key differences from ngram_match_kernel:
@@ -187,24 +327,44 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
187327
threshold = std::stoi(env_var);
188328
}
189329

190-
ngram_match_mixed_kernel<<<1, 1, 0, input_ids.stream()>>>(
191-
input_ids.data<int64_t>(),
192-
input_ids_len.data<int64_t>(),
193-
pre_ids.data<int64_t>(),
194-
step_idx.data<int64_t>(),
195-
draft_token_num.data<int>(),
196-
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
197-
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
198-
seq_lens_decoder.data<int32_t>(),
199-
max_dec_len.data<int64_t>(),
200-
input_ids_stride,
201-
pre_ids_stride,
202-
draft_tokens_stride,
203-
max_batch_size,
204-
max_ngram_size,
205-
min_ngram_size,
206-
max_draft_tokens,
207-
threshold);
330+
if (input_ids.is_gpu()) {
331+
ngram_match_mixed_kernel<<<1, 1, 0, input_ids.stream()>>>(
332+
input_ids.data<int64_t>(),
333+
input_ids_len.data<int64_t>(),
334+
pre_ids.data<int64_t>(),
335+
step_idx.data<int64_t>(),
336+
draft_token_num.data<int>(),
337+
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
338+
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
339+
seq_lens_decoder.data<int32_t>(),
340+
max_dec_len.data<int64_t>(),
341+
input_ids_stride,
342+
pre_ids_stride,
343+
draft_tokens_stride,
344+
max_batch_size,
345+
max_ngram_size,
346+
min_ngram_size,
347+
max_draft_tokens,
348+
threshold);
349+
} else {
350+
find_candidate_pred_tokens_mixed(
351+
input_ids.data<int64_t>(),
352+
input_ids_len.data<int64_t>(),
353+
pre_ids.data<int64_t>(),
354+
step_idx.data<int64_t>(),
355+
draft_token_num.data<int>(),
356+
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
357+
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
358+
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
359+
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
360+
input_ids_stride,
361+
pre_ids_stride,
362+
draft_tokens_stride,
363+
max_batch_size,
364+
max_ngram_size,
365+
min_ngram_size,
366+
max_draft_tokens);
367+
}
208368
}
209369

210370
PD_BUILD_STATIC_OP(hybrid_mtp_ngram)

0 commit comments

Comments
 (0)