Skip to content

Commit 5b7ebb7

Browse files
authored
[BugFix][Speculative Decoding] Fix MTP triton sampler seed (#7946)
* add FD_ENABLE_TOP_P_ONE_OPT=0 * mtp support triton seed * fix mtp triton sampler * check test
1 parent 60e6223 commit 5b7ebb7

3 files changed

Lines changed: 12 additions & 5 deletions

File tree

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,6 @@ def _seeded_gumbel_kernel(
947947
"""Generate -log(u) with per-row Philox seeds, fully on GPU."""
948948
pid = tl.program_id(0)
949949
seed = tl.load(SEEDS_ptr + pid)
950-
seed = seed.to(tl.int32)
951950
for start in tl.range(0, VOCAB_SIZE, BLOCK_SIZE):
952951
offsets = start + tl.arange(0, BLOCK_SIZE)
953952
mask = offsets < VOCAB_SIZE

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ def _verify_and_sample(
927927
increment_value: int,
928928
accept_all_drafts: bool = False,
929929
reject_all_drafts: bool = False,
930+
topp_seed: Optional[paddle.Tensor] = None,
930931
) -> SamplerOutput:
931932
"""
932933
Verify draft tokens against target model output and produce final samples.
@@ -959,7 +960,7 @@ def _verify_and_sample(
959960

960961
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
961962
if FD_SAMPLING_CLASS.lower() == "triton":
962-
target_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed)
963+
target_tokens = _random_sample(probs, topp_seed=topp_seed)
963964
else:
964965
# Only TARGET_MATCH needs stochastic sampling
965966
top_p, top_k, topp_seed = build_sampling_params(
@@ -1038,6 +1039,7 @@ def _normal_sample(
10381039
probs: paddle.Tensor,
10391040
sampling_metadata: SamplingMetadata,
10401041
share_inputs: List[paddle.Tensor],
1042+
topp_seed: Optional[paddle.Tensor],
10411043
) -> SamplerOutput:
10421044
"""
10431045
Normal sampling without draft token verification.
@@ -1060,7 +1062,7 @@ def _normal_sample(
10601062

10611063
# Sample tokens
10621064
if FD_SAMPLING_CLASS.lower() == "triton":
1063-
next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed)
1065+
next_tokens = _random_sample(probs, topp_seed=topp_seed)
10641066
else:
10651067
next_tokens = _sample_from_probs(
10661068
probs,
@@ -1164,6 +1166,7 @@ def forward_cuda(
11641166
)
11651167

11661168
logits_ori = None
1169+
topp_seed = None
11671170
if FD_SAMPLING_CLASS.lower() == "triton":
11681171
logits_ori = logits.clone()
11691172
top_p, top_k, _ = build_sampling_params(
@@ -1187,7 +1190,7 @@ def forward_cuda(
11871190
# Route based on spec_method
11881191
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
11891192
if is_naive:
1190-
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs)
1193+
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs, topp_seed=topp_seed)
11911194
else:
11921195
sampler_output = self._verify_and_sample(
11931196
logits,
@@ -1199,6 +1202,7 @@ def forward_cuda(
11991202
increment_value,
12001203
accept_all_drafts,
12011204
reject_all_drafts,
1205+
topp_seed=topp_seed,
12021206
)
12031207

12041208
# Build logprobs via unified path (outside of sampling logic)

tests/layers/test_triton_sampler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def triton_mode(monkeypatch):
121121
import fastdeploy.envs as envs
122122

123123
monkeypatch.setattr(envs, "FD_SAMPLING_CLASS", "triton")
124+
monkeypatch.setattr("fastdeploy.model_executor.layers.sample.sampler.FD_SAMPLING_CLASS", "triton")
124125

125126

126127
def _create_metadata(batch_size=1, min_seq_len=1, max_seq_len=3, max_num_logprobs=None, **overrides):
@@ -343,6 +344,7 @@ def test_verify_and_sample_target_match_triton(self, mock_ops, triton_mode, monk
343344
m = _create_metadata(batch_size=1)
344345
logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32")
345346
probs = paddle.nn.functional.softmax(logits, axis=-1)
347+
seeds = paddle.ones([probs.shape[0], 1], dtype="int64")
346348

347349
out = sampler._verify_and_sample(
348350
logits,
@@ -352,6 +354,7 @@ def test_verify_and_sample_target_match_triton(self, mock_ops, triton_mode, monk
352354
share_inputs=_spec_share_inputs(),
353355
token_num_output_cpu=1,
354356
increment_value=1,
357+
topp_seed=seeds,
355358
)
356359
assert out.sampled_token_ids is not None
357360

@@ -366,8 +369,9 @@ def test_normal_sample_triton(self, mock_ops, triton_mode, monkeypatch):
366369
m = _create_metadata(batch_size=1)
367370
logits = paddle.to_tensor([[1.0, 2.0, 3.0]], dtype="float32")
368371
probs = paddle.nn.functional.softmax(logits, axis=-1)
372+
seeds = paddle.ones([probs.shape[0], 1], dtype="int64")
369373

370-
out = sampler._normal_sample(logits, probs, m, share_inputs=_spec_share_inputs())
374+
out = sampler._normal_sample(logits, probs, m, share_inputs=_spec_share_inputs(), topp_seed=seeds)
371375
assert out.sampled_token_ids is not None
372376

373377
def test_forward_cuda_triton_logit_mask(self, mock_ops, triton_mode, monkeypatch):

0 commit comments

Comments
 (0)