-
Notifications
You must be signed in to change notification settings - Fork 749
【TI-Consisent】Added Metric logits_stats to the ZMQ branch #6979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 9 commits
86c539b
44d4367
e045231
a8259f6
9abc1e1
4488f97
67e0aa1
6a576c8
11175eb
d237894
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -645,6 +645,7 @@ def _start_worker_service(self): | |
| "use_internode_ll_two_stage": self.cfg.parallel_config.use_internode_ll_two_stage, | ||
| "disable_sequence_parallel_moe": self.cfg.parallel_config.disable_sequence_parallel_moe, | ||
| "enable_logprob": self.cfg.model_config.enable_logprob, | ||
| "compute_logits_stats": self.cfg.model_config.compute_logits_stats, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. common_engine.py中也得加这个参数
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加 |
||
| "lm_head_fp32": self.cfg.model_config.lm_head_fp32, | ||
| "moe_gate_fp32": self.cfg.model_config.moe_gate_fp32, | ||
| "shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,8 @@ | |
| # limitations under the License. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import itertools | ||
| import time | ||
|
|
@@ -825,22 +827,74 @@ def _create_chat_logprobs( | |
| request_decode_flag: Optional[bool] = True, | ||
| ) -> Optional[LogProbs]: | ||
| """Create OpenAI-style logprobs for chat completions.""" | ||
| if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs): | ||
| if ( | ||
| output_top_logprobs is None | ||
| or len(output_top_logprobs) < 3 | ||
| or any(not lst for lst in output_top_logprobs[:3]) | ||
| ): | ||
| return None | ||
| logprobs_res: Optional[LogProbs] = None | ||
| for logprob_token_ids, logprobs, sampled_token_ranks in zip( | ||
| output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2] | ||
| ): | ||
| top_logprobs = LogprobsLists( | ||
| logprob_token_ids=[logprob_token_ids], | ||
| logprobs=[logprobs], | ||
| sampled_token_ranks=[sampled_token_ranks], | ||
| ) | ||
|
|
||
| # Check if output_top_logprobs is a LogprobsLists object(NamedTuple) or a list | ||
| is_logprobslists = hasattr(output_top_logprobs, "logprob_token_ids") | ||
|
|
||
| # Extract logits stats if available | ||
| if is_logprobslists: | ||
| # output_top_logprobs is LogprobsLists namedtuple | ||
| has_logits_stats = output_top_logprobs.logits_min is not None | ||
| else: | ||
| # list from msgpack: [logprob_token_ids, logprobs, sampled_token_ranks, logits_min, logits_max, logits_mean, logits_std] | ||
| has_logits_stats = len(output_top_logprobs) >= 7 and output_top_logprobs[3] is not None | ||
|
|
||
| if is_logprobslists: | ||
| num_tokens = len(output_top_logprobs.logprobs) | ||
| _tk_ids = lambda idx: output_top_logprobs.logprob_token_ids[idx] | ||
| _lps = lambda idx: output_top_logprobs.logprobs[idx] | ||
| _ranks = lambda idx: output_top_logprobs.sampled_token_ranks[idx] | ||
| _lmin = lambda idx: output_top_logprobs.logits_min[idx] | ||
| _lmax = lambda idx: output_top_logprobs.logits_max[idx] | ||
| _lmean = lambda idx: output_top_logprobs.logits_mean[idx] | ||
| _lstd = lambda idx: output_top_logprobs.logits_std[idx] | ||
| else: | ||
| num_tokens = len(output_top_logprobs[1]) | ||
| _tk_ids = lambda idx: output_top_logprobs[0][idx] | ||
| _lps = lambda idx: output_top_logprobs[1][idx] | ||
| _ranks = lambda idx: output_top_logprobs[2][idx] | ||
| _lmin = lambda idx: output_top_logprobs[3][idx] | ||
| _lmax = lambda idx: output_top_logprobs[4][idx] | ||
| _lmean = lambda idx: output_top_logprobs[5][idx] | ||
| _lstd = lambda idx: output_top_logprobs[6][idx] | ||
|
|
||
| for idx in range(num_tokens): | ||
| logits_stats = None | ||
| if has_logits_stats: | ||
| top_logprobs = LogprobsLists( | ||
| logprob_token_ids=[_tk_ids(idx)], | ||
| logprobs=[_lps(idx)], | ||
| sampled_token_ranks=[_ranks(idx)], | ||
| logits_min=[_lmin(idx)], | ||
| logits_max=[_lmax(idx)], | ||
| logits_mean=[_lmean(idx)], | ||
| logits_std=[_lstd(idx)], | ||
| ) | ||
| logits_stats = { | ||
| "min": float(_lmin(idx)), | ||
| "max": float(_lmax(idx)), | ||
| "mean": float(_lmean(idx)), | ||
| "std": float(_lstd(idx)), | ||
| } | ||
| else: | ||
| top_logprobs = LogprobsLists( | ||
| logprob_token_ids=[_tk_ids(idx)], | ||
| logprobs=[_lps(idx)], | ||
| sampled_token_ranks=[_ranks(idx)], | ||
| ) | ||
| step_logprobs_res = self._build_logprobs_response( | ||
| request_logprobs=request_logprobs, | ||
| response_logprobs=top_logprobs, | ||
| request_top_logprobs=request_top_logprobs, | ||
| request_decode_flag=request_decode_flag, | ||
| logits_stats=logits_stats, | ||
| ) | ||
| if logprobs_res is None: | ||
| logprobs_res = step_logprobs_res | ||
|
|
@@ -854,6 +908,7 @@ def _build_logprobs_response( | |
| response_logprobs: Optional[LogprobsLists], | ||
| request_top_logprobs: int, | ||
| request_decode_flag: bool, | ||
| logits_stats: Optional[dict[str, float]] = None, | ||
| ) -> Optional[LogProbs]: | ||
|
Comment on lines
905
to
912
|
||
| """ | ||
| Construct a logprobs response object in line with the OpenAI style. | ||
|
|
@@ -901,6 +956,7 @@ def _build_logprobs_response( | |
| logprob=top_logprob_entries[0].logprob, | ||
| bytes=top_logprob_entries[0].bytes, | ||
| top_logprobs=top_logprob_entries[1:], # Here are the complete topk candidates | ||
| logits_stats=logits_stats, | ||
| ) | ||
|
|
||
| return LogProbs(content=[sampled_entry]) | ||
|
|
@@ -922,7 +978,7 @@ def _build_prompt_logprobs( | |
| tensors. | ||
| """ | ||
|
|
||
| token_ids, logprobs, ranks = prompt_logprobs_tensors | ||
| token_ids, logprobs, ranks = prompt_logprobs_tensors[:3] | ||
|
|
||
| # Normalize to plain Python lists (support both Tensor and list inputs) | ||
| if hasattr(token_ids, "tolist"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -83,6 +83,7 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.speculative_decoding = self.cfg.speculative_config.method is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.use_logprobs = self.cfg.model_config.enable_logprob | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.compute_logits_stats = self.cfg.model_config.compute_logits_stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.enable_draft_logprob = self.cfg.speculative_config.enable_draft_logprob | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.speculative_decoding: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -350,6 +351,26 @@ def _process_batch_output_use_zmq(self, receive_datas): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list: LogprobsLists = stream_data.logprobs.tolists() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.outputs.logprob = float(logprobs_list.logprobs[0][0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result.outputs.top_logprobs = logprobs_list | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Extract logits statistics if available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.compute_logits_stats: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_min is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_min is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_max is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_max is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_mean is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_mean is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logprobs_list.logits_std is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "logits_std is None when compute_logits_stats is enabled" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+356
to
+367
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | |
| logprobs_list.logits_min is not None | |
| ), "logits_min is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_max is not None | |
| ), "logits_max is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_mean is not None | |
| ), "logits_mean is None when compute_logits_stats is enabled" | |
| assert ( | |
| logprobs_list.logits_std is not None | |
| ), "logits_std is None when compute_logits_stats is enabled" | |
| missing_fields = [] | |
| if logprobs_list.logits_min is None: | |
| missing_fields.append("logits_min") | |
| if logprobs_list.logits_max is None: | |
| missing_fields.append("logits_max") | |
| if logprobs_list.logits_mean is None: | |
| missing_fields.append("logits_mean") | |
| if logprobs_list.logits_std is None: | |
| missing_fields.append("logits_std") | |
| if missing_fields: | |
| # When compute_logits_stats is enabled, all logits_* fields must be present | |
| raise ValueError( | |
| "Missing logits stats fields when compute_logits_stats is enabled: " | |
| + ", ".join(missing_fields) | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,11 @@ class LogprobsLists(NamedTuple): | |
| logprobs: list[list[float]] | ||
| # [num_reqs] | ||
| sampled_token_ranks: list[int] | ||
| # Logits statistics for each sequence (optional) | ||
| logits_min: Optional[list[float]] = None # [num_reqs] | ||
| logits_max: Optional[list[float]] = None # [num_reqs] | ||
| logits_mean: Optional[list[float]] = None # [num_reqs] | ||
| logits_std: Optional[list[float]] = None # [num_reqs] | ||
|
Comment on lines
44
to
+51
|
||
|
|
||
| def slice_columns(self, start: int, end: int): | ||
| """ | ||
|
|
@@ -54,6 +59,10 @@ def slice_columns(self, start: int, end: int): | |
| [row[start:end] for row in self.logprob_token_ids], | ||
| [row[start:end] for row in self.logprobs], | ||
| self.sampled_token_ranks, # unchanged | ||
| self.logits_min, # unchanged | ||
| self.logits_max, # unchanged | ||
| self.logits_mean, # unchanged | ||
| self.logits_std, # unchanged | ||
|
Comment on lines
58
to
+65
|
||
| ) | ||
|
|
||
| def slice_rows(self, start: int, end: int): | ||
|
|
@@ -65,6 +74,10 @@ def slice_rows(self, start: int, end: int): | |
| self.logprob_token_ids[start:end], | ||
| self.logprobs[start:end], | ||
| self.sampled_token_ranks[start:end], | ||
| self.logits_min[start:end] if self.logits_min is not None else None, | ||
| self.logits_max[start:end] if self.logits_max is not None else None, | ||
| self.logits_mean[start:end] if self.logits_mean is not None else None, | ||
| self.logits_std[start:end] if self.logits_std is not None else None, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -77,13 +90,22 @@ class LogprobsTensors(NamedTuple): | |
| logprobs: paddle.Tensor | ||
| # [num_reqs] | ||
| selected_token_ranks: paddle.Tensor | ||
| # Logits statistics for each sequence (optional) | ||
| logits_min: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_max: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_mean: Optional[paddle.Tensor] = None # [num_reqs] | ||
| logits_std: Optional[paddle.Tensor] = None | ||
|
|
||
| def tolists(self): | ||
| """Convert to lists.""" | ||
| return LogprobsLists( | ||
| self.logprob_token_ids.tolist(), | ||
| self.logprobs.tolist(), | ||
| self.selected_token_ranks.tolist(), | ||
| self.logits_min.tolist() if self.logits_min is not None else None, | ||
| self.logits_max.tolist() if self.logits_max is not None else None, | ||
| self.logits_mean.tolist() if self.logits_mean is not None else None, | ||
| self.logits_std.tolist() if self.logits_std is not None else None, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -97,6 +119,10 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens | |
| logprob_token_ids=logprob_token_ids, | ||
| logprobs=logprobs, | ||
| selected_token_ranks=selected_token_ranks, | ||
| logits_min=None, | ||
| logits_max=None, | ||
| logits_mean=None, | ||
| logits_std=None, | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
@@ -110,6 +136,10 @@ def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors" | |
| logprob_token_ids=logprob_token_ids, | ||
| logprobs=logprobs, | ||
| selected_token_ranks=selected_token_ranks, | ||
| logits_min=None, | ||
| logits_max=None, | ||
| logits_mean=None, | ||
| logits_std=None, | ||
| ) | ||
|
|
||
| def slice_rows(self, start: int, end: int): | ||
|
|
@@ -122,6 +152,26 @@ def slice_rows(self, start: int, end: int): | |
| paddle.to_tensor(self.logprob_token_ids.cpu()[start:end], place="cpu"), | ||
| paddle.to_tensor(self.logprobs.cpu()[start:end], place="cpu"), | ||
| paddle.to_tensor(self.selected_token_ranks.cpu()[start:end], place="cpu"), | ||
| ( | ||
| paddle.to_tensor(self.logits_min.cpu()[start:end], place="cpu") | ||
| if self.logits_min is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_max.cpu()[start:end], place="cpu") | ||
| if self.logits_max is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_mean.cpu()[start:end], place="cpu") | ||
| if self.logits_mean is not None | ||
| else None | ||
| ), | ||
| ( | ||
| paddle.to_tensor(self.logits_std.cpu()[start:end], place="cpu") | ||
| if self.logits_std is not None | ||
| else None | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR 标题目前为“【TI-Consisent】...”,不符合仓库要求的
[CLASS]Title格式(模板里给出的 tag 列表如[Feature]/[BugFix]等)。建议将标题改为类似[Feature] Add logits_stats metric for ZMQ logprobs,并修正 Consisent 的拼写以便后续检索与自动化流程识别。