Skip to content

Commit 10804fd

Browse files
committed
update hybrid kernel to adapt cudagraph
1 parent 55f2706 commit 10804fd

7 files changed

Lines changed: 200 additions & 26 deletions

File tree

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,8 @@ void HybridMtpNgram(const paddle::Tensor& token_ids_all,
973973
const paddle::Tensor& max_dec_len,
974974
const int max_ngram_size,
975975
const int min_ngram_size,
976-
const int max_draft_tokens);
976+
const int max_draft_tokens,
977+
const bool pad_to_max);
977978

978979
// MTP
979980
void DraftModelPostprocess(const paddle::Tensor& base_model_draft_tokens,

custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ __global__ void ngram_match_mixed_gather_kernel(
144144
int32_t *seq_lens_this_time,
145145
int64_t draft_tokens_stride,
146146
int64_t max_batch_size,
147-
int threshold) {
147+
int threshold,
148+
int max_draft_tokens_param,
149+
bool pad_to_max) {
148150
typedef cub::BlockScan<int, NGRAM_GATHER_THREADS> BlockScanInt;
149151
__shared__ typename BlockScanInt::TempStorage temp_storage1;
150152
__shared__ typename BlockScanInt::TempStorage temp_storage2;
@@ -203,9 +205,8 @@ __global__ void ngram_match_mixed_gather_kernel(
203205
}
204206
actual = min(actual, tentative);
205207

206-
seq_lens_this_time[tid] = actual;
207-
208-
// Copy ngram draft tokens from scratch to output
208+
// Copy ngram draft tokens from scratch to output FIRST
209+
// (so subsequent padding doesn't overwrite real ngram hits)
209210
int ngram_to_copy = actual - ori;
210211
if (ngram_to_copy > 0) {
211212
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
@@ -214,6 +215,37 @@ __global__ void ngram_match_mixed_gather_kernel(
214215
dst[ori + k] = src[ori + k];
215216
}
216217
}
218+
219+
// === Pad seq_lens_this_time to K+1 for cudagraph stability ===
220+
// Hybrid MTP-ngram produces variable seq_lens_this_time depending on how
221+
// many ngram positions hit (range: [num_model_steps+1, K+1]). cudagraph
222+
// captures launch params (grid dim, kernel args) at capture time; if the
223+
// captured slt differs from replay-time slt, downstream kernels read past
224+
// valid ranges of cu_seqlens / slot_mapping etc., causing CUDA 700.
225+
//
226+
// When pad_to_max=true (cudagraph enabled), force slt = K+1 =
227+
// max_draft_tokens + 1: positions beyond actual ngram hits get padded
228+
// with a placeholder token. The target model will verify these
229+
// placeholders and (almost always) reject them, but the verify cost is
230+
// fixed per iteration => grid dim is now invariant. When pad_to_max=
231+
// false (cudagraph disabled), keep the natural variable slt to avoid
232+
// wasting verify compute on placeholders.
233+
if (pad_to_max) {
234+
int target_slt = max_draft_tokens_param + 1;
235+
if (actual < target_slt) {
236+
int64_t *dst = draft_tokens + tid * draft_tokens_stride;
237+
// Reuse the last valid draft token as placeholder. It is a token the
238+
// model could plausibly have produced, so attention math stays
239+
// well-defined; rejection happens at the sampler level.
240+
int64_t pad_token = (actual > 0) ? dst[actual - 1] : 0;
241+
for (int k = actual; k < target_slt; k++) {
242+
dst[k] = pad_token;
243+
}
244+
actual = target_slt;
245+
}
246+
}
247+
248+
seq_lens_this_time[tid] = actual;
217249
}
218250
}
219251

@@ -376,7 +408,8 @@ void HybridMtpNgram(const paddle::Tensor &token_ids_all,
376408
const paddle::Tensor &max_dec_len,
377409
const int max_ngram_size,
378410
const int min_ngram_size,
379-
const int max_draft_tokens) {
411+
const int max_draft_tokens,
412+
const bool pad_to_max) {
380413
const int64_t max_model_len = token_ids_all.shape()[1];
381414

382415
auto pre_ids_shape = pre_ids.shape();
@@ -404,21 +437,28 @@ void HybridMtpNgram(const paddle::Tensor &token_ids_all,
404437
// counts tentative > 0, which is equivalent under this invariant.
405438

406439
// Allocate scratch buffers for Phase 1 → Phase 2 communication
440+
static paddle::Tensor s_draft_copy_mixed;
441+
static paddle::Tensor s_seqlens_copy_mixed;
442+
static paddle::Tensor s_seqlens_orig_mixed;
443+
static int64_t s_scratch_batch_mixed = 0;
444+
static int64_t s_scratch_stride_mixed = 0;
445+
446+
if (max_batch_size > s_scratch_batch_mixed ||
447+
draft_tokens_stride > s_scratch_stride_mixed) {
448+
s_draft_copy_mixed = paddle::empty({max_batch_size, draft_tokens_stride},
449+
paddle::DataType::INT64,
450+
token_ids_all.place());
451+
s_seqlens_copy_mixed = paddle::empty(
452+
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
453+
s_seqlens_orig_mixed = paddle::empty(
454+
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
455+
s_scratch_batch_mixed = max_batch_size;
456+
s_scratch_stride_mixed = draft_tokens_stride;
457+
}
458+
auto &draft_tokens_copy = s_draft_copy_mixed;
459+
auto &seq_lens_this_time_copy = s_seqlens_copy_mixed;
460+
auto &seq_lens_this_time_orig = s_seqlens_orig_mixed;
407461

408-
// Scratch copy of draft_tokens (Phase 1 writes tentative tokens here)
409-
auto draft_tokens_copy =
410-
paddle::empty({max_batch_size, draft_tokens_stride},
411-
paddle::DataType::INT64,
412-
token_ids_all.place());
413-
414-
// Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts)
415-
auto seq_lens_this_time_copy = paddle::empty(
416-
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
417-
418-
// Save a copy of original seq_lens_this_time for Phase 2
419-
// (Phase 1 reads from the original, Phase 2 needs ori values)
420-
auto seq_lens_this_time_orig = paddle::empty(
421-
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
422462
cudaMemcpyAsync(seq_lens_this_time_orig.data<int32_t>(),
423463
seq_lens_this_time.data<int32_t>(),
424464
max_batch_size * sizeof(int32_t),
@@ -462,7 +502,9 @@ void HybridMtpNgram(const paddle::Tensor &token_ids_all,
462502
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
463503
draft_tokens_stride,
464504
max_batch_size,
465-
threshold);
505+
threshold,
506+
max_draft_tokens,
507+
pad_to_max);
466508
} else {
467509
find_candidate_pred_tokens_mixed(
468510
token_ids_all.data<int64_t>(),
@@ -496,7 +538,8 @@ PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
496538
"max_dec_len"})
497539
.Attrs({"max_ngram_size: int",
498540
"min_ngram_size: int",
499-
"max_draft_tokens: int"})
541+
"max_draft_tokens: int",
542+
"pad_to_max: bool"})
500543
.Outputs({"draft_tokens_out", "seq_lens_this_time_out"})
501544
.SetKernelFn(PD_KERNEL(HybridMtpNgram))
502545
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},

fastdeploy/spec_decode/mtp_cuda.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ def _update_status(self):
393393
)
394394

395395
def _extend_draft_token_with_ngram_match(self):
396+
# pad_to_max forces hybrid kernel to write a fixed seq_lens_this_time
397+
# = K + 1, padding unfilled ngram slots with a placeholder draft token.
398+
# Required when target cudagraph is enabled (capture-time seq_lens_this_time
399+
# must match replay-time seq_lens_this_time).
396400
hybrid_mtp_ngram(
397401
self.model_inputs["token_ids_all"],
398402
self.model_inputs["prompt_lens"],
@@ -406,6 +410,7 @@ def _extend_draft_token_with_ngram_match(self):
406410
self.max_ngram_size,
407411
self.min_ngram_size,
408412
self.max_draft_token_num,
413+
self.graph_opt_config.use_cudagraph,
409414
)
410415

411416
def padding_cudagraph_inputs(self) -> None:

tests/e2e/test_ernie_21b_mtp_ngram.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ def setup_and_run_server():
157157
"--cache-queue-port",
158158
str(FD_CACHE_QUEUE_PORT),
159159
"--max-model-len",
160-
"4096",
160+
"32768",
161161
"--max-num-seqs",
162-
"8",
162+
"128",
163163
"--quantization",
164164
"wint4",
165165
"--enable-overlap-schedule",

tests/operators/test_hybrid_mtp_ngram.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_ngram_match_mixed(self):
8888
self.max_ngram_size,
8989
self.min_ngram_size,
9090
self.max_draft_tokens,
91+
False,
9192
)
9293

