Skip to content

Commit 529ec9e

Browse files
authored
[Feature] Support computing entropy with fastdeploy runner (#7954)
1 parent d778309 commit 529ec9e

4 files changed

Lines changed: 331 additions & 11 deletions

File tree

fastdeploy/model_executor/entropy_utils.py

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
# limitations under the License.
1515
"""
1616

17+
import os
18+
1719
import paddle
1820

1921
from fastdeploy.utils import data_processor_logger
2022

2123

2224
def get_entropy(logits):
23-
# Check for -inf values in logits
2425
if paddle.any(paddle.isinf(logits) & (logits < 0)):
2526
data_processor_logger.debug("Detected -inf values in logits, clipping to minimum value")
2627
logits = paddle.clip(logits, min=1e-9)
@@ -32,7 +33,36 @@ def get_entropy(logits):
3233
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)
3334

3435

36+
def _log_entropy(share_inputs, i):
37+
elist = share_inputs["entropy_list"][i]
38+
data_processor_logger.info(
39+
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(elist)/len(elist)}, steps: {len(elist)}, all_values: {elist}"
40+
)
41+
share_inputs["entropy_list"][i] = []
42+
43+
44+
# ==============================================================================
45+
# ernie5_model_runner path (original logic from commit 361a310)
46+
# ==============================================================================
47+
48+
3549
def calculate_logits_entropy(logits, share_inputs, temperature):
50+
use_fd_runner = os.environ.get("EB5_ENABLE_FD_RUNNER", "0") == "1"
51+
if use_fd_runner:
52+
calculate_logits_entropy_fd(logits, share_inputs, temperature)
53+
else:
54+
calculate_logits_entropy_ori(logits, share_inputs, temperature)
55+
56+
57+
def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
58+
use_fd_runner = os.environ.get("EB5_ENABLE_FD_RUNNER", "0") == "1"
59+
if use_fd_runner:
60+
speculate_calculate_logits_entropy_fd(logits, share_inputs, temperature)
61+
else:
62+
speculate_calculate_logits_entropy_ori(logits, share_inputs, temperature)
63+
64+
65+
def calculate_logits_entropy_ori(logits, share_inputs, temperature):
3666
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
3767
real_seq_lens = paddle.where(
3868
share_inputs["seq_lens_encoder"][:real_bsz].squeeze(1) != 0,
@@ -57,13 +87,10 @@ def calculate_logits_entropy(logits, share_inputs, temperature):
5787
and share_inputs["seq_lens_decoder"][i] != 0
5888
and len(share_inputs["entropy_list"][i]) != 0
5989
):
60-
data_processor_logger.info(
61-
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
62-
)
63-
share_inputs["entropy_list"][i] = []
90+
_log_entropy(share_inputs, i)
6491

6592

66-
def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
93+
def speculate_calculate_logits_entropy_ori(logits, share_inputs, temperature):
6794
# get accepted logits
6895
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
6996
total_accepted_num = paddle.sum(share_inputs["accept_num"])
@@ -100,7 +127,128 @@ def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
100127
and share_inputs["seq_lens_decoder"][i] != 0
101128
and len(share_inputs["entropy_list"][i]) != 0
102129
):
103-
data_processor_logger.info(
104-
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
105-
)
106-
share_inputs["entropy_list"][i] = []
130+
_log_entropy(share_inputs, i)
131+
132+
133+
# ==============================================================================
134+
# gpu_model_runner (FD runner) path
135+
# ==============================================================================
136+
137+
138+
def calculate_logits_entropy_fd(logits, share_inputs, temperature):
139+
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
140+
seq_lens_encoder = share_inputs["seq_lens_encoder"][:real_bsz]
141+
seq_lens_this_time = share_inputs["seq_lens_this_time"]
142+
if seq_lens_encoder.ndim == 2:
143+
seq_lens_encoder = seq_lens_encoder.squeeze(1)
144+
if seq_lens_this_time.ndim == 2:
145+
seq_lens_this_time = seq_lens_this_time.squeeze(1)
146+
real_seq_lens = paddle.where(
147+
seq_lens_encoder != 0,
148+
paddle.ones([1], dtype="int32"),
149+
seq_lens_this_time,
150+
)
151+
152+
for i in range(real_bsz):
153+
if int(real_seq_lens[i]) == 0:
154+
continue
155+
t = temperature[i]
156+
if t > 0 and t != 1.0:
157+
logits[i] = logits[i].scale_(1 / t)
158+
159+
entropy_tensor = get_entropy(logits[:real_bsz])
160+
161+
for i in range(real_bsz):
162+
if int(real_seq_lens[i]) == 0:
163+
continue
164+
entropy_val = float(entropy_tensor[i])
165+
share_inputs["entropy_list"][i].append(entropy_val)
166+
if (
167+
share_inputs["stop_flags"][i]
168+
and share_inputs["seq_lens_decoder"][i] != 0
169+
and len(share_inputs["entropy_list"][i]) != 0
170+
):
171+
_log_entropy(share_inputs, i)
172+
173+
174+
def speculate_calculate_logits_entropy_fd(logits, share_inputs, temperature):
175+
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
176+
total_accepted_num = int(paddle.sum(share_inputs["accept_num"][:real_bsz]))
177+
178+
if total_accepted_num == 0:
179+
for i in range(real_bsz):
180+
if (
181+
share_inputs["stop_flags"][i]
182+
and share_inputs["seq_lens_decoder"][i] != 0
183+
and len(share_inputs["entropy_list"][i]) != 0
184+
):
185+
_log_entropy(share_inputs, i)
186+
return
187+
188+
seq_lens_encoder = share_inputs["seq_lens_encoder"][:real_bsz]
189+
seq_lens_this_time = share_inputs["seq_lens_this_time"]
190+
if seq_lens_encoder.ndim == 2:
191+
seq_lens_encoder = seq_lens_encoder.squeeze(1)
192+
if seq_lens_this_time.ndim == 2:
193+
seq_lens_this_time = seq_lens_this_time.squeeze(1)
194+
real_seq_lens = paddle.where(
195+
seq_lens_encoder != 0,
196+
paddle.ones([1], dtype="int32"),
197+
seq_lens_this_time,
198+
)
199+
seq_start_idx = paddle.concat([paddle.zeros([1], dtype="int32"), paddle.cumsum(real_seq_lens, dtype="int32")])
200+
repeated_starts = paddle.repeat_interleave(seq_start_idx[:-1], share_inputs["accept_num"][:real_bsz])
201+
offsets = paddle.concat([paddle.arange(share_inputs["accept_num"][i].item()) for i in range(real_bsz)]).astype(
202+
"int32"
203+
)
204+
accepted_idx = repeated_starts + offsets
205+
206+
accepted_logits = paddle.index_select(logits, accepted_idx, axis=0)
207+
208+
batch_indices = paddle.arange(real_bsz, dtype="int32")
209+
batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"][:real_bsz])
210+
temp_per_token = temperature[batch_id_per_token].flatten()
211+
scale = paddle.where(
212+
temp_per_token > 0,
213+
1.0 / temp_per_token,
214+
paddle.ones_like(temp_per_token),
215+
)
216+
accepted_logits = accepted_logits * scale.unsqueeze(-1)
217+
218+
entropy_tensor = get_entropy(accepted_logits)
219+
entropy = entropy_tensor.tolist()
220+
221+
idx = 0
222+
for i in range(real_bsz):
223+
accept_count = int(share_inputs["accept_num"][i])
224+
if accept_count > 0:
225+
for _ in range(accept_count):
226+
e_val = entropy[idx]
227+
share_inputs["entropy_list"][i].append(e_val)
228+
idx += 1
229+
if (
230+
share_inputs["stop_flags"][i]
231+
and share_inputs["seq_lens_decoder"][i] != 0
232+
and len(share_inputs["entropy_list"][i]) != 0
233+
):
234+
_log_entropy(share_inputs, i)
235+
236+
237+
# ==============================================================================
238+
# Common utility
239+
# ==============================================================================
240+
241+
242+
def flush_entropy_on_stop(share_inputs):
243+
"""
244+
Flush entropy for requests whose stop_flags became True after entropy calculation.
245+
Called after unified_update_model_status which sets stop_flags for max_dec_len.
246+
"""
247+
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
248+
for i in range(real_bsz):
249+
if (
250+
share_inputs["stop_flags"][i]
251+
and share_inputs["seq_lens_decoder"][i] != 0
252+
and len(share_inputs["entropy_list"][i]) != 0
253+
):
254+
_log_entropy(share_inputs, i)

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107

108108
from fastdeploy.model_executor.entropy_utils import (
109109
calculate_logits_entropy,
110+
flush_entropy_on_stop,
110111
speculate_calculate_logits_entropy,
111112
)
112113
from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
@@ -520,6 +521,9 @@ def post_process_speculate(
520521
model_output.max_dec_len, # max_dec_len
521522
)
522523

524+
if enable_entropy:
525+
flush_entropy_on_stop(share_inputs)
526+
523527

524528
def save_output_speculate(
525529
sampler_output: SamplerOutput,

fastdeploy/worker/gpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list:
12541254
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
12551255
idx * block_num, (idx + 1) * block_num, 1
12561256
)
1257-
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
1257+
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:batch_size]
12581258

12591259
def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None:
12601260
"""Prepare the model inputs"""
@@ -2415,6 +2415,7 @@ def execute_model_overlap(
24152415
model_inputs, p_done_idxs, token_num_event = self._preprocess(
24162416
model_forward_batch, num_running_requests, self._cached_launch_token_num, self._cached_real_bsz
24172417
)
2418+
24182419
model_output = self._execute(model_inputs)
24192420
# save output (last batch)
24202421
if self._cached_model_output_data is not None:

0 commit comments

Comments
 (0)