Skip to content

Commit 881aae2

Browse files
committed
update kv
1 parent c7cf058 commit 881aae2

16 files changed

Lines changed: 1390 additions & 675 deletions

File tree

configs/lingbot_fast/lingbot_fast_i2v.json

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"text_len": 512,
55
"target_height": 720,
66
"target_width": 1280,
7-
"self_attn_1_type": "sage_attn2",
7+
"self_attn_1_type": "sage_attn2_k_int8_v_fp8",
88
"cross_attn_1_type": "sage_attn2",
99
"cross_attn_2_type": "sage_attn2",
1010
"sample_guide_scale": 1.0,
@@ -20,7 +20,13 @@
2020
"num_frame_per_chunk": 3,
2121
"timesteps_index": [0, 179, 358, 679],
2222
"sink_size": 3,
23-
"kv_offload": false
23+
"kv_quant": {
24+
"calibrate": false,
25+
"calib_path": "calib_kv.pt",
26+
"quant_scheme": "sage",
27+
"k_cache_type": "int8",
28+
"v_cache_type": "fp8"
29+
}
2430
},
2531
"rms_norm_type": "self_forcing"
2632
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"infer_steps": 4,
3+
"target_video_length": 161,
4+
"text_len": 512,
5+
"target_height": 720,
6+
"target_width": 1280,
7+
"self_attn_1_type": "sage_attn2",
8+
"cross_attn_1_type": "sage_attn2",
9+
"cross_attn_2_type": "sage_attn2",
10+
"sample_guide_scale": 1.0,
11+
"sample_shift": 10.0,
12+
"enable_cfg": false,
13+
"cpu_offload": true,
14+
"offload_granularity": "block",
15+
"t5_cpu_offload": true,
16+
"vae_cpu_offload": true,
17+
"use_image_encoder": false,
18+
"dit_original_ckpt": "/data/nvme4/models/lingbot-world-base-cam/lingbot_world_fast/",
19+
"ar_config": {
20+
"local_attn_size": 21,
21+
"num_frame_per_chunk": 3,
22+
"timesteps_index": [0, 179, 358, 679],
23+
"sink_size": 3,
24+
"kv_quant": {
25+
"quant_scheme": "kivi",
26+
"k_cache_type": "int4",
27+
"v_cache_type": "int4",
28+
"group_size": 64
29+
},
30+
"kv_offload": true
31+
},
32+
"rms_norm_type": "self_forcing"
33+
}

configs/lingbot_fast/lingbot_fast_i2v_kv_quant_offload.json renamed to configs/lingbot_fast/lingbot_fast_i2v_kv_sagequant.json

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"sample_guide_scale": 1.0,
1111
"sample_shift": 10.0,
1212
"enable_cfg": false,
13-
"cpu_offload": false,
13+
"cpu_offload": true,
14+
"offload_granularity": "block",
1415
"t5_cpu_offload": true,
1516
"vae_cpu_offload": true,
1617
"use_image_encoder": false,
@@ -22,8 +23,10 @@
2223
"sink_size": 3,
2324
"kv_quant": {
2425
"calibrate": false,
25-
"smooth_k": true,
26-
"calib_path": "calib_kv.pt"
26+
"calib_path": "calib_kv.pt",
27+
"quant_scheme": "sage",
28+
"k_cache_type": "int8",
29+
"v_cache_type": "fp8"
2730
},
2831
"kv_offload": true
2932
},

lightx2v/common/kvcache/__init__.py

100644100755
Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
1-
"""
2-
KV cache for autoregressive transformer inference.
3-
4-
- ``base``: cross-attention pool
5-
- ``rolling``: ``RollingKVCachePool`` (bf16 rolling-window cache)
6-
- ``quant``: ``CalibRollingKVCachePool`` / ``QuantRollingKVCachePool``
7-
- ``offload``: ``OffloadRollingKVCachePool`` / ``OffloadQuantRollingKVCachePool``
8-
- ``manager``: ``KVCacheManager``
9-
"""
10-
111
from .manager import KVCacheManager
12-
from .offload import OffloadQuantRollingKVCachePool, OffloadRollingKVCachePool
13-
from .quant import CalibRollingKVCachePool, QuantRollingKVCachePool
2+
from .offload import KVOffloadPlugin
3+
from .quant import CalibRollingKVCachePool, SageQuantRollingKVCachePool
144
from .rolling import RollingKVCachePool
155

