|
12 | 12 | // See the License for the specific language governing permissions and |
13 | 13 | // limitations under the License. |
14 | 14 |
|
| 15 | +#include <algorithm> |
15 | 16 | #include <cstdlib> |
| 17 | +#include <cstring> |
16 | 18 | #include <string> |
17 | 19 | #include "paddle/extension.h" |
18 | 20 |
|
19 | 21 | #ifndef PD_BUILD_STATIC_OP |
20 | 22 | #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) |
21 | 23 | #endif |
22 | 24 |
|
| 25 | +// ============================================================ |
| 26 | +// CPU path — preserved from original for backward compatibility |
| 27 | +// with CPU-only callers and tests. |
| 28 | +// ============================================================ |
| 29 | +static int sum_mixed_cpu(const int *value, int num) { |
| 30 | + int sum_value = 0; |
| 31 | + for (int i = 0; i <= num; i++) { |
| 32 | + sum_value += value[i]; |
| 33 | + } |
| 34 | + return sum_value; |
| 35 | +} |
| 36 | + |
| 37 | +static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, |
| 38 | + const int64_t *input_ids_len, |
| 39 | + const int64_t *pre_ids, |
| 40 | + const int64_t *step_idx, |
| 41 | + const int *draft_token_num, |
| 42 | + int64_t *draft_tokens, |
| 43 | + int32_t *seq_lens_this_time, |
| 44 | + int32_t *seq_lens_decoder, |
| 45 | + int64_t *max_dec_len, |
| 46 | + int64_t input_ids_stride, |
| 47 | + int64_t pre_ids_stride, |
| 48 | + int64_t draft_tokens_stride, |
| 49 | + int64_t max_batch_size, |
| 50 | + int max_ngram_size = 3, |
| 51 | + int min_ngram_size = 1, |
| 52 | + const int max_draft_tokens = 10) { |
| 53 | + int threshold = 1024; |
| 54 | + char *env_var = getenv("SPEC_TOKENUM_THRESHOLD"); |
| 55 | + if (env_var) { |
| 56 | + threshold = std::stoi(env_var); |
| 57 | + } |
| 58 | + int unprocessed_batch_size = 0; |
| 59 | + for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { |
| 60 | + if (seq_lens_decoder[batch_idx] > 0) { |
| 61 | + unprocessed_batch_size++; |
| 62 | + } |
| 63 | + } |
| 64 | + for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { |
| 65 | + const int ori_seq_len_this_time = seq_lens_this_time[batch_idx]; |
| 66 | + int max_draft_tokens_query = std::min( |
| 67 | + static_cast<int64_t>(max_draft_tokens - ori_seq_len_this_time + 1), |
| 68 | + max_dec_len[batch_idx] - step_idx[batch_idx] - 1); |
| 69 | + |
| 70 | + if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) { |
| 71 | + continue; |
| 72 | + } |
| 73 | + |
| 74 | + const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; |
| 75 | + int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; |
| 76 | + const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; |
| 77 | + const int64_t cur_step_idx = step_idx[batch_idx]; |
| 78 | + const int64_t cur_input_ids_len = input_ids_len[batch_idx]; |
| 79 | + unprocessed_batch_size--; |
| 80 | + |
| 81 | + auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx); |
| 82 | + int left_min_token_num = unprocessed_batch_size; |
| 83 | + |
| 84 | + if (sum_token_num + max_draft_tokens_query + left_min_token_num > |
| 85 | + threshold) { |
| 86 | + int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; |
| 87 | + max_draft_tokens_query = |
| 88 | + std::min(max_draft_tokens_query, tmp_max_draft_tokens); |
| 89 | + } |
| 90 | + |
| 91 | + if (sum_token_num + left_min_token_num >= threshold - 1) { |
| 92 | + continue; |
| 93 | + } |
| 94 | + bool match_global = false; |
| 95 | + for (int ngram_size = max_ngram_size; |
| 96 | + ngram_size >= min_ngram_size && !match_global; |
| 97 | + --ngram_size) { |
| 98 | + if (cur_step_idx < ngram_size) { |
| 99 | + continue; |
| 100 | + } |
| 101 | + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); |
| 102 | + |
| 103 | + for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; |
| 104 | + ++i) { |
| 105 | + bool match_local = true; |
| 106 | + for (int j = 0; j < ngram_size; j++) { |
| 107 | + if (ngram[j] != cur_input_ids[i + j]) { |
| 108 | + match_local = false; |
| 109 | + break; |
| 110 | + } |
| 111 | + } |
| 112 | + if (match_local) { |
| 113 | + int64_t start_idx = i + ngram_size; |
| 114 | + int64_t end_idx = |
| 115 | + std::min(start_idx + max_draft_tokens_query, cur_input_ids_len); |
| 116 | + if (start_idx >= end_idx) continue; |
| 117 | + |
| 118 | + int64_t cur_draft_token_num = end_idx - start_idx; |
| 119 | + seq_lens_this_time[batch_idx] = |
| 120 | + ori_seq_len_this_time + cur_draft_token_num; |
| 121 | + memcpy(cur_draft_tokens + ori_seq_len_this_time, |
| 122 | + cur_input_ids + start_idx, |
| 123 | + sizeof(int64_t) * cur_draft_token_num); |
| 124 | + match_global = true; |
| 125 | + break; |
| 126 | + } |
| 127 | + } |
| 128 | + if (!match_global) { |
| 129 | + for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; |
| 130 | + ++i) { |
| 131 | + bool match_local = true; |
| 132 | + for (int j = 0; j < ngram_size; j++) { |
| 133 | + if (ngram[j] != cur_pre_ids[i + j]) { |
| 134 | + match_local = false; |
| 135 | + break; |
| 136 | + } |
| 137 | + } |
| 138 | + if (match_local) { |
| 139 | + int64_t start_idx = i + ngram_size; |
| 140 | + int64_t end_idx = |
| 141 | + std::min(start_idx + max_draft_tokens_query, cur_step_idx); |
| 142 | + int64_t cur_draft_token_num = end_idx - start_idx; |
| 143 | + if (start_idx >= end_idx) continue; |
| 144 | + |
| 145 | + seq_lens_this_time[batch_idx] = |
| 146 | + ori_seq_len_this_time + cur_draft_token_num; |
| 147 | + memcpy(cur_draft_tokens + ori_seq_len_this_time, |
| 148 | + cur_pre_ids + start_idx, |
| 149 | + sizeof(int64_t) * cur_draft_token_num); |
| 150 | + match_global = true; |
| 151 | + break; |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | +} |
| 158 | + |
| 159 | +// ============================================================ |
| 160 | +// GPU path — CUDA kernel for zero-copy ngram matching. |
| 161 | +// ============================================================ |
| 162 | + |
23 | 163 | // GPU kernel for hybrid MTP ngram matching — eliminates CPU↔GPU data copies. |
24 | 164 | // Single-thread execution preserves sequential threshold semantics. |
25 | 165 | // Key differences from ngram_match_kernel: |
@@ -187,24 +327,44 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, |
187 | 327 | threshold = std::stoi(env_var); |
188 | 328 | } |
189 | 329 |
|
190 | | - ngram_match_mixed_kernel<<<1, 1, 0, input_ids.stream()>>>( |
191 | | - input_ids.data<int64_t>(), |
192 | | - input_ids_len.data<int64_t>(), |
193 | | - pre_ids.data<int64_t>(), |
194 | | - step_idx.data<int64_t>(), |
195 | | - draft_token_num.data<int>(), |
196 | | - const_cast<int64_t *>(draft_tokens.data<int64_t>()), |
197 | | - const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()), |
198 | | - seq_lens_decoder.data<int32_t>(), |
199 | | - max_dec_len.data<int64_t>(), |
200 | | - input_ids_stride, |
201 | | - pre_ids_stride, |
202 | | - draft_tokens_stride, |
203 | | - max_batch_size, |
204 | | - max_ngram_size, |
205 | | - min_ngram_size, |
206 | | - max_draft_tokens, |
207 | | - threshold); |
| 330 | + if (input_ids.is_gpu()) { |
| 331 | + ngram_match_mixed_kernel<<<1, 1, 0, input_ids.stream()>>>( |
| 332 | + input_ids.data<int64_t>(), |
| 333 | + input_ids_len.data<int64_t>(), |
| 334 | + pre_ids.data<int64_t>(), |
| 335 | + step_idx.data<int64_t>(), |
| 336 | + draft_token_num.data<int>(), |
| 337 | + const_cast<int64_t *>(draft_tokens.data<int64_t>()), |
| 338 | + const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()), |
| 339 | + seq_lens_decoder.data<int32_t>(), |
| 340 | + max_dec_len.data<int64_t>(), |
| 341 | + input_ids_stride, |
| 342 | + pre_ids_stride, |
| 343 | + draft_tokens_stride, |
| 344 | + max_batch_size, |
| 345 | + max_ngram_size, |
| 346 | + min_ngram_size, |
| 347 | + max_draft_tokens, |
| 348 | + threshold); |
| 349 | + } else { |
| 350 | + find_candidate_pred_tokens_mixed( |
| 351 | + input_ids.data<int64_t>(), |
| 352 | + input_ids_len.data<int64_t>(), |
| 353 | + pre_ids.data<int64_t>(), |
| 354 | + step_idx.data<int64_t>(), |
| 355 | + draft_token_num.data<int>(), |
| 356 | + const_cast<int64_t *>(draft_tokens.data<int64_t>()), |
| 357 | + const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()), |
| 358 | + const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()), |
| 359 | + const_cast<int64_t *>(max_dec_len.data<int64_t>()), |
| 360 | + input_ids_stride, |
| 361 | + pre_ids_stride, |
| 362 | + draft_tokens_stride, |
| 363 | + max_batch_size, |
| 364 | + max_ngram_size, |
| 365 | + min_ngram_size, |
| 366 | + max_draft_tokens); |
| 367 | + } |
208 | 368 | } |
209 | 369 |
|
210 | 370 | PD_BUILD_STATIC_OP(hybrid_mtp_ngram) |
|
0 commit comments