@@ -34,6 +34,20 @@ def _reset_cuda_generator_for_determinism():
3434 paddle .framework .core .default_cuda_generator (0 ).manual_seed (_DETERMINISTIC_RNG_SEED )
3535
3636
37+ def dispatch_top_k_renorm_probs (probs , top_k ):
38+ try :
39+ if current_platform .is_iluvatar ():
40+ from fastdeploy .model_executor .ops .iluvatar import top_k_renorm_probs
41+ else :
42+ from fastdeploy .model_executor .ops .gpu import top_k_renorm_probs
43+ probs = top_k_renorm_probs (probs , top_k )
44+
45+ except ImportError :
46+ logger .warning ("top_k sampling is not supported on current platform, skipping top_k filtering." )
47+
48+ return probs
49+
50+
3751def top_k_top_p_sampling (
3852 x : paddle .Tensor ,
3953 top_p : paddle .Tensor ,
@@ -70,7 +84,6 @@ def top_k_top_p_sampling(
7084
7185 """
7286 top_p_class = envs .FD_SAMPLING_CLASS .lower ()
73- topp_seed_device = None
7487
7588 # In deterministic mode, reset CUDA generator offset before sampling.
7689 # paddle.tensor.top_p_sampling uses the global GPU generator offset even
@@ -85,29 +98,17 @@ def top_k_top_p_sampling(
8598 _ = None
8699 else :
87100 if top_k_list and any (x > 0 for x in top_k_list ):
88- try :
89- if current_platform .is_iluvatar ():
90- from fastdeploy .model_executor .ops .iluvatar import (
91- top_k_renorm_probs ,
92- )
93- else :
94- from fastdeploy .model_executor .ops .gpu import top_k_renorm_probs
95- x = top_k_renorm_probs (x , top_k )
96- except ImportError :
97- logger .warning ("top_k sampling is not supported on current platform, skipping top_k filtering." )
101+ x = dispatch_top_k_renorm_probs (x , top_k )
98102
99103 if top_p_class == "air" :
100104 _ , ids = air_top_p_sampling (x , top_p , threshold , topp_seed , seed = seed , k = k , mode = mode )
101105
102106 elif top_p_class == "base_non_truncated" :
103- if topp_seed is not None :
104- topp_seed_device = paddle .empty (shape = topp_seed .shape , dtype = topp_seed .dtype )
105- topp_seed_device .copy_ (topp_seed , False )
106107 _ , ids = paddle .tensor .top_p_sampling (
107108 x ,
108109 top_p ,
109110 threshold = threshold ,
110- topp_seed = topp_seed_device ,
111+ topp_seed = topp_seed ,
111112 seed = seed ,
112113 k = k ,
113114 mode = "non-truncated" ,
@@ -122,14 +123,11 @@ def top_k_top_p_sampling(
122123
123124 _ , ids = native_top_p_sampling (x , top_p )
124125 else :
125- if topp_seed is not None :
126- topp_seed_device = paddle .empty (shape = topp_seed .shape , dtype = topp_seed .dtype )
127- topp_seed_device .copy_ (topp_seed , False )
128126 _ , ids = paddle .tensor .top_p_sampling (
129127 x ,
130128 top_p ,
131129 threshold = threshold ,
132- topp_seed = topp_seed_device ,
130+ topp_seed = topp_seed ,
133131 seed = seed ,
134132 k = k ,
135133 mode = "truncated" ,
0 commit comments