Skip to content

【draft】Mtp optimization#1266

Open
hiworldwzj wants to merge 32 commits intomainfrom
mtp_optimization
Open

【draft】Mtp optimization#1266
hiworldwzj wants to merge 32 commits intomainfrom
mtp_optimization

Conversation

@hiworldwzj
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Multi-Token Prediction (MTP) support, including a new 'eagle3' mode, dynamic MTP verification, and optimized Triton kernels for diverse attention. It also adds profiling capabilities and updates benchmark scripts. My feedback focuses on improving code quality by moving local imports to the top of files, replacing print statements with proper logging, removing dead/commented-out code, and optimizing weight loading.

Comment on lines +116 to +117
from lightllm.utils.envs_utils import get_env_start_args
args_mtp_step = get_env_start_args().mtp_step
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid using absolute imports inside a function. Move this import to the top of the file to improve readability and maintainability.

Comment on lines +156 to +160
if os.path.exists(os.path.join(draft_model_path[0], "pytorch_model.bin")):
self.draft_model_weight_dict = torch.load(os.path.join(draft_model_path[0], "pytorch_model.bin"))
self.hidden_proj_weight = self.draft_model_weight_dict["fc.weight"].to(torch.bfloat16).to("cuda")
del self.draft_model_weight_dict
gc.collect()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of torch.load with pytorch_model.bin can be slow and memory-intensive. Consider using safetensors for faster and safer model weight loading if possible.

Comment on lines +357 to +362
# infer_state.b_mark_shared_group = F.pad(
# infer_state.b_mark_shared_group,
# (0, infer_state.input_ids.shape[0] - infer_state.b_mark_shared_group.shape[0]),
# mode="constant",
# value=0,
# )
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code block should be removed if it is no longer needed, or uncommented if it is intended to be part of the logic. Leaving dead code reduces maintainability.

Comment on lines +141 to +144
try:
attr_.copy_(attr_value, non_blocking=True)
except Exception as e:
print(f"Warning: copy tensor {attr_name} failed during cuda graph copy, error: {e}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for error logging is not recommended in production code. Use the project's logger to ensure errors are captured in the standard logging infrastructure.

Comment on lines +619 to +631
# # 1. 根据当前的 group_sizes 将原来的索引拆分成多个组
# # 这里的 group_sizes 应该对应之前未处理前的每一组的大小
# chunks = torch.split(draft_model_input.mem_indexes, mtp_group_sizes)
# # 2. 对每一个 chunk 进行处理:去掉第一个元素 ([:, 1:]),并加上对应的 eagle_mem_indexes_i 元素
# # 假设 eagle_mem_indexes_i 的形状是 (num_groups,)
# new_chunks = []
# for i, chunk in enumerate(chunks):
# # chunk[1:] 模拟了原来的 [:, 1:] 操作
# # eagle_mem_indexes_i[i:i+1] 确保拿出来的是一个长度为 1 的张量用于拼接
# updated_chunk = torch.cat([chunk[1:], eagle_mem_indexes_i[i:i+1]], dim=0)
# new_chunks.append(updated_chunk)
# # 3. 重新合并回一维张量
# draft_model_input.mem_indexes = torch.cat(new_chunks, dim=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The commented-out code block should be removed to keep the codebase clean and maintainable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants