11import torch
2+ import torch .distributed as dist
3+
4+ from lightx2v .common .ops .attn .utils .all2all import all2all_head2seq , all2all_seq2head
25
36
47class BaseKVCachePool :
@@ -31,9 +34,13 @@ def _init_kv_buffer(self):
3134 )
3235
3336 def k_cache (self , layer_id : int , attn_start : int | None = None , local_end : int | None = None ) -> torch .Tensor :
37+ if attn_start is None and local_end is None :
38+ return self ._k_buffer [layer_id ]
3439 return self ._k_buffer [layer_id ][attn_start :local_end ]
3540
3641 def v_cache (self , layer_id : int , attn_start : int | None = None , local_end : int | None = None ) -> torch .Tensor :
42+ if attn_start is None and local_end is None :
43+ return self ._v_buffer [layer_id ]
3744 return self ._v_buffer [layer_id ][attn_start :local_end ]
3845
3946 def store_kv (self , k : torch .Tensor , v : torch .Tensor , layer_id : int ) -> None :
@@ -44,6 +51,70 @@ def reset(self) -> None:
4451 self ._k_buffer .zero_ ()
4552 self ._v_buffer .zero_ ()
4653
54+ def sp_kvcache_attn (
55+ self ,
56+ q : torch .Tensor ,
57+ k_cache ,
58+ v_cache ,
59+ attention_module ,
60+ seq_p_group ,
61+ num_heads : int ,
62+ head_dim : int ,
63+ * ,
64+ attn_start : int | None = None ,
65+ local_end : int | None = None ,
66+ ) -> torch .Tensor :
67+ world_size = dist .get_world_size (seq_p_group )
68+ shard_heads = num_heads // world_size
69+
70+ full_q = all2all_seq2head (q , group = seq_p_group )
71+ if isinstance (k_cache , tuple ) or isinstance (v_cache , tuple ):
72+ full_k , full_v , full_kv_len = self ._sp_quant_kv_to_head_shard (
73+ k_cache = k_cache ,
74+ v_cache = v_cache ,
75+ shard_heads = shard_heads ,
76+ seq_p_group = seq_p_group ,
77+ attn_start = attn_start ,
78+ local_end = local_end ,
79+ )
80+ else :
81+ full_k = all2all_seq2head (k_cache , group = seq_p_group )
82+ full_v = all2all_seq2head (v_cache , group = seq_p_group )
83+ full_kv_len = int (full_k .size (0 ))
84+
85+ q_lens = torch .tensor ([full_q .size (0 )], dtype = torch .int32 )
86+ k_lens = torch .tensor ([full_kv_len ], dtype = torch .int32 )
87+ cu_q = torch .cat ([q_lens .new_zeros ([1 ]), q_lens ]).cumsum (0 , dtype = torch .int32 )
88+ cu_k = torch .cat ([k_lens .new_zeros ([1 ]), k_lens ]).cumsum (0 , dtype = torch .int32 )
89+
90+ attn_out = attention_module .apply (
91+ q = full_q ,
92+ k = full_k ,
93+ v = full_v ,
94+ cu_seqlens_q = cu_q ,
95+ cu_seqlens_kv = cu_k ,
96+ max_seqlen_q = full_q .size (0 ),
97+ max_seqlen_kv = full_kv_len ,
98+ )
99+ attn_out = attn_out .view (full_q .size (0 ), shard_heads , head_dim )
100+ attn_out = all2all_head2seq (attn_out , group = seq_p_group )
101+ return attn_out .reshape (q .size (0 ), num_heads * head_dim )
102+
103+ def _sp_quant_kv_to_head_shard (
104+ self ,
105+ k_cache ,
106+ v_cache ,
107+ shard_heads : int ,
108+ seq_p_group ,
109+ * ,
110+ attn_start : int | None = None ,
111+ local_end : int | None = None ,
112+ ):
113+ raise TypeError (
114+ f"{ self .__class__ .__name__ } does not support tuple K/V in SP path. "
115+ "Please use a cache class that implements _sp_quant_kv_to_head_shard."
116+ )
117+
47118 @property
48119 def device (self ) -> torch .device :
49120 return self ._device
0 commit comments