166
__all__ = [
177
"KVCacheManager",
8+
"KVOffloadPlugin",
189
"RollingKVCachePool",
1910
"CalibRollingKVCachePool",
20-
"QuantRollingKVCachePool",
21-
"OffloadRollingKVCachePool",
22-
"OffloadQuantRollingKVCachePool",
11+
"SageQuantRollingKVCachePool"
2312
]

lightx2v/common/kvcache/base.py

100644100755
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import torch
2+
import torch.distributed as dist
3+
4+
from lightx2v.common.ops.attn.utils.all2all import all2all_head2seq, all2all_seq2head
25

36

47
class BaseKVCachePool:
@@ -31,9 +34,13 @@ def _init_kv_buffer(self):
3134
)
3235

3336
def k_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
37+
if attn_start is None and local_end is None:
38+
return self._k_buffer[layer_id]
3439
return self._k_buffer[layer_id][attn_start:local_end]
3540

3641
def v_cache(self, layer_id: int, attn_start: int | None = None, local_end: int | None = None) -> torch.Tensor:
42+
if attn_start is None and local_end is None:
43+
return self._v_buffer[layer_id]
3744
return self._v_buffer[layer_id][attn_start:local_end]
3845

3946
def store_kv(self, k: torch.Tensor, v: torch.Tensor, layer_id: int) -> None:
@@ -44,6 +51,70 @@ def reset(self) -> None:
4451
self._k_buffer.zero_()
4552
self._v_buffer.zero_()
4653

54+
def sp_kvcache_attn(
55+
self,
56+
q: torch.Tensor,
57+
k_cache,
58+
v_cache,
59+
attention_module,
60+
seq_p_group,
61+
num_heads: int,
62+
head_dim: int,
63+
*,
64+
attn_start: int | None = None,
65+
local_end: int | None = None,
66+
) -> torch.Tensor:
67+
world_size = dist.get_world_size(seq_p_group)
68+
shard_heads = num_heads // world_size
69+
70+
full_q = all2all_seq2head(q, group=seq_p_group)
71+
if isinstance(k_cache, tuple) or isinstance(v_cache, tuple):
72+
full_k, full_v, full_kv_len = self._sp_quant_kv_to_head_shard(
73+
k_cache=k_cache,
74+
v_cache=v_cache,
75+
shard_heads=shard_heads,
76+
seq_p_group=seq_p_group,
77+
attn_start=attn_start,
78+
local_end=local_end,
79+
)
80+
else:
81+
full_k = all2all_seq2head(k_cache, group=seq_p_group)
82+
full_v = all2all_seq2head(v_cache, group=seq_p_group)
83+
full_kv_len = int(full_k.size(0))
84+
85+
q_lens = torch.tensor([full_q.size(0)], dtype=torch.int32)
86+
k_lens = torch.tensor([full_kv_len], dtype=torch.int32)
87+
cu_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
88+
cu_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
89+
90+
attn_out = attention_module.apply(
91+
q=full_q,
92+
k=full_k,
93+
v=full_v,
94+
cu_seqlens_q=cu_q,
95+
cu_seqlens_kv=cu_k,
96+
max_seqlen_q=full_q.size(0),
97+
max_seqlen_kv=full_kv_len,
98+
)
99+
attn_out = attn_out.view(full_q.size(0), shard_heads, head_dim)
100+
attn_out = all2all_head2seq(attn_out, group=seq_p_group)
101+
return attn_out.reshape(q.size(0), num_heads * head_dim)
102+
103+
def _sp_quant_kv_to_head_shard(
104+
self,
105+
k_cache,
106+
v_cache,
107+
shard_heads: int,
108+
seq_p_group,
109+
*,
110+
attn_start: int | None = None,
111+
local_end: int | None = None,
112+
):
113+
raise TypeError(
114+
f"{self.__class__.__name__} does not support tuple K/V in SP path. "
115+
"Please use a cache class that implements _sp_quant_kv_to_head_shard."
116+
)
117+
47118
@property
48119
def device(self) -> torch.device:
49120
return self._device

lightx2v/common/kvcache/calibrate.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

0 commit comments

Comments
 (0)