Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 140 additions & 9 deletions fastdeploy/model_executor/entropy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


def get_entropy(logits):
# Check for -inf values in logits
if paddle.any(paddle.isinf(logits) & (logits < 0)):
data_processor_logger.debug("Detected -inf values in logits, clipping to minimum value")
logits = paddle.clip(logits, min=1e-9)
Expand All @@ -32,6 +31,19 @@ def get_entropy(logits):
return paddle.sum(p0 * (paddle.log(z0) - a0), axis=-1)


def _log_entropy(share_inputs, i):
elist = share_inputs["entropy_list"][i]
data_processor_logger.info(
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(elist)/len(elist)}, steps: {len(elist)}, all_values: {elist}"
)
share_inputs["entropy_list"][i] = []


# ==============================================================================
# ernie5_model_runner path (original logic from commit 361a310)
# ==============================================================================


def calculate_logits_entropy(logits, share_inputs, temperature):
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
real_seq_lens = paddle.where(
Expand All @@ -57,10 +69,7 @@ def calculate_logits_entropy(logits, share_inputs, temperature):
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
data_processor_logger.info(
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
)
share_inputs["entropy_list"][i] = []
_log_entropy(share_inputs, i)


def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
Expand Down Expand Up @@ -100,7 +109,129 @@ def speculate_calculate_logits_entropy(logits, share_inputs, temperature):
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
data_processor_logger.info(
f"req_id: {share_inputs['req_ids'][i]}, entropy: {sum(share_inputs['entropy_list'][i])/len(share_inputs['entropy_list'][i])}"
)
share_inputs["entropy_list"][i] = []
_log_entropy(share_inputs, i)


# ==============================================================================
# gpu_model_runner (FD runner) path
# ==============================================================================


def calculate_logits_entropy_fd(logits, share_inputs, temperature):
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
seq_lens_encoder = share_inputs["seq_lens_encoder"][:real_bsz]
seq_lens_this_time = share_inputs["seq_lens_this_time"]
if seq_lens_encoder.ndim == 2:
seq_lens_encoder = seq_lens_encoder.squeeze(1)
if seq_lens_this_time.ndim == 2:
seq_lens_this_time = seq_lens_this_time.squeeze(1)
real_seq_lens = paddle.where(
seq_lens_encoder != 0,
paddle.ones([1], dtype="int32"),
seq_lens_this_time,
)

for i in range(real_bsz):
if int(real_seq_lens[i]) == 0:
continue
t = temperature[i]
if t > 0 and t != 1.0:
logits[i] = logits[i].scale_(1 / t)

entropy_tensor = get_entropy(logits[:real_bsz])

for i in range(real_bsz):
if int(real_seq_lens[i]) == 0:
continue
entropy_val = float(entropy_tensor[i])
share_inputs["entropy_list"][i].append(entropy_val)
if (
share_inputs["stop_flags"][i]
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
_log_entropy(share_inputs, i)


def speculate_calculate_logits_entropy_fd(logits, share_inputs, temperature):
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
total_accepted_num = int(paddle.sum(share_inputs["accept_num"][:real_bsz]))

if total_accepted_num == 0:
for i in range(real_bsz):
if (
share_inputs["stop_flags"][i]
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
_log_entropy(share_inputs, i)
return

seq_lens_encoder = share_inputs["seq_lens_encoder"][:real_bsz]
seq_lens_this_time = share_inputs["seq_lens_this_time"]
if seq_lens_encoder.ndim == 2:
seq_lens_encoder = seq_lens_encoder.squeeze(1)
if seq_lens_this_time.ndim == 2:
seq_lens_this_time = seq_lens_this_time.squeeze(1)
real_seq_lens = paddle.where(
seq_lens_encoder != 0,
paddle.ones([1], dtype="int32"),
seq_lens_this_time,
)
seq_start_idx = paddle.concat([paddle.zeros([1], dtype="int32"), paddle.cumsum(real_seq_lens, dtype="int32")])
repeated_starts = paddle.repeat_interleave(seq_start_idx[:-1], share_inputs["accept_num"][:real_bsz])
offsets = paddle.concat([paddle.arange(share_inputs["accept_num"][i].item()) for i in range(real_bsz)]).astype(
"int32"
)
accepted_idx = repeated_starts + offsets

accepted_logits = paddle.empty([total_accepted_num, logits.shape[1]], dtype=logits.dtype)
for i in range(total_accepted_num):
accepted_logits[i] = logits[accepted_idx[i]]

batch_indices = paddle.arange(real_bsz, dtype="int32")
batch_id_per_token = paddle.repeat_interleave(batch_indices, share_inputs["accept_num"][:real_bsz])
for i in range(total_accepted_num):
bid = int(batch_id_per_token[i])
t = temperature[bid]
if t > 0 and t != 1.0:
accepted_logits[i] = accepted_logits[i].scale_(1 / t)

entropy_tensor = get_entropy(accepted_logits)
entropy = entropy_tensor.tolist()

for i in range(real_bsz):
accept_count = int(share_inputs["accept_num"][i])
if accept_count > 0:
req_id = share_inputs["req_ids"][i] if i < len(share_inputs["req_ids"]) else ""
is_valid_req = bool(req_id and str(req_id).strip())
for j in range(accept_count):
e_val = entropy.pop(0)
if is_valid_req:
share_inputs["entropy_list"][i].append(e_val)
if (
share_inputs["stop_flags"][i]
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
_log_entropy(share_inputs, i)


# ==============================================================================
# Common utility
# ==============================================================================


def flush_entropy_on_stop(share_inputs):
"""
Flush entropy for requests whose stop_flags became True after entropy calculation.
Called after unified_update_model_status which sets stop_flags for max_dec_len.
"""
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
for i in range(real_bsz):
if (
share_inputs["stop_flags"][i]
and share_inputs["seq_lens_decoder"][i] != 0
and len(share_inputs["entropy_list"][i]) != 0
):
_log_entropy(share_inputs, i)
20 changes: 18 additions & 2 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""

import os
import queue
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -107,8 +108,14 @@

from fastdeploy.model_executor.entropy_utils import (
calculate_logits_entropy,
calculate_logits_entropy_fd,
flush_entropy_on_stop,
speculate_calculate_logits_entropy,
speculate_calculate_logits_entropy_fd,
)

_USE_FD_RUNNER = os.environ.get("EB5_ENABLE_FD_RUNNER", "0") == "1"

from fastdeploy.model_executor.layers.moe.routing_indices_cache import (
RoutingReplayManager,
)
Expand Down Expand Up @@ -321,7 +328,10 @@ def post_process_normal(
)

if enable_entropy:
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
if _USE_FD_RUNNER:
calculate_logits_entropy_fd(sampler_output.logits, share_inputs, sampling_metadata.temperature)
else:
calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)

# Routing replay
if routing_replay_manager is not None:
Expand Down Expand Up @@ -471,7 +481,10 @@ def post_process_speculate(
)

if enable_entropy:
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)
if _USE_FD_RUNNER:
speculate_calculate_logits_entropy_fd(sampler_output.logits, share_inputs, sampling_metadata.temperature)
else:
speculate_calculate_logits_entropy(sampler_output.logits, share_inputs, sampling_metadata.temperature)

# Routing replay
if routing_replay_manager is not None:
Expand Down Expand Up @@ -520,6 +533,9 @@ def post_process_speculate(
model_output.max_dec_len, # max_dec_len
)

if enable_entropy:
flush_entropy_on_stop(share_inputs)


def save_output_speculate(
sampler_output: SamplerOutput,
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,7 +1254,7 @@ def _dummy_prefill_inputs(self, input_length_list: List[int], max_dec_len_list:
self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
idx * block_num, (idx + 1) * block_num, 1
)
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"]
self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:batch_size]

def _prepare_inputs(self, cached_token_num=-1, cached_real_bsz=-1, is_dummy_or_profile_run=False) -> None:
"""Prepare the model inputs"""
Expand Down Expand Up @@ -2415,6 +2415,7 @@ def execute_model_overlap(
model_inputs, p_done_idxs, token_num_event = self._preprocess(
model_forward_batch, num_running_requests, self._cached_launch_token_num, self._cached_real_bsz
)

model_output = self._execute(model_inputs)
# save output (last batch)
if self._cached_model_output_data is not None:
Expand Down
Loading
Loading