diff --git a/tools/fp8_quant_with_vllm_activation.py b/tools/fp8_quant_with_vllm_activation.py index a4dca258..e13b6a62 100644 --- a/tools/fp8_quant_with_vllm_activation.py +++ b/tools/fp8_quant_with_vllm_activation.py @@ -3,6 +3,7 @@ import math import multiprocessing as mp import os +import re import shutil from argparse import ArgumentParser from typing import Dict, List, Tuple @@ -644,6 +645,56 @@ def process_moe_values(data: Dict[str, Dict]) -> Dict[str, Dict]: print("no moe_expert_stats") raise AssertionError("moe_expert_stats file is required") # print(ac_json_data["model.layers.0.mlp.gate_up_proj"]) + + # ------------------------------------------------------------------ + # MTP (Multi-Token Prediction) calibration stats merge + # + # When stage 1 was run with --enable-mtp, the MTP draft layer's + # activation / MoE expert statistics are written to a SEPARATE pair of + # files: + # - mtp_activation_stats.json + # - mtp_moe_expert_stats.json + # Their keys carry an extra ``.mtp_block.`` segment that reflects the + # vLLM hunyuan_mtp draft module structure, e.g. + # model.layers.80.mtp_block.self_attn.qkv_proj + # model.layers.80.mtp_block.mlp.experts.0.gate_up_proj + # However, the HF safetensors weight names for the same MTP layer do + # NOT contain ``mtp_block`` (e.g. ``model.layers.80.mlp.experts.0. + # gate_proj.weight``), so process_safetensor() derives lookup keys + # without that segment. To keep the two sides aligned, we strip the + # ``.mtp_block`` segment from MTP stat keys before merging. + # ------------------------------------------------------------------ + _MTP_BLOCK_RE = re.compile(r"\.mtp_block\.") + + def _strip_mtp_block_keys(d: Dict[str, dict]) -> Dict[str, dict]: + return {_MTP_BLOCK_RE.sub(".", k): v for k, v in d.items()} + + mtp_act_path = os.path.join(args.input_vllm_ac_json_path, "mtp_activation_stats.json") + if os.path.isfile(mtp_act_path): + with open(mtp_act_path, "r", encoding="utf8") as fp: + mtp_act_stats = json.load(fp) + merged = _strip_mtp_block_keys(mtp_act_stats) + ac_json_data.update(merged) + print( + f"[mtp-stats] merged {len(merged)} MTP activation entries from " + f"{mtp_act_path} (keys with .mtp_block. stripped)" + ) + else: + print(f"[mtp-stats] {mtp_act_path} not found; skipping MTP activation merge") + + mtp_moe_path = os.path.join(args.input_vllm_ac_json_path, "mtp_moe_expert_stats.json") + if os.path.isfile(mtp_moe_path): + with open(mtp_moe_path, "r", encoding="utf8") as fp: + mtp_moe_stats = json.load(fp) + merged = _strip_mtp_block_keys(mtp_moe_stats) + ac_json_data.update(merged) + print( + f"[mtp-stats] merged {len(merged)} MTP MoE expert entries from " + f"{mtp_moe_path} (keys with .mtp_block. stripped)" + ) + else: + print(f"[mtp-stats] {mtp_moe_path} not found; skipping MTP MoE merge") + # ac_json_data = merge_vllm_act_moe_jsonl(ac_json_data) print(ac_json_data["model.layers.11.mlp.experts.1.gate_up_proj"]) print(ac_json_data["model.layers.11.mlp.experts.22.gate_up_proj"]) diff --git a/tools/run_vllm_calibrate.py b/tools/run_vllm_calibrate.py index 9a00e02d..a5ca8cbe 100644 --- a/tools/run_vllm_calibrate.py +++ b/tools/run_vllm_calibrate.py @@ -39,6 +39,7 @@ remove_kvcache_perhead_hooks, remove_kvcache_perhead_value_hooks, setup_activation_hooks, + setup_kvcache_perhead_hooks, setup_kvcache_perhead_value_hooks, setup_kvcache_value_hooks, setup_mtp_activation_hooks, @@ -314,7 +315,7 @@ def main(): speculative_config = None if args.enable_mtp: speculative_config = { - "method": "hunyuan_mtp", + "method": "mtp", "num_speculative_tokens": args.num_speculative_tokens, } print(f" MTP Enabled: True (num_speculative_tokens={args.num_speculative_tokens})") @@ -382,6 +383,27 @@ def main(): else: print(f"Worker {i}: No MTP draft model available") + # Per-head KV-cache hooks for the MTP draft model (only when the user + # asked for per-head granularity at the main-model level). We reuse + # ``setup_kvcache_perhead_hooks`` which finds Attention layers via + # ``_find_layers`` and installs a separate KVCachePerHeadHook on each. + # This coexists with the per-tensor ``KVCacheHook`` already registered + # by ``setup_mtp_activation_hooks``; both fire on every forward, and we + # later overwrite the per-tensor scalars in mtp_activation_stats.json + # with the per-head lists so the stage-2 quantizer picks them up. + if args.kv_granularity == "per-head": + print("\n" + "=" * 80) + print("Setting up MTP draft model per-head KV-cache hooks...") + print("=" * 80) + mtp_ph_results = llm.llm_engine.collective_rpc( + lambda w: _apply_on_draft_model(w, setup_kvcache_perhead_hooks) + ) + for i, result in enumerate(mtp_ph_results): + if result is not None: + print(f"Worker {i}: {result}") + else: + print(f"Worker {i}: No MTP draft model available (per-head KV)") + # Load dataset and prepare prompts print("\n" + "=" * 80) print("Loading dataset and preparing prompts...") @@ -511,6 +533,54 @@ def main(): stats_type="MTP MoE expert statistics", ) + # --------------------------------------------------------------- + # Per-head KV-cache stats for the MTP draft model. + # The per-tensor KVCacheHook (registered inside setup_mtp_activation + # _hooks) writes scalar min/max under keys like + # model.layers.80.mtp_block.self_attn.attn.k_cache + # model.layers.80.mtp_block.self_attn.attn.v_cache + # which were already saved into mtp_activation_stats.json above. + # When --kv-granularity=per-head, we additionally collect per-head + # min/max from the parallel KVCachePerHeadHook and OVERWRITE those + # same keys with list-valued min/max so the stage-2 quantizer + # (which switches on isinstance(min, list) -> per_head) picks the + # finer granularity. Hooks are then removed. + # --------------------------------------------------------------- + if args.kv_granularity == "per-head": + print("\n[MTP] Collecting per-head KV-cache statistics from draft model...") + mtp_ph_stats_list = llm.llm_engine.collective_rpc( + lambda w: _apply_on_draft_model(w, get_kvcache_perhead_stats) + ) + # Pick the first non-None result (rank-0 carries the gathered data). + mtp_ph_stats = next((r for r in (mtp_ph_stats_list or []) if r), None) + if not mtp_ph_stats: + print( + "[MTP] WARNING: no per-head KV-cache stats collected from " + "draft model; mtp_activation_stats.json will keep per-tensor " + "scalars for k_cache/v_cache entries." + ) + else: + mtp_act_path = os.path.join(args.output_dir, "mtp_activation_stats.json") + if os.path.exists(mtp_act_path): + with open(mtp_act_path, "r", encoding="utf8") as _f: + merged_mtp = json.load(_f) + else: + merged_mtp = {} + # dict.update() will overwrite the per-tensor (scalar) values + # for the same keys with the per-head (list) values. + merged_mtp.update(mtp_ph_stats) + with open(mtp_act_path, "w", encoding="utf8") as _f: + json.dump(merged_mtp, _f, indent=2) + print( + f"[MTP] Merged {len(mtp_ph_stats)} per-head KV-cache entries " + f"into {mtp_act_path} (per-tensor scalars overwritten)." + ) + + # Clean up the per-head hooks on the draft model. + llm.llm_engine.collective_rpc( + lambda w: _apply_on_draft_model(w, remove_kvcache_perhead_hooks) + ) + print("\n" + "=" * 80) print("Calibration completed successfully!") print(f"Results saved to: {args.output_dir}")