11# Copyright (c) OpenMMLab. All rights reserved.
22
33import functools
4+ from dataclasses import dataclass
45
56import torch
67
78from lmdeploy .pytorch .backends .attention import V4AttentionMetadata
8- from lmdeploy .pytorch .models .deepseek_v4_utils import (
9- build_compress_topk_indices ,
10- build_prefix_positions ,
11- build_window_positions ,
12- build_window_topk_indices ,
13- )
9+
10+
11+ @dataclass
12+ class CudaV4AttentionMetadata (V4AttentionMetadata ):
13+ """CUDA-specific V4 attention metadata with pre-computed indices.
14+
15+ Adds layer-invariant index tensors that are computed once per step
16+ in ``from_step_context``, so TritonV4AttentionImpl does not
17+ recompute them per layer.
18+ """
19+
20+ # --- Decode pre-computed ---
21+ is_padded : torch .Tensor = None # [bsz] bool
22+ batch_offsets : torch .Tensor = None # [bsz, 1, 1] int32
23+ compress_fallback_indices_r4 : torch .Tensor = None # [bsz, 1, max_comp] int32
24+ compress_fallback_topk_r4 : torch .Tensor = None # [bsz] int32
25+ compress_fallback_indices_r128 : torch .Tensor = None # [bsz, 1, max_comp] int32
26+ compress_fallback_topk_r128 : torch .Tensor = None # [bsz] int32
27+
28+ # --- Prefill pre-computed ---
29+ prefill_window_kv_lens : torch .Tensor = None # [bsz] long
30+ prefill_max_flat_kv_len_r4 : int = None
31+ prefill_total_flat_kv_tokens_r4 : int = None
32+ prefill_max_flat_kv_len_r128 : int = None
33+ prefill_total_flat_kv_tokens_r128 : int = None
34+ prefill_max_compress_width : int = None
35+ prefill_window_topk : torch .Tensor = None # [total_q_tokens, window_size]
36+ prefill_compress_topk_r4 : torch .Tensor = None # [total_q_tokens, max_width]
37+ prefill_compress_topk_r128 : torch .Tensor = None # [total_q_tokens, max_width]
38+
39+ @classmethod
40+ def from_step_context (cls , attn_metadata , step_ctx , ** kwargs ) -> 'CudaV4AttentionMetadata' :
41+ window_size = kwargs .get ('window_size' , 0 )
42+ slot = kwargs .get ('slot' , None )
43+
44+ meta = super ().from_step_context (attn_metadata , step_ctx )
45+
46+ if window_size > 0 and slot is not None :
47+ meta .is_padded = slot < 0
48+ if meta .is_decoding :
49+ cls ._precompute_decode (meta , window_size )
50+ else :
51+ cls ._precompute_prefill (meta , window_size )
52+
53+ return meta
54+
55+ @staticmethod
56+ def _precompute_decode (meta , window_size ):
57+ from lmdeploy .pytorch .backends .cuda .attention .v4_utils import build_prefix_positions , build_window_positions
58+ kv_seqlens = meta .kv_seqlens
59+ block_offsets = meta .block_offsets .long ()
60+ block_size = meta .block_size
61+
62+ window_positions , window_lens , _ = build_window_positions (kv_seqlens .long (), window_size )
63+ meta .extra_indices_in_kvcache = window_positions .unsqueeze (1 ).to (torch .int32 )
64+ meta .extra_topk_length = window_lens .to (torch .int32 )
65+
66+ bsz = kv_seqlens .numel ()
67+ meta .batch_offsets = (
68+ torch .arange (bsz , device = kv_seqlens .device , dtype = torch .int32 ).view (- 1 , 1 , 1 ) * window_size )
69+
70+ for ratio in (4 , 128 ):
71+ num_compressed = torch .div (kv_seqlens , ratio , rounding_mode = 'floor' ).long ()
72+ max_comp = max (block_offsets .size (1 ) * block_size // ratio , 1 )
73+ comp_positions , _ = build_prefix_positions (num_compressed , max_comp )
74+ indices = comp_positions .unsqueeze (1 ).to (torch .int32 )
75+ topk = num_compressed .to (torch .int32 )
76+ if ratio == 4 :
77+ meta .compress_fallback_indices_r4 = indices
78+ meta .compress_fallback_topk_r4 = topk
79+ else :
80+ meta .compress_fallback_indices_r128 = indices
81+ meta .compress_fallback_topk_r128 = topk
82+
83+ @staticmethod
84+ def _precompute_prefill (meta , window_size ):
85+ from lmdeploy .pytorch .backends .cuda .attention .v4_utils import (
86+ build_compress_topk_indices ,
87+ build_window_topk_indices ,
88+ )
89+ kv_seqlens = meta .kv_seqlens
90+ q_seqlens = meta .q_seqlens
91+ start_pos = meta .start_pos
92+ total_lens = kv_seqlens
93+ max_kv = meta .max_kv_seqlen
94+ sum_kv = meta .sum_kv_seqlen
95+
96+ meta .prefill_window_kv_lens = total_lens .clamp (max = window_size )
97+
98+ for ratio in (4 , 128 ):
99+ mfk = min (max_kv , window_size ) + max_kv // ratio
100+ tfk = sum_kv + sum_kv // ratio
101+ if ratio == 4 :
102+ meta .prefill_max_flat_kv_len_r4 = mfk
103+ meta .prefill_total_flat_kv_tokens_r4 = tfk
104+ else :
105+ meta .prefill_max_flat_kv_len_r128 = mfk
106+ meta .prefill_total_flat_kv_tokens_r128 = tfk
107+
108+ meta .prefill_max_compress_width = max_kv // 4
109+
110+ meta .prefill_window_topk = build_window_topk_indices (
111+ total_lens , window_size ,
112+ q_seqlens = q_seqlens , start_pos = start_pos , causal = True )
113+
114+ for ratio in (4 , 128 ):
115+ max_width = max_kv // ratio
116+ compress_topk = build_compress_topk_indices (
117+ total_lens , ratio ,
118+ offset = meta .prefill_window_kv_lens ,
119+ q_seqlens = q_seqlens , start_pos = start_pos ,
120+ causal = True , max_width = max_width )
121+ if ratio == 4 :
122+ meta .prefill_compress_topk_r4 = compress_topk
123+ else :
124+ meta .prefill_compress_topk_r128 = compress_topk
14125
15126
16127class V4IndicesUpdater :
@@ -76,11 +187,14 @@ def build():
76187
77188
78189def _try_dynamic_compile (func , * args , ** kwargs ):
79- """Try to compile a function with torch.compile, fall back to eager."""
80- try :
81- return torch .compile (func , dynamic = True )
82- except Exception :
83- return func
190+ """Return the function as-is.
191+
192+ ``torch.compile(dynamic=True)`` is incompatible with CUDAGraph — it can
193+ produce compiled code that triggers guard failures or invalid memory
194+ accesses during graph replay, leading to segfaults. Eager execution is
195+ safe and the index conversion is not a bottleneck.
196+ """
197+ return func
84198
85199
86200class TritonV4AttentionImpl :
@@ -182,26 +296,20 @@ def _write_window_prefill(self, kv, attn_caches, slot, start_pos, q_seqlens, tot
182296 kv_tokens = selected .reshape (- 1 , self .head_size )
183297 self ._pack_window_fp8 (kv_tokens , attn_caches ['window_state_fp8' ], slot_expanded , pos_expanded )
184298
185- def _build_window_indices (self , total_lens ):
186- """Build window ring-buffer positions and lengths."""
187- window_positions , window_lens , _ = build_window_positions (total_lens .long (), self .window_size )
188- extra_indices_in_kvcache = window_positions .unsqueeze (1 ).to (torch .int32 )
189- extra_topk_length = window_lens .to (torch .int32 )
190- return extra_indices_in_kvcache , extra_topk_length
191299
192300 # ------------------------------------------------------------------
193301 # Unified forward
194302 # ------------------------------------------------------------------
195303
196- def forward (self , query , kv , attn_sink , attn_metadata : V4AttentionMetadata ,
304+ def forward (self , query , kv , attn_sink , attn_metadata : CudaV4AttentionMetadata ,
197305 caches , slot , index_out = None ):
198306 """Unified forward — dispatches to decoding or prefilling internally.
199307
200308 Args:
201309 query: Q tensor [bsz, 1, n_heads, head_dim] or [1, total_tokens, n_heads, head_dim]
202310 kv: KV tensor [bsz, 1, head_dim] or [1, total_tokens, head_dim]
203311 attn_sink: Learnable sink parameter
204- attn_metadata: V4AttentionMetadata with sequence info
312+ attn_metadata: CudaV4AttentionMetadata with sequence info
205313 caches: dict of attention cache tensors
206314 slot: state cache slot indices [bsz]
207315 index_out: V4IndexerOutput from the indexer call (if any)
@@ -217,27 +325,26 @@ def forward(self, query, kv, attn_sink, attn_metadata: V4AttentionMetadata,
217325 # Decode path
218326 # ------------------------------------------------------------------
219327
220- def _forward_decoding (self , query , kv , attn_sink , attn_metadata : V4AttentionMetadata ,
328+ def _forward_decoding (self , query , kv , attn_sink , attn_metadata : CudaV4AttentionMetadata ,
221329 caches , slot , index_out = None ):
222330 # Model sends [1, bsz, n_heads, head_dim] for decode; FlashMLA expects [bsz, 1, ...]
223331 if query .size (0 ) == 1 and query .size (1 ) > 1 :
224332 query = query .transpose (0 , 1 ).contiguous ()
225333 kv = kv .transpose (0 , 1 ).contiguous ()
226334
227- kv_seqlens = attn_metadata .kv_seqlens
228- q_seqlens = attn_metadata .q_seqlens
229- start_pos = (kv_seqlens .to (torch .long ) - q_seqlens .to (torch .long ))
230- total_lens = kv_seqlens
231- bsz = kv_seqlens .numel ()
335+ start_pos = attn_metadata .start_pos
336+ total_lens = attn_metadata .kv_seqlens
337+ bsz = total_lens .numel ()
232338 block_offsets = attn_metadata .block_offsets .long ()
233339 block_size = attn_metadata .block_size
234340
235341 # Write window state + FP8 pack
236342 window_state_fp8 = self ._write_window_decode (kv , caches , slot , start_pos , total_lens )
237343
238- # Window indices
239- extra_indices_in_kvcache , extra_topk_length = self ._build_window_indices (total_lens )
240- is_padded = slot < 0
344+ # Window indices (pre-computed once per step)
345+ extra_indices_in_kvcache = attn_metadata .extra_indices_in_kvcache
346+ extra_topk_length = attn_metadata .extra_topk_length
347+ is_padded = attn_metadata .is_padded
241348 extra_topk_length = torch .where (is_padded , 1 , extra_topk_length )
242349
243350 # Compressed indices
@@ -250,12 +357,12 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
250357 if index_out is not None :
251358 indices_in_kvcache = index_out .indices_in_kvcache
252359 topk_length = index_out .topk_length
360+ elif self .compress_ratio == 4 :
361+ indices_in_kvcache = attn_metadata .compress_fallback_indices_r4
362+ topk_length = attn_metadata .compress_fallback_topk_r4
253363 else :
254- num_compressed = torch .div (total_lens , self .compress_ratio , rounding_mode = 'floor' ).long ()
255- max_comp = max (block_offsets .size (1 ) * block_size // self .compress_ratio , 1 )
256- comp_positions , _ = build_prefix_positions (num_compressed , max_comp )
257- indices_in_kvcache = comp_positions .unsqueeze (1 ).to (torch .int32 )
258- topk_length = num_compressed .to (torch .int32 )
364+ indices_in_kvcache = attn_metadata .compress_fallback_indices_r128
365+ topk_length = attn_metadata .compress_fallback_topk_r128
259366 else :
260367 indices_in_kvcache = torch .full ((bsz , 1 , 1 ), - 1 , dtype = torch .int32 , device = query .device )
261368 topk_length = torch .zeros (bsz , dtype = torch .int32 , device = query .device )
@@ -265,7 +372,7 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
265372 # FlashMLA sparse decode
266373 extra_k_cache = window_state_fp8 .view (bsz , self .window_size , 1 , - 1 )
267374 extra_indices = extra_indices_in_kvcache
268- batch_offsets = torch . arange ( bsz , device = extra_indices . device , dtype = torch . int32 ). view ( - 1 , 1 , 1 ) * self . window_size # noqa: E501
375+ batch_offsets = attn_metadata . batch_offsets
269376 extra_indices = torch .where (extra_indices >= 0 , extra_indices + batch_offsets , extra_indices )
270377
271378 if block_offsets is not None and self .compress_ratio and indices_in_kvcache is not None :
@@ -307,29 +414,27 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
307414 # Prefill path
308415 # ------------------------------------------------------------------
309416
310- def _forward_prefilling (self , query , kv , attn_sink , attn_metadata : V4AttentionMetadata ,
417+ def _forward_prefilling (self , query , kv , attn_sink , attn_metadata : CudaV4AttentionMetadata ,
311418 caches , slot , index_out = None ):
312- kv_seqlens = attn_metadata .kv_seqlens
419+ start_pos = attn_metadata .start_pos
420+ total_lens = attn_metadata .kv_seqlens
313421 q_seqlens = attn_metadata .q_seqlens
314- start_pos = (kv_seqlens .to (torch .long ) - q_seqlens .to (torch .long ))
315- total_lens = kv_seqlens
316422 block_offsets = attn_metadata .block_offsets
317423
318- # CPU-side upper bounds for flatten_v4_kv (avoids GPU .item() sync)
319- max_kv = attn_metadata .max_kv_seqlen
320- sum_kv = attn_metadata .sum_kv_seqlen
321424 cr = self .compress_ratio if self .compress_ratio else 1
322- max_flat_kv_len = min (max_kv , self .window_size ) + max_kv // cr
323- total_flat_kv_tokens = sum_kv + sum_kv // cr
324- max_compress_width = max_kv // cr
425+ if cr == 4 :
426+ max_flat_kv_len = attn_metadata .prefill_max_flat_kv_len_r4
427+ total_flat_kv_tokens = attn_metadata .prefill_total_flat_kv_tokens_r4
428+ else :
429+ max_flat_kv_len = attn_metadata .prefill_max_flat_kv_len_r128
430+ total_flat_kv_tokens = attn_metadata .prefill_total_flat_kv_tokens_r128
325431
326- # Pre-compute window_kv_lens for Indexer offset
327- window_kv_lens = total_lens .clamp (max = self .window_size )
432+ window_kv_lens = attn_metadata .prefill_window_kv_lens
328433
329434 # Write window state + FP8 pack (batched)
330435 self ._write_window_prefill (kv , caches , slot , start_pos , q_seqlens , total_lens )
331436
332- # Build compress topk
437+ # Compress topk
333438 compress_topk = None
334439 if self .compress_ratio :
335440 if index_out is not None :
@@ -339,14 +444,10 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
339444 token_seq = torch .arange (total_tokens , device = compress_topk .device )
340445 seq_id = torch .searchsorted (cu_q_seqlens [1 :], token_seq , right = True )
341446 compress_topk = compress_topk + window_kv_lens [seq_id ].unsqueeze (- 1 )
447+ elif self .compress_ratio == 4 :
448+ compress_topk = attn_metadata .prefill_compress_topk_r4
342449 else :
343- compress_topk = build_compress_topk_indices (
344- total_lens , self .compress_ratio ,
345- offset = window_kv_lens ,
346- q_seqlens = q_seqlens ,
347- start_pos = start_pos ,
348- causal = True ,
349- max_width = max_compress_width )
450+ compress_topk = attn_metadata .prefill_compress_topk_r128
350451
351452 # Flatten window + compressed KV into contiguous tensor
352453 fp8_compressed_kv_cache = caches ['compressed_kv_fp8' ] if self .compress_ratio else None
@@ -357,13 +458,8 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
357458 total_flat_kv_tokens , max_flat_kv_len ,
358459 fp8_compressed_kv_cache = fp8_compressed_kv_cache , slot = slot )
359460
360- # Build topk indices
361- q_flat = query .squeeze (0 )
362- window_topk = build_window_topk_indices (
363- total_lens , self .window_size ,
364- q_seqlens = q_seqlens ,
365- start_pos = start_pos ,
366- causal = True )
461+ # Window topk (pre-computed once per step)
462+ window_topk = attn_metadata .prefill_window_topk
367463
368464 if compress_topk is not None :
369465 topk_indices = torch .cat ([window_topk , compress_topk ], dim = - 1 )
@@ -376,6 +472,7 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
376472 # FlashMLA sparse prefill
377473 topk_indices = self ._pad_sparse_indices (topk_indices ).to (torch .int32 )
378474
475+ q_flat = query .squeeze (0 )
379476 num_heads = q_flat .size (1 )
380477 target = 64 if num_heads < 64 else (128 if num_heads < 128 else num_heads )
381478 if target != num_heads :
0 commit comments