1414# limitations under the License.
1515"""
1616
17- import os
18-
1917import paddle
2018
2119from fastdeploy .utils import data_processor_logger
2220
2321
2422def get_entropy (logits ):
23+ # Check for -inf values in logits
2524 if paddle .any (paddle .isinf (logits ) & (logits < 0 )):
2625 data_processor_logger .debug ("Detected -inf values in logits, clipping to minimum value" )
2726 logits = paddle .clip (logits , min = 1e-9 )
@@ -33,36 +32,7 @@ def get_entropy(logits):
3332 return paddle .sum (p0 * (paddle .log (z0 ) - a0 ), axis = - 1 )
3433
3534
36- def _log_entropy (share_inputs , i ):
37- elist = share_inputs ["entropy_list" ][i ]
38- data_processor_logger .info (
39- f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (elist )/ len (elist )} , steps: { len (elist )} , all_values: { elist } "
40- )
41- share_inputs ["entropy_list" ][i ] = []
42-
43-
44- # ==============================================================================
45- # ernie5_model_runner path (original logic from commit 361a310)
46- # ==============================================================================
47-
48-
4935def calculate_logits_entropy (logits , share_inputs , temperature ):
50- use_fd_runner = os .environ .get ("EB5_ENABLE_FD_RUNNER" , "0" ) == "1"
51- if use_fd_runner :
52- calculate_logits_entropy_fd (logits , share_inputs , temperature )
53- else :
54- calculate_logits_entropy_ori (logits , share_inputs , temperature )
55-
56-
57- def speculate_calculate_logits_entropy (logits , share_inputs , temperature ):
58- use_fd_runner = os .environ .get ("EB5_ENABLE_FD_RUNNER" , "0" ) == "1"
59- if use_fd_runner :
60- speculate_calculate_logits_entropy_fd (logits , share_inputs , temperature )
61- else :
62- speculate_calculate_logits_entropy_ori (logits , share_inputs , temperature )
63-
64-
65- def calculate_logits_entropy_ori (logits , share_inputs , temperature ):
6636 real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
6737 real_seq_lens = paddle .where (
6838 share_inputs ["seq_lens_encoder" ][:real_bsz ].squeeze (1 ) != 0 ,
@@ -87,10 +57,13 @@ def calculate_logits_entropy_ori(logits, share_inputs, temperature):
8757 and share_inputs ["seq_lens_decoder" ][i ] != 0
8858 and len (share_inputs ["entropy_list" ][i ]) != 0
8959 ):
90- _log_entropy (share_inputs , i )
60+ data_processor_logger .info (
61+ f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (share_inputs ['entropy_list' ][i ])/ len (share_inputs ['entropy_list' ][i ])} "
62+ )
63+ share_inputs ["entropy_list" ][i ] = []
9164
9265
93- def speculate_calculate_logits_entropy_ori (logits , share_inputs , temperature ):
66+ def speculate_calculate_logits_entropy (logits , share_inputs , temperature ):
9467 # get accepted logits
9568 real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
9669 total_accepted_num = paddle .sum (share_inputs ["accept_num" ])
@@ -127,128 +100,7 @@ def speculate_calculate_logits_entropy_ori(logits, share_inputs, temperature):
127100 and share_inputs ["seq_lens_decoder" ][i ] != 0
128101 and len (share_inputs ["entropy_list" ][i ]) != 0
129102 ):
130- _log_entropy (share_inputs , i )
131-
132-
133- # ==============================================================================
134- # gpu_model_runner (FD runner) path
135- # ==============================================================================
136-
137-
138- def calculate_logits_entropy_fd (logits , share_inputs , temperature ):
139- real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
140- seq_lens_encoder = share_inputs ["seq_lens_encoder" ][:real_bsz ]
141- seq_lens_this_time = share_inputs ["seq_lens_this_time" ]
142- if seq_lens_encoder .ndim == 2 :
143- seq_lens_encoder = seq_lens_encoder .squeeze (1 )
144- if seq_lens_this_time .ndim == 2 :
145- seq_lens_this_time = seq_lens_this_time .squeeze (1 )
146- real_seq_lens = paddle .where (
147- seq_lens_encoder != 0 ,
148- paddle .ones ([1 ], dtype = "int32" ),
149- seq_lens_this_time ,
150- )
151-
152- for i in range (real_bsz ):
153- if int (real_seq_lens [i ]) == 0 :
154- continue
155- t = temperature [i ]
156- if t > 0 and t != 1.0 :
157- logits [i ] = logits [i ].scale_ (1 / t )
158-
159- entropy_tensor = get_entropy (logits [:real_bsz ])
160-
161- for i in range (real_bsz ):
162- if int (real_seq_lens [i ]) == 0 :
163- continue
164- entropy_val = float (entropy_tensor [i ])
165- share_inputs ["entropy_list" ][i ].append (entropy_val )
166- if (
167- share_inputs ["stop_flags" ][i ]
168- and share_inputs ["seq_lens_decoder" ][i ] != 0
169- and len (share_inputs ["entropy_list" ][i ]) != 0
170- ):
171- _log_entropy (share_inputs , i )
172-
173-
174- def speculate_calculate_logits_entropy_fd (logits , share_inputs , temperature ):
175- real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
176- total_accepted_num = int (paddle .sum (share_inputs ["accept_num" ][:real_bsz ]))
177-
178- if total_accepted_num == 0 :
179- for i in range (real_bsz ):
180- if (
181- share_inputs ["stop_flags" ][i ]
182- and share_inputs ["seq_lens_decoder" ][i ] != 0
183- and len (share_inputs ["entropy_list" ][i ]) != 0
184- ):
185- _log_entropy (share_inputs , i )
186- return
187-
188- seq_lens_encoder = share_inputs ["seq_lens_encoder" ][:real_bsz ]
189- seq_lens_this_time = share_inputs ["seq_lens_this_time" ]
190- if seq_lens_encoder .ndim == 2 :
191- seq_lens_encoder = seq_lens_encoder .squeeze (1 )
192- if seq_lens_this_time .ndim == 2 :
193- seq_lens_this_time = seq_lens_this_time .squeeze (1 )
194- real_seq_lens = paddle .where (
195- seq_lens_encoder != 0 ,
196- paddle .ones ([1 ], dtype = "int32" ),
197- seq_lens_this_time ,
198- )
199- seq_start_idx = paddle .concat ([paddle .zeros ([1 ], dtype = "int32" ), paddle .cumsum (real_seq_lens , dtype = "int32" )])
200- repeated_starts = paddle .repeat_interleave (seq_start_idx [:- 1 ], share_inputs ["accept_num" ][:real_bsz ])
201- offsets = paddle .concat ([paddle .arange (share_inputs ["accept_num" ][i ].item ()) for i in range (real_bsz )]).astype (
202- "int32"
203- )
204- accepted_idx = repeated_starts + offsets
205-
206- accepted_logits = paddle .index_select (logits , accepted_idx , axis = 0 )
207-
208- batch_indices = paddle .arange (real_bsz , dtype = "int32" )
209- batch_id_per_token = paddle .repeat_interleave (batch_indices , share_inputs ["accept_num" ][:real_bsz ])
210- temp_per_token = temperature [batch_id_per_token ].flatten ()
211- scale = paddle .where (
212- temp_per_token > 0 ,
213- 1.0 / temp_per_token ,
214- paddle .ones_like (temp_per_token ),
215- )
216- accepted_logits = accepted_logits * scale .unsqueeze (- 1 )
217-
218- entropy_tensor = get_entropy (accepted_logits )
219- entropy = entropy_tensor .tolist ()
220-
221- idx = 0
222- for i in range (real_bsz ):
223- accept_count = int (share_inputs ["accept_num" ][i ])
224- if accept_count > 0 :
225- for _ in range (accept_count ):
226- e_val = entropy [idx ]
227- share_inputs ["entropy_list" ][i ].append (e_val )
228- idx += 1
229- if (
230- share_inputs ["stop_flags" ][i ]
231- and share_inputs ["seq_lens_decoder" ][i ] != 0
232- and len (share_inputs ["entropy_list" ][i ]) != 0
233- ):
234- _log_entropy (share_inputs , i )
235-
236-
237- # ==============================================================================
238- # Common utility
239- # ==============================================================================
240-
241-
242- def flush_entropy_on_stop (share_inputs ):
243- """
244- Flush entropy for requests whose stop_flags became True after entropy calculation.
245- Called after unified_update_model_status which sets stop_flags for max_dec_len.
246- """
247- real_bsz = share_inputs ["seq_lens_this_time" ].shape [0 ]
248- for i in range (real_bsz ):
249- if (
250- share_inputs ["stop_flags" ][i ]
251- and share_inputs ["seq_lens_decoder" ][i ] != 0
252- and len (share_inputs ["entropy_list" ][i ]) != 0
253- ):
254- _log_entropy (share_inputs , i )
103+ data_processor_logger .info (
104+ f"req_id: { share_inputs ['req_ids' ][i ]} , entropy: { sum (share_inputs ['entropy_list' ][i ])/ len (share_inputs ['entropy_list' ][i ])} "
105+ )
106+ share_inputs ["entropy_list" ][i ] = []
0 commit comments