Skip to content

Commit 48c4adb

Browse files
authored
[fix]: Fix mtp config for Hy3 vllm calibration. (#329)
1 parent ceeec44 commit 48c4adb

2 files changed

Lines changed: 122 additions & 1 deletion

File tree

tools/fp8_quant_with_vllm_activation.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import multiprocessing as mp
55
import os
6+
import re
67
import shutil
78
from argparse import ArgumentParser
89
from typing import Dict, List, Tuple
@@ -644,6 +645,56 @@ def process_moe_values(data: Dict[str, Dict]) -> Dict[str, Dict]:
644645
print("no moe_expert_stats")
645646
raise AssertionError("moe_expert_stats file is required")
646647
# print(ac_json_data["model.layers.0.mlp.gate_up_proj"])
648+
649+
# ------------------------------------------------------------------
650+
# MTP (Multi-Token Prediction) calibration stats merge
651+
#
652+
# When stage 1 was run with --enable-mtp, the MTP draft layer's
653+
# activation / MoE expert statistics are written to a SEPARATE pair of
654+
# files:
655+
# - mtp_activation_stats.json
656+
# - mtp_moe_expert_stats.json
657+
# Their keys carry an extra ``.mtp_block.`` segment that reflects the
658+
# vLLM hunyuan_mtp draft module structure, e.g.
659+
# model.layers.80.mtp_block.self_attn.qkv_proj
660+
# model.layers.80.mtp_block.mlp.experts.0.gate_up_proj
661+
# However, the HF safetensors weight names for the same MTP layer do
662+
# NOT contain ``mtp_block`` (e.g. ``model.layers.80.mlp.experts.0.
663+
# gate_proj.weight``), so process_safetensor() derives lookup keys
664+
# without that segment. To keep the two sides aligned, we strip the
665+
# ``.mtp_block`` segment from MTP stat keys before merging.
666+
# ------------------------------------------------------------------
667+
_MTP_BLOCK_RE = re.compile(r"\.mtp_block\.")
668+
669+
def _strip_mtp_block_keys(d: Dict[str, dict]) -> Dict[str, dict]:
670+
return {_MTP_BLOCK_RE.sub(".", k): v for k, v in d.items()}
671+
672+
mtp_act_path = os.path.join(args.input_vllm_ac_json_path, "mtp_activation_stats.json")
673+
if os.path.isfile(mtp_act_path):
674+
with open(mtp_act_path, "r", encoding="utf8") as fp:
675+
mtp_act_stats = json.load(fp)
676+
merged = _strip_mtp_block_keys(mtp_act_stats)
677+
ac_json_data.update(merged)
678+
print(
679+
f"[mtp-stats] merged {len(merged)} MTP activation entries from "
680+
f"{mtp_act_path} (keys with .mtp_block. stripped)"
681+
)
682+
else:
683+
print(f"[mtp-stats] {mtp_act_path} not found; skipping MTP activation merge")
684+
685+
mtp_moe_path = os.path.join(args.input_vllm_ac_json_path, "mtp_moe_expert_stats.json")
686+
if os.path.isfile(mtp_moe_path):
687+
with open(mtp_moe_path, "r", encoding="utf8") as fp:
688+
mtp_moe_stats = json.load(fp)
689+
merged = _strip_mtp_block_keys(mtp_moe_stats)
690+
ac_json_data.update(merged)
691+
print(
692+
f"[mtp-stats] merged {len(merged)} MTP MoE expert entries from "
693+
f"{mtp_moe_path} (keys with .mtp_block. stripped)"
694+
)
695+
else:
696+
print(f"[mtp-stats] {mtp_moe_path} not found; skipping MTP MoE merge")
697+
647698
# ac_json_data = merge_vllm_act_moe_jsonl(ac_json_data)
648699
print(ac_json_data["model.layers.11.mlp.experts.1.gate_up_proj"])
649700
print(ac_json_data["model.layers.11.mlp.experts.22.gate_up_proj"])

tools/run_vllm_calibrate.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
remove_kvcache_perhead_hooks,
4040
remove_kvcache_perhead_value_hooks,
4141
setup_activation_hooks,
42+
setup_kvcache_perhead_hooks,
4243
setup_kvcache_perhead_value_hooks,
4344
setup_kvcache_value_hooks,
4445
setup_mtp_activation_hooks,
@@ -314,7 +315,7 @@ def main():
314315
speculative_config = None
315316
if args.enable_mtp:
316317
speculative_config = {
317-
"method": "hunyuan_mtp",
318+
"method": "mtp",
318319
"num_speculative_tokens": args.num_speculative_tokens,
319320
}
320321
print(f" MTP Enabled: True (num_speculative_tokens={args.num_speculative_tokens})")
@@ -382,6 +383,27 @@ def main():
382383
else:
383384
print(f"Worker {i}: No MTP draft model available")
384385

386+
# Per-head KV-cache hooks for the MTP draft model (only when the user
387+
# asked for per-head granularity at the main-model level). We reuse
388+
# ``setup_kvcache_perhead_hooks`` which finds Attention layers via
389+
# ``_find_layers`` and installs a separate KVCachePerHeadHook on each.
390+
# This coexists with the per-tensor ``KVCacheHook`` already registered
391+
# by ``setup_mtp_activation_hooks``; both fire on every forward, and we
392+
# later overwrite the per-tensor scalars in mtp_activation_stats.json
393+
# with the per-head lists so the stage-2 quantizer picks them up.
394+
if args.kv_granularity == "per-head":
395+
print("\n" + "=" * 80)
396+
print("Setting up MTP draft model per-head KV-cache hooks...")
397+
print("=" * 80)
398+
mtp_ph_results = llm.llm_engine.collective_rpc(
399+
lambda w: _apply_on_draft_model(w, setup_kvcache_perhead_hooks)
400+
)
401+
for i, result in enumerate(mtp_ph_results):
402+
if result is not None:
403+
print(f"Worker {i}: {result}")
404+
else:
405+
print(f"Worker {i}: No MTP draft model available (per-head KV)")
406+
385407
# Load dataset and prepare prompts
386408
print("\n" + "=" * 80)
387409
print("Loading dataset and preparing prompts...")
@@ -511,6 +533,54 @@ def main():
511533
stats_type="MTP MoE expert statistics",
512534
)
513535

