@@ -1097,6 +1097,7 @@ def _verify_and_sample(
10971097 increment_value : int ,
10981098 accept_all_drafts : bool = False ,
10991099 reject_all_drafts : bool = False ,
1100+ topp_seed : Optional [paddle .Tensor ] = None ,
11001101 ) -> SamplerOutput :
11011102 """
11021103 Verify draft tokens against target model output and produce final samples.
@@ -1129,7 +1130,7 @@ def _verify_and_sample(
11291130
11301131 if self .verify_strategy == VerifyStrategy .TARGET_MATCH :
11311132 if FD_SAMPLING_CLASS .lower () == "triton" :
1132- target_tokens = _random_sample (probs , topp_seed = sampling_metadata . seed )
1133+ target_tokens = _random_sample (probs , topp_seed = topp_seed )
11331134 else :
11341135 # Only TARGET_MATCH needs stochastic sampling
11351136 top_p , top_k , topp_seed = build_sampling_params (
@@ -1208,6 +1209,7 @@ def _normal_sample(
12081209 probs : paddle .Tensor ,
12091210 sampling_metadata : SamplingMetadata ,
12101211 share_inputs : List [paddle .Tensor ],
1212+ topp_seed : Optional [paddle .Tensor ],
12111213 ) -> SamplerOutput :
12121214 """
12131215 Normal sampling without draft token verification.
@@ -1230,7 +1232,7 @@ def _normal_sample(
12301232
12311233 # Sample tokens
12321234 if FD_SAMPLING_CLASS .lower () == "triton" :
1233- next_tokens = _random_sample (probs , topp_seed = sampling_metadata . seed )
1235+ next_tokens = _random_sample (probs , topp_seed = topp_seed )
12341236 else :
12351237 next_tokens = _sample_from_probs (
12361238 probs ,
@@ -1333,9 +1335,10 @@ def forward_cuda(
13331335 )
13341336
13351337 logits_ori = None
1338+ topp_seed = None
13361339 if FD_SAMPLING_CLASS .lower () == "triton" :
13371340 logits_ori = logits .clone ()
1338- top_p , top_k , _ = build_sampling_params (
1341+ top_p , top_k , topp_seed = build_sampling_params (
13391342 sampling_metadata .top_p ,
13401343 sampling_metadata .top_k ,
13411344 sampling_metadata .seed ,
@@ -1356,7 +1359,7 @@ def forward_cuda(
13561359 # Route based on spec_method
13571360 is_naive = self .spec_method is None or self .spec_method == SpecMethod .NAIVE
13581361 if is_naive :
1359- sampler_output = self ._normal_sample (logits , probs , sampling_metadata , share_inputs )
1362+ sampler_output = self ._normal_sample (logits , probs , sampling_metadata , share_inputs , topp_seed = topp_seed )
13601363 else :
13611364 sampler_output = self ._verify_and_sample (
13621365 logits ,
@@ -1368,6 +1371,7 @@ def forward_cuda(
13681371 increment_value ,
13691372 accept_all_drafts ,
13701373 reject_all_drafts ,
1374+ topp_seed = topp_seed ,
13711375 )
13721376
13731377 keep_sampling_mask = sampling_metadata .keep_sampling_mask
0 commit comments