@@ -70,6 +70,24 @@ def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[floa
7070 return attn_scores
7171
7272
73+ def _write_generate_kv_cache (
74+ k : torch .Tensor ,
75+ v : torch .Tensor ,
76+ k_cache : torch .Tensor ,
77+ v_cache : torch .Tensor ,
78+ slot_idx : torch .Tensor ,
79+ input_pos : torch .Tensor ,
80+ ):
81+ """Write single-token decode K/V into the cache."""
82+ b , s = k .shape [:2 ]
83+ assert s == 1 , f"Expected sequence length 1 for generate phase, got { s } "
84+ for i in range (b ):
85+ cache_idx = slot_idx [i ].item ()
86+ pos = input_pos [i ].item ()
87+ k_cache [cache_idx , pos ] = k [i , 0 ] # Remove sequence dim
88+ v_cache [cache_idx , pos ] = v [i , 0 ] # Remove sequence dim
89+
90+
7391def _torch_generate_mha (
7492 q : torch .Tensor ,
7593 k : torch .Tensor ,
@@ -89,12 +107,7 @@ def _torch_generate_mha(
89107 assert s == 1 , f"Expected sequence length 1 for generate phase, got { s } "
90108 n_kv_heads = k .shape [2 ] # k has shape (b, 1, n_kv_heads, head_dim)
91109
92- # Update KV cache for single token
93- for i in range (b ):
94- cache_idx = slot_idx [i ].item ()
95- pos = input_pos [i ].item ()
96- k_cache [cache_idx , pos ] = k [i , 0 ] # Remove sequence dim
97- v_cache [cache_idx , pos ] = v [i , 0 ] # Remove sequence dim
110+ _write_generate_kv_cache (k , v , k_cache , v_cache , slot_idx , input_pos )
98111
99112 # Compute attention for each sequence using manual computation
100113 for i in range (b ):
@@ -156,6 +169,60 @@ def _torch_generate_mha(
156169 out [i ] = attn_out .squeeze (1 ) # [n_heads, v_head_dim]
157170
158171
172+ def _torch_generate_mha_readonly (
173+ q : torch .Tensor ,
174+ k_cache : torch .Tensor ,
175+ v_cache : torch .Tensor ,
176+ slot_idx : torch .Tensor ,
177+ input_pos : torch .Tensor ,
178+ scale : float ,
179+ out : torch .Tensor ,
180+ logit_cap : Optional [float ] = None ,
181+ sliding_window_size : Optional [int ] = None ,
182+ sinks : Optional [torch .Tensor ] = None ,
183+ ):
184+ """Generate-only attention using an existing KV cache without writing current-layer K/V."""
185+ b , s , n_heads , head_dim = q .shape
186+ assert s == 1 , f"Expected sequence length 1 for generate phase, got { s } "
187+ n_kv_heads = k_cache .shape [2 ]
188+
189+ for i in range (b ):
190+ cache_idx = slot_idx [i ].item ()
191+ pos = input_pos [i ].item ()
192+ q_i = q [i , 0 ]
193+
194+ if sliding_window_size is not None and sliding_window_size > 0 :
195+ start_pos = max (0 , pos - sliding_window_size + 1 )
196+ k_i = k_cache [cache_idx , start_pos : pos + 1 ]
197+ v_i = v_cache [cache_idx , start_pos : pos + 1 ]
198+ else :
199+ k_i = k_cache [cache_idx , : pos + 1 ]
200+ v_i = v_cache [cache_idx , : pos + 1 ]
201+
202+ q_i = q_i .unsqueeze (1 )
203+ k_i = k_i .transpose (0 , 1 )
204+ v_i = v_i .transpose (0 , 1 )
205+
206+ if n_heads != n_kv_heads :
207+ n_rep = n_heads // n_kv_heads
208+ k_i = repeat_kv (k_i .unsqueeze (0 ), n_rep )[0 ]
209+ v_i = repeat_kv (v_i .unsqueeze (0 ), n_rep )[0 ]
210+
211+ attn_scores = torch .matmul (q_i , k_i .transpose (- 2 , - 1 )) * scale
212+ attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
213+
214+ if sinks is not None :
215+ sinks = sinks .reshape (- 1 , 1 , 1 )
216+ attn_weights = torch .cat ([attn_scores , sinks ], dim = - 1 )
217+ attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
218+ attn_out = torch .matmul (attn_weights [..., : - sinks .size (- 1 )], v_i )
219+ else :
220+ attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
221+ attn_out = torch .matmul (attn_weights , v_i )
222+
223+ out [i ] = attn_out .squeeze (1 )
224+
225+
159226def _torch_context_mha (
160227 q : torch .Tensor ,
161228 k : torch .Tensor ,
@@ -174,7 +241,6 @@ def _torch_context_mha(
174241 sinks : Optional [torch .Tensor ] = None ,
175242) -> None :
176243 """Context attention (multiple tokens, potentially multiple sequences) using existing torch functions."""
177- # Update KV cache first using existing function
178244 _update_kv_cache (k , v , k_cache , v_cache , seq_len , input_pos , slot_idx , seq_start )
179245
180246 # Compute attention for each sequence
@@ -293,9 +359,85 @@ def _torch_context_mha(
293359 out .copy_ (torch .cat (attn_outputs , dim = 0 ))
294360
295361
296- @torch .library .custom_op (
297- "auto_deploy::torch_cached_attention_with_cache" , mutates_args = ("k_cache" , "v_cache" )
298- )
362+ def _torch_context_mha_readonly (
363+ q : torch .Tensor ,
364+ input_pos : torch .Tensor ,
365+ slot_idx : torch .Tensor ,
366+ k_cache : torch .Tensor ,
367+ v_cache : torch .Tensor ,
368+ seq_len : torch .Tensor ,
369+ seq_start : torch .Tensor ,
370+ scale : float ,
371+ out : torch .Tensor ,
372+ logit_cap : Optional [float ] = None ,
373+ sliding_window_size : Optional [int ] = None ,
374+ sinks : Optional [torch .Tensor ] = None ,
375+ ) -> None :
376+ """Context attention using an existing KV cache without writing current-layer K/V."""
377+ attn_outputs = []
378+ for idx in range (seq_len .shape [0 ]):
379+ seq_len_i = seq_len [idx ].item ()
380+ input_pos_i = input_pos [idx ].item ()
381+ slot_idx_i = slot_idx [idx ].item ()
382+ seq_start_i = seq_start [idx ].item ()
383+
384+ if seq_len_i == 0 :
385+ continue
386+
387+ q_seq = q [seq_start_i : seq_start_i + seq_len_i ]
388+ kv_seq_len = input_pos_i + seq_len_i
389+ k_seq = k_cache [slot_idx_i , :kv_seq_len ]
390+ v_seq = v_cache [slot_idx_i , :kv_seq_len ]
391+
392+ n_heads = q_seq .shape [1 ]
393+ n_kv_heads = k_seq .shape [1 ]
394+
395+ q_seq_t = q_seq .transpose (0 , 1 ).unsqueeze (0 )
396+ k_seq_t = k_seq .transpose (0 , 1 ).unsqueeze (0 )
397+ v_seq_t = v_seq .transpose (0 , 1 ).unsqueeze (0 )
398+
399+ if n_heads != n_kv_heads :
400+ n_rep = n_heads // n_kv_heads
401+ k_seq_t = repeat_kv (k_seq_t , n_rep )
402+ v_seq_t = repeat_kv (v_seq_t , n_rep )
403+
404+ attn_scores = torch .matmul (q_seq_t , k_seq_t .transpose (- 2 , - 1 )) * scale
405+
406+ causal_mask = torch .triu (
407+ torch .ones (seq_len_i , kv_seq_len , device = q .device , dtype = torch .bool ),
408+ diagonal = 1 + input_pos_i ,
409+ )
410+ attn_scores .masked_fill_ (causal_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
411+
412+ if sliding_window_size is not None and sliding_window_size > 0 :
413+ query_positions = torch .arange (input_pos_i , input_pos_i + seq_len_i , device = q .device )
414+ key_positions = torch .arange (kv_seq_len , device = q .device )
415+ pos_diff = query_positions .unsqueeze (1 ) - key_positions .unsqueeze (0 )
416+ sliding_window_mask = (pos_diff < 0 ) | (pos_diff >= sliding_window_size )
417+ attn_scores .masked_fill_ (sliding_window_mask .unsqueeze (0 ).unsqueeze (0 ), float ("-inf" ))
418+
419+ attn_scores = _apply_logit_softcapping (attn_scores , logit_cap )
420+
421+ if sinks is not None :
422+ new_sinks = sinks .reshape (1 , - 1 , 1 , 1 ).expand (1 , n_heads , seq_len_i , 1 )
423+ attn_weights = torch .cat ([attn_scores , new_sinks ], dim = - 1 )
424+ attn_weights = torch .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
425+ attn_out = torch .matmul (attn_weights [..., : - new_sinks .size (- 1 )], v_seq_t )
426+ else :
427+ attn_weights = torch .softmax (attn_scores , dim = - 1 , dtype = torch .float32 ).to (q .dtype )
428+ attn_out = torch .matmul (attn_weights , v_seq_t )
429+
430+ attn_outputs .append (attn_out [0 ].transpose (0 , 1 ))
431+
432+ if len (attn_outputs ) == 0 :
433+ out .zero_ ()
434+ elif len (attn_outputs ) == 1 :
435+ out .copy_ (attn_outputs [0 ])
436+ else :
437+ out .copy_ (torch .cat (attn_outputs , dim = 0 ))
438+
439+
440+ @torch .library .custom_op ("auto_deploy::torch_cached_attention_with_cache" , mutates_args = ())
299441def torch_backend_mha_with_cache (
300442 # Q, K, V
301443 q : torch .Tensor ,
@@ -320,6 +462,7 @@ def torch_backend_mha_with_cache(
320462 sinks : Optional [torch .Tensor ] = None ,
321463 sliding_window_size : Optional [int ] = None ,
322464 logit_cap : Optional [float ] = None ,
465+ read_cache_only : bool = False ,
323466 out : Optional [torch .Tensor ] = None ,
324467) -> torch .Tensor :
325468 """Torch backend MHA with cache that takes q, k, v in BSND layout."""
@@ -359,12 +502,15 @@ def torch_backend_mha_with_cache(
359502 y = q .new_empty (* bs_view , num_heads , v_head_dim ).contiguous ()
360503
361504 # Compute attention
505+ if not read_cache_only :
506+ if s == 1 :
507+ _write_generate_kv_cache (k , v , k_cache , v_cache , slot_idx , input_pos )
508+ else :
509+ _update_kv_cache (k , v , k_cache , v_cache , seq_len , input_pos , slot_idx , seq_start )
510+
362511 if s == 1 :
363- # Generate-only phase
364- _torch_generate_mha (
512+ _torch_generate_mha_readonly (
365513 q ,
366- k ,
367- v ,
368514 k_cache ,
369515 v_cache ,
370516 slot_idx ,
@@ -376,11 +522,8 @@ def torch_backend_mha_with_cache(
376522 sinks ,
377523 )
378524 else :
379- # Context phase
380- _torch_context_mha (
525+ _torch_context_mha_readonly (
381526 q ,
382- k ,
383- v ,
384527 input_pos ,
385528 slot_idx ,
386529 k_cache ,
@@ -437,6 +580,7 @@ def torch_backend_mha_with_cache_fake(
437580 sinks : Optional [torch .Tensor ] = None ,
438581 sliding_window_size : Optional [int ] = None ,
439582 logit_cap : Optional [float ] = None ,
583+ read_cache_only : bool = False ,
440584 out : Optional [torch .Tensor ] = None ,
441585) -> torch .Tensor :
442586 if out is not None :
@@ -464,6 +608,10 @@ def get_source_attention_op(cls) -> OpOverloadPacket:
464608 def get_cached_attention_op (cls ) -> MHACallable :
465609 return torch .ops .auto_deploy .torch_cached_attention_with_cache .default
466610
611+ @classmethod
612+ def supports_shared_kv (cls ) -> bool :
613+ return True
614+
467615 @classmethod
468616 def get_standard_metadata_args (cls ) -> List [str ]:
469617 return ["batch_info_host" , "seq_len" , "input_pos" , "slot_idx" , "cu_seqlen" ]
@@ -537,4 +685,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
537685 sinks , # sinks parameter
538686 sliding_window , # sliding window parameter
539687 logit_cap , # logit cap parameter
688+ cls .get_shared_kv_source_layer_idx (source_attn_node ) is not None , # read_cache_only
540689 ]
0 commit comments