1414# limitations under the License.
1515"""
1616
17+ import os
18+
1719import paddle
1820
1921from fastdeploy .utils import data_processor_logger
2022
2123
2224def get_entropy (logits ):
23- # Check for -inf values in logits
2425 if paddle .any (paddle .isinf (logits ) & (logits < 0 )):
2526 data_processor_logger .debug ("Detected -inf values in logits, clipping to minimum value" )
2627 logits = paddle .clip (logits , min = 1e-9 )
@@ -32,7 +33,36 @@ def get_entropy(logits):
3233 return paddle .sum (p0 * (paddle .log (z0 ) - a0 ), axis = - 1 )
3334
3435
36+ def _log_entropy (share_inputs , i ):
37+ elist = share_inputs ["entropy_list" ][i ]
38+ data_processor_logger .info (
39+ f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (elist )/ len (elist )} , steps: { len (elist )} , all_values: { elist } "
40+ )
41+ share_inputs ["entropy_list" ][i ] = []
42+
43+
44+ # ==============================================================================
45+ # ernie5_model_runner path (original logic from commit 361a310)
46+ # ==============================================================================
47+
48+
3549def calculate_logits_entropy (logits , share_inputs , temperature ):
50+ use_fd_runner = os .environ .get ("EB5_ENABLE_FD_RUNNER" , "0" ) == "1"
51+ if use_fd_runner :
52+ calculate_logits_entropy_fd (logits , share_inputs , temperature )
53+ else :
54+ calculate_logits_entropy_ori (logits , share_inputs , temperature )
55+
56+
57+ def speculate_calculate_logits_entropy (logits , share_inputs , temperature ):
58+ use_fd_runner = os .environ .get ("EB5_ENABLE_FD_RUNNER" , "0" ) == "1"
59+ if use_fd_runner :
60+ speculate_calculate_logits_entropy_fd (logits , share_inputs , temperature )
61+ else :
62+ speculate_calculate_logits_entropy_ori (logits , share_inputs , temperature )
63+
64+
65+ def calculate_logits_entropy_ori (logits , share_inputs , temperature ):
3666 real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
3767 real_seq_lens = paddle .where (
3868 share_inputs ["seq_lens_encoder" ][:real_bsz ].squeeze (1 ) != 0 ,
@@ -57,13 +87,10 @@ def calculate_logits_entropy(logits, share_inputs, temperature):
5787 and share_inputs ["seq_lens_decoder" ][i ] != 0
5888 and len (share_inputs ["entropy_list" ][i ]) != 0
5989 ):
60- data_processor_logger .info (
61- f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (share_inputs ['entropy_list' ][i ])/ len (share_inputs ['entropy_list' ][i ])} "
62- )
63- share_inputs ["entropy_list" ][i ] = []
90+ _log_entropy (share_inputs , i )
6491
6592
66- def speculate_calculate_logits_entropy (logits , share_inputs , temperature ):
93+ def speculate_calculate_logits_entropy_ori (logits , share_inputs , temperature ):
6794 # get accepted logits
6895 real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
6996 total_accepted_num = paddle .sum (share_inputs ["accept_num" ])
@@ -100,7 +127,128 @@ def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
100127 and share_inputs ["seq_lens_decoder" ][i ] != 0
101128 and len (share_inputs ["entropy_list" ][i ]) != 0
102129 ):
103- data_processor_logger .info (
104- f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (share_inputs ['entropy_list' ][i ])/ len (share_inputs ['entropy_list' ][i ])} "
105- )
106- share_inputs ["entropy_list" ][i ] = []
130+ _log_entropy (share_inputs , i )
131+
132+
133+ # ==============================================================================
134+ # gpu_model_runner (FD runner) path
135+ # ==============================================================================
136+
137+
138+ def calculate_logits_entropy_fd (logits , share_inputs , temperature ):
139+ real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
140+ seq_lens_encoder = share_inputs ["seq_lens_encoder" ][:real_bsz ]
141+ seq_lens_this_time = share_inputs ["seq_lens_this_time" ]
142+ if seq_lens_encoder .ndim == 2 :
143+ seq_lens_encoder = seq_lens_encoder .squeeze (1 )
144+ if seq_lens_this_time .ndim == 2 :
145+ seq_lens_this_time = seq_lens_this_time .squeeze (1 )
146+ real_seq_lens = paddle .where (
147+ seq_lens_encoder != 0 ,
148+ paddle .ones ([1 ], dtype = "int32" ),
149+ seq_lens_this_time ,
150+ )
151+
152+ for i in range (real_bsz ):
153+ if int (real_seq_lens [i ]) == 0 :
154+ continue
155+ t = temperature [i ]
156+ if t > 0 and t != 1.0 :
157+ logits [i ] = logits [i ].scale_ (1 / t )
158+
159+ entropy_tensor = get_entropy (logits [:real_bsz ])
160+
161+ for i in range (real_bsz ):
162+ if int (real_seq_lens [i ]) == 0 :
163+ continue
164+ entropy_val = float (entropy_tensor [i ])
165+ share_inputs ["entropy_list" ][i ].append (entropy_val )
166+ if (
167+ share_inputs ["stop_flags" ][i ]
168+ and share_inputs ["seq_lens_decoder" ][i ] != 0
169+ and len (share_inputs ["entropy_list" ][i ]) != 0
170+ ):
171+ _log_entropy (share_inputs , i )
172+
173+
174+ def speculate_calculate_logits_entropy_fd (logits , share_inputs , temperature ):
175+ real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
176+ total_accepted_num = int (paddle .sum (share_inputs ["accept_num" ][:real_bsz ]))
177+
178+ if total_accepted_num == 0 :
179+ for i in range (real_bsz ):
180+ if (
181+ share_inputs ["stop_flags" ][i ]
182+ and share_inputs ["seq_lens_decoder" ][i ] != 0
183+ and len (share_inputs ["entropy_list" ][i ]) != 0
184+ ):
185+ _log_entropy (share_inputs , i )
186+ return
187+
188+ seq_lens_encoder = share_inputs ["seq_lens_encoder" ][:real_bsz ]
189+ seq_lens_this_time = share_inputs ["seq_lens_this_time" ]
190+ if seq_lens_encoder .ndim == 2 :
191+ seq_lens_encoder = seq_lens_encoder .squeeze (1 )
192+ if seq_lens_this_time .ndim == 2 :
193+ seq_lens_this_time = seq_lens_this_time .squeeze (1 )
194+ real_seq_lens = paddle .where (
195+ seq_lens_encoder != 0 ,
196+ paddle .ones ([1 ], dtype = "int32" ),
197+ seq_lens_this_time ,
198+ )
199+ seq_start_idx = paddle .concat ([paddle .zeros ([1 ], dtype = "int32" ), paddle .cumsum (real_seq_lens , dtype = "int32" )])
200+ repeated_starts = paddle .repeat_interleave (seq_start_idx [:- 1 ], share_inputs ["accept_num" ][:real_bsz ])
201+ offsets = paddle .concat ([paddle .arange (share_inputs ["accept_num" ][i ].item ()) for i in range (real_bsz )]).astype (
202+ "int32"
203+ )
204+ accepted_idx = repeated_starts + offsets
205+
206+ accepted_logits = paddle .index_select (logits , accepted_idx , axis = 0 )
207+
208+ batch_indices = paddle .arange (real_bsz , dtype = "int32" )
209+ batch_id_per_token = paddle .repeat_interleave (batch_indices , share_inputs ["accept_num" ][:real_bsz ])
210+ temp_per_token = temperature [batch_id_per_token ].flatten ()
211+ scale = paddle .where (
212+ temp_per_token > 0 ,
213+ 1.0 / temp_per_token ,
214+ paddle .ones_like (temp_per_token ),
215+ )
216+ accepted_logits = accepted_logits * scale .unsqueeze (- 1 )
217+
218+ entropy_tensor = get_entropy (accepted_logits )
219+ entropy = entropy_tensor .tolist ()
220+
221+ idx = 0
222+ for i in range (real_bsz ):
223+ accept_count = int (share_inputs ["accept_num" ][i ])
224+ if accept_count > 0 :
225+ for _ in range (accept_count ):
226+ e_val = entropy [idx ]
227+ share_inputs ["entropy_list" ][i ].append (e_val )
228+ idx += 1
229+ if (
230+ share_inputs ["stop_flags" ][i ]
231+ and share_inputs ["seq_lens_decoder" ][i ] != 0
232+ and len (share_inputs ["entropy_list" ][i ]) != 0
233+ ):
234+ _log_entropy (share_inputs , i )
235+
236+
237+ # ==============================================================================
238+ # Common utility
239+ # ==============================================================================
240+
241+
242+ def flush_entropy_on_stop (share_inputs ):
243+ """
244+ Flush entropy for requests whose stop_flags became True after entropy calculation.
245+ Called after unified_update_model_status which sets stop_flags for max_dec_len.
246+ """
247+ real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
248+ for i in range (real_bsz ):
249+ if (
250+ share_inputs ["stop_flags" ][i ]
251+ and share_inputs ["seq_lens_decoder" ][i ] != 0
252+ and len (share_inputs ["entropy_list" ][i ]) != 0
253+ ):
254+ _log_entropy (share_inputs , i )
0 commit comments