Skip to content

Commit b8b7f35

Browse files
committed
fix mtp triton seed
1 parent 532fc37 commit b8b7f35

2 files changed

Lines changed: 14 additions & 4 deletions

File tree

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
import paddle
2929
from paddle.utils.deprecated import VisibleDeprecationWarning
3030

31+
from fastdeploy.model_executor.ops.triton_ops.triton_utils import (
32+
enable_compat_on_triton_kernel,
33+
)
34+
3135
# Suppress the VisibleDeprecationWarning from use_triton_in_paddle that fires
3236
# on every Triton kernel launch (paddle.device.cuda.current_stream /
3337
# synchronize). In serving hot-paths this produces thousands of log lines per
@@ -112,6 +116,7 @@ def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, senti
112116
return min_larger, num_min_larger
113117

114118

119+
@enable_compat_on_triton_kernel
115120
@triton.jit
116121
def _topk_topp_kernel(
117122
LOGITS,
@@ -936,6 +941,7 @@ def apply_top_k_top_p_triton(
936941
return logits
937942

938943

944+
@enable_compat_on_triton_kernel
939945
@triton.jit
940946
def _seeded_gumbel_kernel(
941947
OUT_ptr,

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,7 @@ def _verify_and_sample(
10971097
increment_value: int,
10981098
accept_all_drafts: bool = False,
10991099
reject_all_drafts: bool = False,
1100+
topp_seed: Optional[paddle.Tensor] = None,
11001101
) -> SamplerOutput:
11011102
"""
11021103
Verify draft tokens against target model output and produce final samples.
@@ -1129,7 +1130,7 @@ def _verify_and_sample(
11291130

11301131
if self.verify_strategy == VerifyStrategy.TARGET_MATCH:
11311132
if FD_SAMPLING_CLASS.lower() == "triton":
1132-
target_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed)
1133+
target_tokens = _random_sample(probs, topp_seed=topp_seed)
11331134
else:
11341135
# Only TARGET_MATCH needs stochastic sampling
11351136
top_p, top_k, topp_seed = build_sampling_params(
@@ -1208,6 +1209,7 @@ def _normal_sample(
12081209
probs: paddle.Tensor,
12091210
sampling_metadata: SamplingMetadata,
12101211
share_inputs: List[paddle.Tensor],
1212+
topp_seed: Optional[paddle.Tensor],
12111213
) -> SamplerOutput:
12121214
"""
12131215
Normal sampling without draft token verification.
@@ -1230,7 +1232,7 @@ def _normal_sample(
12301232

12311233
# Sample tokens
12321234
if FD_SAMPLING_CLASS.lower() == "triton":
1233-
next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed)
1235+
next_tokens = _random_sample(probs, topp_seed=topp_seed)
12341236
else:
12351237
next_tokens = _sample_from_probs(
12361238
probs,
@@ -1333,9 +1335,10 @@ def forward_cuda(
13331335
)
13341336

13351337
logits_ori = None
1338+
topp_seed = None
13361339
if FD_SAMPLING_CLASS.lower() == "triton":
13371340
logits_ori = logits.clone()
1338-
top_p, top_k, _ = build_sampling_params(
1341+
top_p, top_k, topp_seed = build_sampling_params(
13391342
sampling_metadata.top_p,
13401343
sampling_metadata.top_k,
13411344
sampling_metadata.seed,
@@ -1356,7 +1359,7 @@ def forward_cuda(
13561359
# Route based on spec_method
13571360
is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE
13581361
if is_naive:
1359-
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs)
1362+
sampler_output = self._normal_sample(logits, probs, sampling_metadata, share_inputs, topp_seed=topp_seed)
13601363
else:
13611364
sampler_output = self._verify_and_sample(
13621365
logits,
@@ -1368,6 +1371,7 @@ def forward_cuda(
13681371
increment_value,
13691372
accept_all_drafts,
13701373
reject_all_drafts,
1374+
topp_seed=topp_seed,
13711375
)
13721376

13731377
keep_sampling_mask = sampling_metadata.keep_sampling_mask

0 commit comments

Comments
 (0)