|
39 | 39 | remove_kvcache_perhead_hooks, |
40 | 40 | remove_kvcache_perhead_value_hooks, |
41 | 41 | setup_activation_hooks, |
| 42 | + setup_kvcache_perhead_hooks, |
42 | 43 | setup_kvcache_perhead_value_hooks, |
43 | 44 | setup_kvcache_value_hooks, |
44 | 45 | setup_mtp_activation_hooks, |
@@ -314,7 +315,7 @@ def main(): |
314 | 315 | speculative_config = None |
315 | 316 | if args.enable_mtp: |
316 | 317 | speculative_config = { |
317 | | - "method": "hunyuan_mtp", |
| 318 | + "method": "mtp", |
318 | 319 | "num_speculative_tokens": args.num_speculative_tokens, |
319 | 320 | } |
320 | 321 | print(f" MTP Enabled: True (num_speculative_tokens={args.num_speculative_tokens})") |
@@ -382,6 +383,27 @@ def main(): |
382 | 383 | else: |
383 | 384 | print(f"Worker {i}: No MTP draft model available") |
384 | 385 |
|
| 386 | + # Per-head KV-cache hooks for the MTP draft model (only when the user |
| 387 | + # asked for per-head granularity at the main-model level). We reuse |
| 388 | + # ``setup_kvcache_perhead_hooks`` which finds Attention layers via |
| 389 | + # ``_find_layers`` and installs a separate KVCachePerHeadHook on each. |
| 390 | + # This coexists with the per-tensor ``KVCacheHook`` already registered |
| 391 | + # by ``setup_mtp_activation_hooks``; both fire on every forward, and we |
| 392 | + # later overwrite the per-tensor scalars in mtp_activation_stats.json |
| 393 | + # with the per-head lists so the stage-2 quantizer picks them up. |
| 394 | + if args.kv_granularity == "per-head": |
| 395 | + print("\n" + "=" * 80) |
| 396 | + print("Setting up MTP draft model per-head KV-cache hooks...") |
| 397 | + print("=" * 80) |
| 398 | + mtp_ph_results = llm.llm_engine.collective_rpc( |
| 399 | + lambda w: _apply_on_draft_model(w, setup_kvcache_perhead_hooks) |
| 400 | + ) |
| 401 | + for i, result in enumerate(mtp_ph_results): |
| 402 | + if result is not None: |
| 403 | + print(f"Worker {i}: {result}") |
| 404 | + else: |
| 405 | + print(f"Worker {i}: No MTP draft model available (per-head KV)") |
| 406 | + |
385 | 407 | # Load dataset and prepare prompts |
386 | 408 | print("\n" + "=" * 80) |
387 | 409 | print("Loading dataset and preparing prompts...") |
@@ -511,6 +533,54 @@ def main(): |
511 | 533 | stats_type="MTP MoE expert statistics", |
512 | 534 | ) |
513 | 535 |
|
| 536 | + # --------------------------------------------------------------- |
| 537 | + # Per-head KV-cache stats for the MTP draft model. |
| 538 | + # The per-tensor KVCacheHook (registered inside setup_mtp_activation |
| 539 | + # _hooks) writes scalar min/max under keys like |
| 540 | + # model.layers.80.mtp_block.self_attn.attn.k_cache |
| 541 | + # model.layers.80.mtp_block.self_attn.attn.v_cache |
| 542 | + # which were already saved into mtp_activation_stats.json above. |
| 543 | + # When --kv-granularity=per-head, we additionally collect per-head |
| 544 | + # min/max from the parallel KVCachePerHeadHook and OVERWRITE those |
| 545 | + # same keys with list-valued min/max so the stage-2 quantizer |
| 546 | + # (which switches on isinstance(min, list) -> per_head) picks the |
| 547 | + # finer granularity. Hooks are then removed. |
| 548 | + # --------------------------------------------------------------- |
| 549 | + if args.kv_granularity == "per-head": |
| 550 | + print("\n[MTP] Collecting per-head KV-cache statistics from draft model...") |
| 551 | + mtp_ph_stats_list = llm.llm_engine.collective_rpc( |
| 552 | + lambda w: _apply_on_draft_model(w, get_kvcache_perhead_stats) |
| 553 | + ) |
| 554 | + # Pick the first non-None result (rank-0 carries the gathered data). |
| 555 | + mtp_ph_stats = next((r for r in (mtp_ph_stats_list or []) if r), None) |
| 556 | + if not mtp_ph_stats: |
| 557 | + print( |
| 558 | + "[MTP] WARNING: no per-head KV-cache stats collected from " |
| 559 | + "draft model; mtp_activation_stats.json will keep per-tensor " |
| 560 | + "scalars for k_cache/v_cache entries." |
| 561 | + ) |
| 562 | + else: |
| 563 | + mtp_act_path = os.path.join(args.output_dir, "mtp_activation_stats.json") |
| 564 | + if os.path.exists(mtp_act_path): |
| 565 | + with open(mtp_act_path, "r", encoding="utf8") as _f: |
| 566 | + merged_mtp = json.load(_f) |
| 567 | + else: |
| 568 | + merged_mtp = {} |
| 569 | + # dict.update() will overwrite the per-tensor (scalar) values |
| 570 | + # for the same keys with the per-head (list) values. |
| 571 | + merged_mtp.update(mtp_ph_stats) |
| 572 | + with open(mtp_act_path, "w", encoding="utf8") as _f: |
| 573 | + json.dump(merged_mtp, _f, indent=2) |
| 574 | + print( |
| 575 | + f"[MTP] Merged {len(mtp_ph_stats)} per-head KV-cache entries " |
| 576 | + f"into {mtp_act_path} (per-tensor scalars overwritten)." |
| 577 | + ) |
| 578 | + |
| 579 | + # Clean up the per-head hooks on the draft model. |
| 580 | + llm.llm_engine.collective_rpc( |
| 581 | + lambda w: _apply_on_draft_model(w, remove_kvcache_perhead_hooks) |
| 582 | + ) |
| 583 | + |
514 | 584 | print("\n" + "=" * 80) |
515 | 585 | print("Calibration completed successfully!") |
516 | 586 | print(f"Results saved to: {args.output_dir}") |
|
0 commit comments