9394
np.testing.assert_allclose(self.seq_lens_this_time.numpy(), self.ref_seq_lens_this_time)

tests/spec_decode/test_ngram_gpu_kernel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,7 @@ def test_correctness_basic(self):
846846
max_ngram_size,
847847
min_ngram_size,
848848
max_draft_tokens,
849+
False,
849850
)
850851
paddle.device.synchronize()
851852

@@ -889,6 +890,7 @@ def test_correctness_varied_seeds(self):
889890
3,
890891
1,
891892
10,
893+
False,
892894
)
893895
paddle.device.synchronize()
894896
np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt)
@@ -936,6 +938,7 @@ def test_large_batch_long_seq(self):
936938
3,
937939
1,
938940
10,
941+
False,
939942
)
940943
paddle.device.synchronize()
941944
finally:
@@ -979,6 +982,7 @@ def test_single_batch_long_seq(self):
979982
3,
980983
1,
981984
10,
985+
False,
982986
)
983987
paddle.device.synchronize()
984988
np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt)
@@ -1022,6 +1026,7 @@ def test_many_short_seqs(self):
10221026
3,
10231027
1,
10241028
10,
1029+
False,
10251030
)
10261031
paddle.device.synchronize()
10271032
finally:

tests/worker/test_reorder_split_prefill_and_decode.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock
1+
from unittest.mock import Mock, patch
22

33
import paddle
44
import pytest
@@ -12,7 +12,11 @@
1212
SpeculativeConfig,
1313
StructuredOutputsConfig,
1414
)
15-
from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode
15+
from fastdeploy.worker.input_batch import (
16+
InputBatch,
17+
ProposerInputBatch,
18+
reorder_split_prefill_and_decode,
19+
)
1620

1721

1822
def create_mock_config():
@@ -61,6 +65,7 @@ def create_mock_config():
6165
scheduler_config = Mock(spec=SchedulerConfig)
6266
scheduler_config.max_num_seqs = 10
6367
scheduler_config.max_num_batched_tokens = 2048
68+
scheduler_config.max_extra_num_batched_tokens = 0
6469

6570
speculative_config = Mock(spec=SpeculativeConfig)
6671
speculative_config.method = None
@@ -315,5 +320,119 @@ def test_reorder_all_prefill(self):
315320
assert paddle.equal_all(input_batch.input_ids[i], original_data[i])
316321

317322

