Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tools/fp8_quant_with_vllm_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import multiprocessing as mp
import os
import re
import shutil
from argparse import ArgumentParser
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -644,6 +645,56 @@ def process_moe_values(data: Dict[str, Dict]) -> Dict[str, Dict]:
print("no moe_expert_stats")
raise AssertionError("moe_expert_stats file is required")
# print(ac_json_data["model.layers.0.mlp.gate_up_proj"])

# ------------------------------------------------------------------
# MTP (Multi-Token Prediction) calibration stats merge
#
# When stage 1 was run with --enable-mtp, the MTP draft layer's
# activation / MoE expert statistics are written to a SEPARATE pair of
# files:
# - mtp_activation_stats.json
# - mtp_moe_expert_stats.json
# Their keys carry an extra ``.mtp_block.`` segment that reflects the
# vLLM hunyuan_mtp draft module structure, e.g.
# model.layers.80.mtp_block.self_attn.qkv_proj
# model.layers.80.mtp_block.mlp.experts.0.gate_up_proj
# However, the HF safetensors weight names for the same MTP layer do
# NOT contain ``mtp_block`` (e.g. ``model.layers.80.mlp.experts.0.
# gate_proj.weight``), so process_safetensor() derives lookup keys
# without that segment. To keep the two sides aligned, we strip the
# ``.mtp_block`` segment from MTP stat keys before merging.
# ------------------------------------------------------------------
_MTP_BLOCK_RE = re.compile(r"\.mtp_block\.")

def _strip_mtp_block_keys(d: Dict[str, dict]) -> Dict[str, dict]:
return {_MTP_BLOCK_RE.sub(".", k): v for k, v in d.items()}

mtp_act_path = os.path.join(args.input_vllm_ac_json_path, "mtp_activation_stats.json")
if os.path.isfile(mtp_act_path):
with open(mtp_act_path, "r", encoding="utf8") as fp:
mtp_act_stats = json.load(fp)
merged = _strip_mtp_block_keys(mtp_act_stats)
ac_json_data.update(merged)
print(
f"[mtp-stats] merged {len(merged)} MTP activation entries from "
f"{mtp_act_path} (keys with .mtp_block. stripped)"
)
else:
print(f"[mtp-stats] {mtp_act_path} not found; skipping MTP activation merge")

mtp_moe_path = os.path.join(args.input_vllm_ac_json_path, "mtp_moe_expert_stats.json")
if os.path.isfile(mtp_moe_path):
with open(mtp_moe_path, "r", encoding="utf8") as fp:
mtp_moe_stats = json.load(fp)
merged = _strip_mtp_block_keys(mtp_moe_stats)
ac_json_data.update(merged)
print(
f"[mtp-stats] merged {len(merged)} MTP MoE expert entries from "
f"{mtp_moe_path} (keys with .mtp_block. stripped)"
)
else:
print(f"[mtp-stats] {mtp_moe_path} not found; skipping MTP MoE merge")

# ac_json_data = merge_vllm_act_moe_jsonl(ac_json_data)
print(ac_json_data["model.layers.11.mlp.experts.1.gate_up_proj"])
print(ac_json_data["model.layers.11.mlp.experts.22.gate_up_proj"])
Expand Down
72 changes: 71 additions & 1 deletion tools/run_vllm_calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
remove_kvcache_perhead_hooks,
remove_kvcache_perhead_value_hooks,
setup_activation_hooks,
setup_kvcache_perhead_hooks,
setup_kvcache_perhead_value_hooks,
setup_kvcache_value_hooks,
setup_mtp_activation_hooks,
Expand Down Expand Up @@ -314,7 +315,7 @@ def main():
speculative_config = None
if args.enable_mtp:
speculative_config = {
"method": "hunyuan_mtp",
"method": "mtp",
"num_speculative_tokens": args.num_speculative_tokens,
}
print(f" MTP Enabled: True (num_speculative_tokens={args.num_speculative_tokens})")
Expand Down Expand Up @@ -382,6 +383,27 @@ def main():
else:
print(f"Worker {i}: No MTP draft model available")

