Skip to content

Commit d7ec87f

Browse files
support kv quant/offload (#1035)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent de2bb45 commit d7ec87f

20 files changed

Lines changed: 1856 additions & 281 deletions

File tree

configs/lingbot_fast/lingbot_fast_i2v.json

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
"vae_cpu_offload": true,
1616
"use_image_encoder": false,
1717
"dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/",
18-
"sf_config": {
19-
"local_attn_size": -1,
20-
"num_frame_per_block": 3,
21-
"timesteps_index": [0, 179, 358, 679]
18+
"ar_config": {
19+
"local_attn_size": 21,
20+
"num_frame_per_chunk": 3,
21+
"timesteps_index": [0, 179, 358, 679],
22+
"sink_size": 3,
23+
"kv_offload": false
2224
},
23-
"parallel": {
24-
"seq_p_size": 4,
25-
"seq_p_attn_type": "ulysses"
26-
}
25+
"rms_norm_type": "self_forcing"
2726
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
{
2+
"infer_steps": 4,
3+
"target_video_length": 161,
4+
"text_len": 512,
5+
"target_height": 720,
6+
"target_width": 1280,
7+
"self_attn_1_type": "sage_attn2_k_int8_v_fp8",
8+
"cross_attn_1_type": "sage_attn2",
9+
"cross_attn_2_type": "sage_attn2",
10+
"sample_guide_scale": 1.0,
11+
"sample_shift": 10.0,
12+
"enable_cfg": false,
13+
"cpu_offload": false,
14+
"t5_cpu_offload": true,
15+
"vae_cpu_offload": true,
16+
"use_image_encoder": false,
17+
"dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/",
18+
"ar_config": {
19+
"local_attn_size": 21,
20+
"num_frame_per_chunk": 3,
21+
"timesteps_index": [0, 179, 358, 679],
22+
"sink_size": 3,
23+
"kv_quant": {
24+
"calibrate": false,
25+
"smooth_k": true,
26+
"calib_path": "calib_kv.pt"
27+
},
28+
"kv_offload": true
29+
},
30+
"rms_norm_type": "self_forcing"
31+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
KV cache for autoregressive transformer inference.
3+
4+
- ``base``: cross-attention pool
5+
- ``rolling``: ``RollingKVCachePool`` (bf16 rolling-window cache)
6+
- ``quant``: ``CalibRollingKVCachePool`` / ``QuantRollingKVCachePool``
7+
- ``offload``: ``OffloadRollingKVCachePool`` / ``OffloadQuantRollingKVCachePool``
8+
- ``manager``: ``KVCacheManager``
9+
"""
10+
11+
from .manager import KVCacheManager
12+
from .offload import OffloadQuantRollingKVCachePool, OffloadRollingKVCachePool
13+
from .quant import CalibRollingKVCachePool, QuantRollingKVCachePool
14+
from .rolling import RollingKVCachePool
15+
16+
__all__ = [
17+
"KVCacheManager",
18+
"RollingKVCachePool",
19+
"CalibRollingKVCachePool",
20+
"QuantRollingKVCachePool",
21+
"OffloadRollingKVCachePool",
22+
"OffloadQuantRollingKVCachePool",
23+
]

lightx2v/common/kvcache/base.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
3+
4+
class BaseKVCachePool:
5+
def __init__(
6+
self,
7+
num_layers: int,
8+
cache_size: int,
9+
num_heads: int,
10+
head_dim: int,
11+
dtype: torch.dtype,
12+
device: torch.device,
13+
) -> None:
14+
self._num_layers = num_layers
15+
self._cache_size = cache_size
16+
self._num_heads = num_heads
17+
self._head_dim = head_dim
18+
self._device = device
19+
self._dtype = dtype
20+
21+
def _init_kv_buffer(self):
22+
self._k_buffer = torch.zeros(
23+
(self._num_layers, self._cache_size, self._num_heads, self._head_dim),
24+
dtype=self._dtype,
25+
device=self._device,
26+
)
27+
self._v_buffer = torch.zeros(
28+
(self._num_layers, self._cache_size, self._num_heads, self._head_dim),
29+
dtype=self._dtype,
30+
device=self._device,
31+
)
32+
33+
def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
34+
return self._k_buffer[layer_id][attn_start:local_end]
35+
36+
def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
37+
return self._v_buffer[layer_id][attn_start:local_end]
38+
39+
def store_kv(self, k: torch.Tensor, v: torch.Tensor, layer_id: int) -> None:
40+
self._k_buffer[layer_id, : k.shape[0]] = k
41+
self._v_buffer[layer_id, : v.shape[0]] = v
42+
43+
def reset(self) -> None:
44+
self._k_buffer.zero_()
45+
self._v_buffer.zero_()
46+
47+
@property
48+
def device(self) -> torch.device:
49+
return self._device
50+
51+
@property
52+
def dtype(self) -> torch.dtype:
53+
return self._dtype
54+
55+
@property
56+
def num_layers(self) -> int:
57+
return self._num_layers
58+
59+
@property
60+
def cache_size(self) -> int:
61+
return self._cache_size
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
KV-cache quantisation calibration.
3+
4+
Step 1 — Calibration run
5+
~~~~~~~~~~~~~~~~~~~~~~~~
6+
Use a config with ``"calibrate": true`` and ``self_attn_1_type`` set to
7+
the **non-quant** attention (e.g. ``"sage_attn2"``). This creates a
8+
``CalibRollingKVCachePool`` that stores bf16 KV normally while
9+
collecting K-mean and V per-channel abs-max.
10+
11+
Config example (calibration)::
12+
13+
{
14+
"self_attn_1_type": "sage_attn2",
15+
"ar_config": {
16+
...
17+
"sage_quant_kv": {
18+
"calibrate": true,
19+
"smooth_k": true
20+
}
21+
}
22+
}
23+
24+
After inference, call ``save_calibration`` to export the stats::
25+
26+
from lightx2v.common.kvcache.calibrate import save_calibration
27+
runner.run_main()
28+
save_calibration(runner.model.kv_cache_manager, "calib_kv.pt")
29+
30+
Step 2 — Quantised inference
31+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32+
Switch to the quant attention and provide the calibration file::
33+
34+
{
35+
"self_attn_1_type": "sage_attn2_kvquant",
36+
"ar_config": {
37+
...
38+
"sage_quant_kv": {
39+
"smooth_k": true,
40+
"calib_path": "calib_kv.pt"
41+
}
42+
}
43+
}
44+
"""
45+
46+
from __future__ import annotations
47+
48+
import torch
49+
from loguru import logger
50+
51+
from .quant import CalibRollingKVCachePool
52+
53+
54+
def save_calibration(
55+
kv_cache_manager,
56+
output_path: str,
57+
) -> dict[str, torch.Tensor]:
58+
"""Export and save KV cache calibration from a completed run.
59+
60+
Parameters
61+
----------
62+
kv_cache_manager : KVCacheManager
63+
The manager whose ``self_attn_kv_cache`` is a
64+
``CalibRollingKVCachePool`` that has been used for at least one
65+
full inference pass.
66+
output_path : str
67+
File path to save the calibration dict (``torch.save`` format).
68+
69+
Returns
70+
-------
71+
dict with keys ``'km'`` and ``'v_scale'``.
72+
"""
73+
pool = kv_cache_manager.self_attn_kv_cache
74+
if not isinstance(pool, CalibRollingKVCachePool):
75+
raise TypeError(f"Expected CalibRollingKVCachePool, got {type(pool).__name__}. Make sure the config has sage_quant_kv.calibrate=true and self_attn_1_type is NOT sage_attn2_kvquant.")
76+
77+
calib = pool.export_calibration()
78+
torch.save(calib, output_path)
79+
logger.info(
80+
"KV calibration saved to {} — km {}, v_scale {}",
81+
output_path,
82+
list(calib["km"].shape),
83+
list(calib["v_scale"].shape),
84+
)
85+
return calib

0 commit comments

Comments
 (0)