323+
class TestProposerInputBatchReset:
324+
"""Cover ProposerInputBatch.reset_model_inputs CUDA + token_ids_all branch
325+
(fastdeploy/worker/input_batch.py:972-985)."""
326+
327+
def _make_config(self):
328+
# Enable spec_decoding path so InputBatch.init_share_inputs allocates
329+
# cu_seqlens_q_output / draft_tokens / accept_num etc.
330+
fd_config = create_mock_config()
331+
fd_config.speculative_config.method = "mtp"
332+
fd_config.speculative_config.num_speculative_tokens = 1
333+
fd_config.speculative_config.num_model_steps = 1
334+
return fd_config
335+
336+
def _make_target(self, fd_config):
337+
target = InputBatch(fd_config)
338+
target.init_share_inputs()
339+
return target
340+
341+
def _make_proposer(self, fd_config, target):
342+
"""Construct a ProposerInputBatch and manually populate only the
343+
attributes that `reset_model_inputs` writes via `fill_paddle_tensor`.
344+
Skipping full `init_share_inputs` avoids depending on rope_emb,
345+
attention backends, and other heavy setup unrelated to the branch
346+
under test."""
347+
proposer = ProposerInputBatch(fd_config, target)
348+
349+
max_num_seqs = fd_config.scheduler_config.max_num_seqs
350+
hidden_size = fd_config.model_config.hidden_size
351+
max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
352+
353+
proposer.target_hidden_states = paddle.full([max_num_seqs, hidden_size], 0, dtype="bfloat16")
354+
proposer.draft_tokens = paddle.full([max_num_seqs, max_draft_token_num + 1], -1, dtype="int64")
355+
proposer.is_block_step = paddle.full([max_num_seqs, 1], False, dtype="bool")
356+
proposer.batch_drop = paddle.full([max_num_seqs, 1], False, dtype="bool")
357+
proposer.used_list_len = paddle.full([max_num_seqs], 0, dtype="int32")
358+
proposer.first_token_hidden_states = paddle.full([max_num_seqs, hidden_size], -1)
359+
proposer.batch_token_num = paddle.full([max_num_seqs], 0, dtype="int32")
360+
proposer.next_token_num = paddle.full([max_num_seqs], 0, dtype="int32")
361+
proposer.cu_batch_token_offset = paddle.full([max_num_seqs + 1], 0, dtype="int32")
362+
proposer.cu_next_token_offset = paddle.full([max_num_seqs + 1], 0, dtype="int32")
363+
proposer.mask_rollback = paddle.full([max_num_seqs, 1], 0, dtype="int32")
364+
proposer.recompute_token_num = paddle.full([max_num_seqs, 1], 0, dtype="int32")
365+
return proposer
366+
367+
@patch("fastdeploy.worker.input_batch.current_platform")
368+
def test_reset_rebinds_token_ids_all_on_cuda(self, mock_platform):
369+
"""When current_platform.is_cuda() and target has token_ids_all,
370+
reset_model_inputs must re-pull token_ids_all from target (line 973)
371+
and rebuild pre_ids from target.token_ids_all[bs_idx, prompt_len:]."""
372+
mock_platform.is_cuda.return_value = True
373+
mock_platform.is_xpu.return_value = False
374+
375+
fd_config = self._make_config()
376+
target = self._make_target(fd_config)
377+
proposer = self._make_proposer(fd_config, target)
378+
379+
max_num_seqs = fd_config.scheduler_config.max_num_seqs
380+
max_model_len = fd_config.model_config.max_model_len
381+
382+
# Rebind target.token_ids_all to a NEW tensor with known content so
383+
# we can distinguish "reset re-pulled it" from "init_share_inputs
384+
# already bound to the old reference".
385+
new_token_ids_all = paddle.arange(max_num_seqs * max_model_len, dtype="int64").reshape(
386+
[max_num_seqs, max_model_len]
387+
)
388+
target.token_ids_all = new_token_ids_all
389+
390+
# Set a non-zero prompt_len so the slice [:, prompt_len:] is non-trivial.
391+
prompt_len_value = 3
392+
target.prompt_lens = paddle.full([max_num_seqs, 1], prompt_len_value, dtype="int64")
393+
394+
proposer.reset_model_inputs()
395+
396+
# Line 973 effect: token_ids_all rebound to target's new tensor.
397+
assert proposer.token_ids_all is new_token_ids_all
398+
399+
# Line 975-985 effect: pre_ids has correct shape and the prefix is
400+
# token_ids_all[bs_idx, prompt_len:], suffix remains -1.
401+
assert proposer.pre_ids.shape == [max_num_seqs, max_model_len]
402+
valid_len = max_model_len - prompt_len_value
403+
expected_prefix = new_token_ids_all[:, prompt_len_value:]
404+
assert paddle.equal_all(proposer.pre_ids[:, :valid_len], expected_prefix)
405+
assert paddle.equal_all(
406+
proposer.pre_ids[:, valid_len:],
407+
paddle.full([max_num_seqs, prompt_len_value], -1, dtype="int64"),
408+
)
409+
410+
@patch("fastdeploy.worker.input_batch.current_platform")
411+
def test_reset_falls_back_to_pre_ids_clone_when_no_token_ids_all(self, mock_platform):
412+
"""When current_platform.is_cuda() but target lacks token_ids_all,
413+
reset_model_inputs takes the else branch (line 986-988): clone pre_ids,
414+
set token_ids_all to None."""
415+
mock_platform.is_cuda.return_value = True
416+
mock_platform.is_xpu.return_value = False
417+
418+
fd_config = self._make_config()
419+
target = self._make_target(fd_config)
420+
proposer = self._make_proposer(fd_config, target)
421+
422+
# Remove token_ids_all from target so the else branch fires.
423+
del target.token_ids_all
424+
# Provide a recognizable pre_ids on target.
425+
max_num_seqs = fd_config.scheduler_config.max_num_seqs
426+
max_model_len = fd_config.model_config.max_model_len
427+
target.pre_ids = paddle.full([max_num_seqs, max_model_len], 42, dtype="int64")
428+
429+
proposer.reset_model_inputs()
430+
431+
assert proposer.token_ids_all is None
432+
assert paddle.equal_all(proposer.pre_ids, paddle.full([max_num_seqs, max_model_len], 42, dtype="int64"))
433+
# Clone, not reference share.
434+
assert proposer.pre_ids is not target.pre_ids
435+
436+
318437
if __name__ == "__main__":
319438
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)