# Per-head KV-cache hooks for the MTP draft model (only when the user
# asked for per-head granularity at the main-model level). We reuse
# ``setup_kvcache_perhead_hooks`` which finds Attention layers via
# ``_find_layers`` and installs a separate KVCachePerHeadHook on each.
# This coexists with the per-tensor ``KVCacheHook`` already registered
# by ``setup_mtp_activation_hooks``; both fire on every forward, and we
# later overwrite the per-tensor scalars in mtp_activation_stats.json
# with the per-head lists so the stage-2 quantizer picks them up.
if args.kv_granularity == "per-head":
print("\n" + "=" * 80)
print("Setting up MTP draft model per-head KV-cache hooks...")
print("=" * 80)
mtp_ph_results = llm.llm_engine.collective_rpc(
lambda w: _apply_on_draft_model(w, setup_kvcache_perhead_hooks)
)
for i, result in enumerate(mtp_ph_results):
if result is not None:
print(f"Worker {i}: {result}")
else:
print(f"Worker {i}: No MTP draft model available (per-head KV)")

# Load dataset and prepare prompts
print("\n" + "=" * 80)
print("Loading dataset and preparing prompts...")
Expand Down Expand Up @@ -511,6 +533,54 @@ def main():
stats_type="MTP MoE expert statistics",
)

# ---------------------------------------------------------------
# Per-head KV-cache stats for the MTP draft model.
# The per-tensor KVCacheHook (registered inside setup_mtp_activation
# _hooks) writes scalar min/max under keys like
# model.layers.80.mtp_block.self_attn.attn.k_cache
# model.layers.80.mtp_block.self_attn.attn.v_cache
# which were already saved into mtp_activation_stats.json above.
# When --kv-granularity=per-head, we additionally collect per-head
# min/max from the parallel KVCachePerHeadHook and OVERWRITE those
# same keys with list-valued min/max so the stage-2 quantizer
# (which switches on isinstance(min, list) -> per_head) picks the
# finer granularity. Hooks are then removed.
# ---------------------------------------------------------------
if args.kv_granularity == "per-head":
print("\n[MTP] Collecting per-head KV-cache statistics from draft model...")
mtp_ph_stats_list = llm.llm_engine.collective_rpc(
lambda w: _apply_on_draft_model(w, get_kvcache_perhead_stats)
)
# Pick the first non-None result (rank-0 carries the gathered data).
mtp_ph_stats = next((r for r in (mtp_ph_stats_list or []) if r), None)
if not mtp_ph_stats:
print(
"[MTP] WARNING: no per-head KV-cache stats collected from "
"draft model; mtp_activation_stats.json will keep per-tensor "
"scalars for k_cache/v_cache entries."
)
else:
mtp_act_path = os.path.join(args.output_dir, "mtp_activation_stats.json")
if os.path.exists(mtp_act_path):
with open(mtp_act_path, "r", encoding="utf8") as _f:
merged_mtp = json.load(_f)
else:
merged_mtp = {}
# dict.update() will overwrite the per-tensor (scalar) values
# for the same keys with the per-head (list) values.
merged_mtp.update(mtp_ph_stats)
with open(mtp_act_path, "w", encoding="utf8") as _f:
json.dump(merged_mtp, _f, indent=2)
print(
f"[MTP] Merged {len(mtp_ph_stats)} per-head KV-cache entries "
f"into {mtp_act_path} (per-tensor scalars overwritten)."
)

# Clean up the per-head hooks on the draft model.
llm.llm_engine.collective_rpc(
lambda w: _apply_on_draft_model(w, remove_kvcache_perhead_hooks)
)

print("\n" + "=" * 80)
print("Calibration completed successfully!")
print(f"Results saved to: {args.output_dir}")
Expand Down
Loading