536+
# ---------------------------------------------------------------
537+
# Per-head KV-cache stats for the MTP draft model.
538+
# The per-tensor KVCacheHook (registered inside setup_mtp_activation
539+
# _hooks) writes scalar min/max under keys like
540+
# model.layers.80.mtp_block.self_attn.attn.k_cache
541+
# model.layers.80.mtp_block.self_attn.attn.v_cache
542+
# which were already saved into mtp_activation_stats.json above.
543+
# When --kv-granularity=per-head, we additionally collect per-head
544+
# min/max from the parallel KVCachePerHeadHook and OVERWRITE those
545+
# same keys with list-valued min/max so the stage-2 quantizer
546+
# (which switches on isinstance(min, list) -> per_head) picks the
547+
# finer granularity. Hooks are then removed.
548+
# ---------------------------------------------------------------
549+
if args.kv_granularity == "per-head":
550+
print("\n[MTP] Collecting per-head KV-cache statistics from draft model...")
551+
mtp_ph_stats_list = llm.llm_engine.collective_rpc(
552+
lambda w: _apply_on_draft_model(w, get_kvcache_perhead_stats)
553+
)
554+
# Pick the first non-None result (rank-0 carries the gathered data).
555+
mtp_ph_stats = next((r for r in (mtp_ph_stats_list or []) if r), None)
556+
if not mtp_ph_stats:
557+
print(
558+
"[MTP] WARNING: no per-head KV-cache stats collected from "
559+
"draft model; mtp_activation_stats.json will keep per-tensor "
560+
"scalars for k_cache/v_cache entries."
561+
)
562+
else:
563+
mtp_act_path = os.path.join(args.output_dir, "mtp_activation_stats.json")
564+
if os.path.exists(mtp_act_path):
565+
with open(mtp_act_path, "r", encoding="utf8") as _f:
566+
merged_mtp = json.load(_f)
567+
else:
568+
merged_mtp = {}
569+
# dict.update() will overwrite the per-tensor (scalar) values
570+
# for the same keys with the per-head (list) values.
571+
merged_mtp.update(mtp_ph_stats)
572+
with open(mtp_act_path, "w", encoding="utf8") as _f:
573+
json.dump(merged_mtp, _f, indent=2)
574+
print(
575+
f"[MTP] Merged {len(mtp_ph_stats)} per-head KV-cache entries "
576+
f"into {mtp_act_path} (per-tensor scalars overwritten)."
577+
)
578+
579+
# Clean up the per-head hooks on the draft model.
580+
llm.llm_engine.collective_rpc(
581+
lambda w: _apply_on_draft_model(w, remove_kvcache_perhead_hooks)
582+
)
583+
514584
print("\n" + "=" * 80)
515585
print("Calibration completed successfully!")
516586
print(f"Results saved to: {args.output_dir}")

0 commit comments

Comments
 (0)