|
| 1 | +""" |
| 2 | +KV-cache quantisation calibration. |
| 3 | +
|
| 4 | +Step 1 — Calibration run |
| 5 | +~~~~~~~~~~~~~~~~~~~~~~~~ |
| 6 | +Use a config with ``"calibrate": true`` and ``self_attn_1_type`` set to |
| 7 | +the **non-quant** attention (e.g. ``"sage_attn2"``). This creates a |
| 8 | +``CalibRollingKVCachePool`` that stores bf16 KV normally while |
| 9 | +collecting K-mean and V per-channel abs-max. |
| 10 | +
|
| 11 | +Config example (calibration):: |
| 12 | +
|
| 13 | + { |
| 14 | + "self_attn_1_type": "sage_attn2", |
| 15 | + "ar_config": { |
| 16 | + ... |
| 17 | + "sage_quant_kv": { |
| 18 | + "calibrate": true, |
| 19 | + "smooth_k": true |
| 20 | + } |
| 21 | + } |
| 22 | + } |
| 23 | +
|
| 24 | +After inference, call ``save_calibration`` to export the stats:: |
| 25 | +
|
| 26 | + from lightx2v.common.kvcache.calibrate import save_calibration |
| 27 | + runner.run_main() |
| 28 | + save_calibration(runner.model.kv_cache_manager, "calib_kv.pt") |
| 29 | +
|
| 30 | +Step 2 — Quantised inference |
| 31 | +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 32 | +Switch to the quant attention and provide the calibration file:: |
| 33 | +
|
| 34 | + { |
| 35 | + "self_attn_1_type": "sage_attn2_kvquant", |
| 36 | + "ar_config": { |
| 37 | + ... |
| 38 | + "sage_quant_kv": { |
| 39 | + "smooth_k": true, |
| 40 | + "calib_path": "calib_kv.pt" |
| 41 | + } |
| 42 | + } |
| 43 | + } |
| 44 | +""" |
| 45 | + |
| 46 | +from __future__ import annotations |
| 47 | + |
| 48 | +import torch |
| 49 | +from loguru import logger |
| 50 | + |
| 51 | +from .quant import CalibRollingKVCachePool |
| 52 | + |
| 53 | + |
| 54 | +def save_calibration( |
| 55 | + kv_cache_manager, |
| 56 | + output_path: str, |
| 57 | +) -> dict[str, torch.Tensor]: |
| 58 | + """Export and save KV cache calibration from a completed run. |
| 59 | +
|
| 60 | + Parameters |
| 61 | + ---------- |
| 62 | + kv_cache_manager : KVCacheManager |
| 63 | + The manager whose ``self_attn_kv_cache`` is a |
| 64 | + ``CalibRollingKVCachePool`` that has been used for at least one |
| 65 | + full inference pass. |
| 66 | + output_path : str |
| 67 | + File path to save the calibration dict (``torch.save`` format). |
| 68 | +
|
| 69 | + Returns |
| 70 | + ------- |
| 71 | + dict with keys ``'km'`` and ``'v_scale'``. |
| 72 | + """ |
| 73 | + pool = kv_cache_manager.self_attn_kv_cache |
| 74 | + if not isinstance(pool, CalibRollingKVCachePool): |
| 75 | + raise TypeError(f"Expected CalibRollingKVCachePool, got {type(pool).__name__}. Make sure the config has sage_quant_kv.calibrate=true and self_attn_1_type is NOT sage_attn2_kvquant.") |
| 76 | + |
| 77 | + calib = pool.export_calibration() |
| 78 | + torch.save(calib, output_path) |
| 79 | + logger.info( |
| 80 | + "KV calibration saved to {} — km {}, v_scale {}", |
| 81 | + output_path, |
| 82 | + list(calib["km"].shape), |
| 83 | + list(calib["v_scale"].shape), |
| 84 | + ) |
| 85 | + return calib |
0 commit comments