@@ -927,6 +927,7 @@ def _verify_and_sample(
927927 increment_value : int ,
928928 accept_all_drafts : bool = False ,
929929 reject_all_drafts : bool = False ,
930+ topp_seed : Optional [paddle .Tensor ] = None ,
930931 ) -> SamplerOutput :
931932 """
932933 Verify draft tokens against target model output and produce final samples.
@@ -959,7 +960,7 @@ def _verify_and_sample(
959960
960961 if self .verify_strategy == VerifyStrategy .TARGET_MATCH :
961962 if FD_SAMPLING_CLASS .lower () == "triton" :
962- target_tokens = _random_sample (probs , topp_seed = sampling_metadata . seed )
963+ target_tokens = _random_sample (probs , topp_seed = topp_seed )
963964 else :
964965 # Only TARGET_MATCH needs stochastic sampling
965966 top_p , top_k , topp_seed = build_sampling_params (
@@ -1038,6 +1039,7 @@ def _normal_sample(
10381039 probs : paddle .Tensor ,
10391040 sampling_metadata : SamplingMetadata ,
10401041 share_inputs : List [paddle .Tensor ],
1042+ topp_seed : Optional [paddle .Tensor ],
10411043 ) -> SamplerOutput :
10421044 """
10431045 Normal sampling without draft token verification.
@@ -1060,7 +1062,7 @@ def _normal_sample(
10601062
10611063 # Sample tokens
10621064 if FD_SAMPLING_CLASS .lower () == "triton" :
1063- next_tokens = _random_sample (probs , topp_seed = sampling_metadata . seed )
1065+ next_tokens = _random_sample (probs , topp_seed = topp_seed )
10641066 else :
10651067 next_tokens = _sample_from_probs (
10661068 probs ,
@@ -1164,6 +1166,7 @@ def forward_cuda(
11641166 )
11651167
11661168 logits_ori = None
1169+ topp_seed = None
11671170 if FD_SAMPLING_CLASS .lower () == "triton" :
11681171 logits_ori = logits .clone ()
11691172 top_p , top_k , _ = build_sampling_params (
@@ -1187,7 +1190,7 @@ def forward_cuda(
11871190 # Route based on spec_method
11881191 is_naive = self .spec_method is None or self .spec_method == SpecMethod .NAIVE
11891192 if is_naive :
1190- sampler_output = self ._normal_sample (logits , probs , sampling_metadata , share_inputs )
1193+ sampler_output = self ._normal_sample (logits , probs , sampling_metadata , share_inputs , topp_seed = topp_seed )
11911194 else :
11921195 sampler_output = self ._verify_and_sample (
11931196 logits ,
@@ -1199,6 +1202,7 @@ def forward_cuda(
11991202 increment_value ,
12001203 accept_all_drafts ,
12011204 reject_all_drafts ,
1205+ topp_seed = topp_seed ,
12021206 )
12031207
12041208 # Build logprobs via unified path (outside of sampling logic)
0 commit comments