Skip to content

Commit 5d36638

Browse files
committed
update attn meta once
1 parent 4479987 commit 5d36638

7 files changed

Lines changed: 196 additions & 217 deletions

File tree

lmdeploy/pytorch/backends/attention.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ class AttentionMetadata:
2525

2626
@dataclass
2727
class V4AttentionMetadata:
28-
"""DeepSeek V4 attention metadata.
28+
"""DeepSeek V4 attention metadata base class.
2929
3030
Built once per step from attn_metadata + step_ctx, then passed through all V4 sub-modules (Attention, Compressor,
31-
Indexer).
31+
Indexer). Backends should subclass this to add their own pre-computed fields and override ``from_step_context``.
3232
"""
3333

3434
is_decoding: bool
@@ -47,26 +47,34 @@ class V4AttentionMetadata:
4747
block_size: int = 0
4848
cu_seqlens_k: torch.Tensor = None
4949
sum_kv_seqlen: int = None
50+
start_pos: torch.Tensor = None # [bsz] long
5051

5152
@classmethod
52-
def from_step_context(cls, attn_metadata, step_ctx) -> 'V4AttentionMetadata':
53+
def from_step_context(cls, attn_metadata, step_ctx, **kwargs) -> 'V4AttentionMetadata':
5354
"""Build V4AttentionMetadata from the scheduler's attn_metadata and
54-
step_ctx."""
55+
step_ctx.
56+
57+
Subclasses can accept additional keyword arguments for backend- specific pre-computation.
58+
"""
5559
is_decoding = attn_metadata.is_decoding
5660
cache_config = step_ctx.cache_config
5761
max_kv_seqlen = (cache_config.block_size * cache_config.num_gpu_blocks
5862
if is_decoding else step_ctx.max_kv_seqlen)
63+
kv_seqlens = attn_metadata.kv_seqlens
64+
q_seqlens = attn_metadata.q_seqlens
65+
5966
return cls(
6067
is_decoding=is_decoding,
6168
block_offsets=attn_metadata.block_offsets,
6269
cu_q_seqlens=attn_metadata.cu_seqlens_q,
63-
kv_seqlens=attn_metadata.kv_seqlens,
64-
q_seqlens=attn_metadata.q_seqlens,
70+
kv_seqlens=kv_seqlens,
71+
q_seqlens=q_seqlens,
6572
max_kv_seqlen=max_kv_seqlen,
6673
max_q_seqlen=step_ctx.max_q_seqlen,
6774
block_size=cache_config.block_size,
6875
sum_kv_seqlen=step_ctx.sum_kv_seqlen,
6976
cu_seqlens_k=attn_metadata.cu_seqlens_k,
77+
start_pos=(kv_seqlens.to(torch.long) - q_seqlens.to(torch.long)),
7078
)
7179

7280

lmdeploy/pytorch/backends/base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ def get_attention_metadata_cls():
6767
"""Get attention metadata class."""
6868
raise NotImplementedError
6969

70+
@staticmethod
71+
def get_v4_attention_metadata_cls():
72+
"""Get V4 attention metadata class.
73+
74+
Returns ``V4AttentionMetadata`` by default; backends with V4-specific
75+
pre-computation should override this to return their subclass.
76+
"""
77+
from lmdeploy.pytorch.backends.attention import V4AttentionMetadata
78+
return V4AttentionMetadata
79+
7080
@staticmethod
7181
@abstractmethod
7282
def get_k_block_shape(

lmdeploy/pytorch/backends/cuda/attention/v4.py

Lines changed: 158 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,127 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

33
import functools
4+
from dataclasses import dataclass
45

56
import torch
67

78
from lmdeploy.pytorch.backends.attention import V4AttentionMetadata
8-
from lmdeploy.pytorch.models.deepseek_v4_utils import (
9-
build_compress_topk_indices,
10-
build_prefix_positions,
11-
build_window_positions,
12-
build_window_topk_indices,
13-
)
9+
10+
11+
@dataclass
12+
class CudaV4AttentionMetadata(V4AttentionMetadata):
13+
"""CUDA-specific V4 attention metadata with pre-computed indices.
14+
15+
Adds layer-invariant index tensors that are computed once per step
16+
in ``from_step_context``, so TritonV4AttentionImpl does not
17+
recompute them per layer.
18+
"""
19+
20+
# --- Decode pre-computed ---
21+
is_padded: torch.Tensor = None # [bsz] bool
22+
batch_offsets: torch.Tensor = None # [bsz, 1, 1] int32
23+
compress_fallback_indices_r4: torch.Tensor = None # [bsz, 1, max_comp] int32
24+
compress_fallback_topk_r4: torch.Tensor = None # [bsz] int32
25+
compress_fallback_indices_r128: torch.Tensor = None # [bsz, 1, max_comp] int32
26+
compress_fallback_topk_r128: torch.Tensor = None # [bsz] int32
27+
28+
# --- Prefill pre-computed ---
29+
prefill_window_kv_lens: torch.Tensor = None # [bsz] long
30+
prefill_max_flat_kv_len_r4: int = None
31+
prefill_total_flat_kv_tokens_r4: int = None
32+
prefill_max_flat_kv_len_r128: int = None
33+
prefill_total_flat_kv_tokens_r128: int = None
34+
prefill_max_compress_width: int = None
35+
prefill_window_topk: torch.Tensor = None # [total_q_tokens, window_size]
36+
prefill_compress_topk_r4: torch.Tensor = None # [total_q_tokens, max_width]
37+
prefill_compress_topk_r128: torch.Tensor = None # [total_q_tokens, max_width]
38+
39+
@classmethod
40+
def from_step_context(cls, attn_metadata, step_ctx, **kwargs) -> 'CudaV4AttentionMetadata':
41+
window_size = kwargs.get('window_size', 0)
42+
slot = kwargs.get('slot', None)
43+
44+
meta = super().from_step_context(attn_metadata, step_ctx)
45+
46+
if window_size > 0 and slot is not None:
47+
meta.is_padded = slot < 0
48+
if meta.is_decoding:
49+
cls._precompute_decode(meta, window_size)
50+
else:
51+
cls._precompute_prefill(meta, window_size)
52+
53+
return meta
54+
55+
@staticmethod
56+
def _precompute_decode(meta, window_size):
57+
from lmdeploy.pytorch.backends.cuda.attention.v4_utils import build_prefix_positions, build_window_positions
58+
kv_seqlens = meta.kv_seqlens
59+
block_offsets = meta.block_offsets.long()
60+
block_size = meta.block_size
61+
62+
window_positions, window_lens, _ = build_window_positions(kv_seqlens.long(), window_size)
63+
meta.extra_indices_in_kvcache = window_positions.unsqueeze(1).to(torch.int32)
64+
meta.extra_topk_length = window_lens.to(torch.int32)
65+
66+
bsz = kv_seqlens.numel()
67+
meta.batch_offsets = (
68+
torch.arange(bsz, device=kv_seqlens.device, dtype=torch.int32).view(-1, 1, 1) * window_size)
69+
70+
for ratio in (4, 128):
71+
num_compressed = torch.div(kv_seqlens, ratio, rounding_mode='floor').long()
72+
max_comp = max(block_offsets.size(1) * block_size // ratio, 1)
73+
comp_positions, _ = build_prefix_positions(num_compressed, max_comp)
74+
indices = comp_positions.unsqueeze(1).to(torch.int32)
75+
topk = num_compressed.to(torch.int32)
76+
if ratio == 4:
77+
meta.compress_fallback_indices_r4 = indices
78+
meta.compress_fallback_topk_r4 = topk
79+
else:
80+
meta.compress_fallback_indices_r128 = indices
81+
meta.compress_fallback_topk_r128 = topk
82+
83+
@staticmethod
84+
def _precompute_prefill(meta, window_size):
85+
from lmdeploy.pytorch.backends.cuda.attention.v4_utils import (
86+
build_compress_topk_indices,
87+
build_window_topk_indices,
88+
)
89+
kv_seqlens = meta.kv_seqlens
90+
q_seqlens = meta.q_seqlens
91+
start_pos = meta.start_pos
92+
total_lens = kv_seqlens
93+
max_kv = meta.max_kv_seqlen
94+
sum_kv = meta.sum_kv_seqlen
95+
96+
meta.prefill_window_kv_lens = total_lens.clamp(max=window_size)
97+
98+
for ratio in (4, 128):
99+
mfk = min(max_kv, window_size) + max_kv // ratio
100+
tfk = sum_kv + sum_kv // ratio
101+
if ratio == 4:
102+
meta.prefill_max_flat_kv_len_r4 = mfk
103+
meta.prefill_total_flat_kv_tokens_r4 = tfk
104+
else:
105+
meta.prefill_max_flat_kv_len_r128 = mfk
106+
meta.prefill_total_flat_kv_tokens_r128 = tfk
107+
108+
meta.prefill_max_compress_width = max_kv // 4
109+
110+
meta.prefill_window_topk = build_window_topk_indices(
111+
total_lens, window_size,
112+
q_seqlens=q_seqlens, start_pos=start_pos, causal=True)
113+
114+
for ratio in (4, 128):
115+
max_width = max_kv // ratio
116+
compress_topk = build_compress_topk_indices(
117+
total_lens, ratio,
118+
offset=meta.prefill_window_kv_lens,
119+
q_seqlens=q_seqlens, start_pos=start_pos,
120+
causal=True, max_width=max_width)
121+
if ratio == 4:
122+
meta.prefill_compress_topk_r4 = compress_topk
123+
else:
124+
meta.prefill_compress_topk_r128 = compress_topk
14125

15126

16127
class V4IndicesUpdater:
@@ -76,11 +187,14 @@ def build():
76187

77188

78189
def _try_dynamic_compile(func, *args, **kwargs):
79-
"""Try to compile a function with torch.compile, fall back to eager."""
80-
try:
81-
return torch.compile(func, dynamic=True)
82-
except Exception:
83-
return func
190+
"""Return the function as-is.
191+
192+
``torch.compile(dynamic=True)`` is incompatible with CUDAGraph — it can
193+
produce compiled code that triggers guard failures or invalid memory
194+
accesses during graph replay, leading to segfaults. Eager execution is
195+
safe and the index conversion is not a bottleneck.
196+
"""
197+
return func
84198

85199

86200
class TritonV4AttentionImpl:
@@ -182,26 +296,20 @@ def _write_window_prefill(self, kv, attn_caches, slot, start_pos, q_seqlens, tot
182296
kv_tokens = selected.reshape(-1, self.head_size)
183297
self._pack_window_fp8(kv_tokens, attn_caches['window_state_fp8'], slot_expanded, pos_expanded)
184298

185-
def _build_window_indices(self, total_lens):
186-
"""Build window ring-buffer positions and lengths."""
187-
window_positions, window_lens, _ = build_window_positions(total_lens.long(), self.window_size)
188-
extra_indices_in_kvcache = window_positions.unsqueeze(1).to(torch.int32)
189-
extra_topk_length = window_lens.to(torch.int32)
190-
return extra_indices_in_kvcache, extra_topk_length
191299

192300
# ------------------------------------------------------------------
193301
# Unified forward
194302
# ------------------------------------------------------------------
195303

196-
def forward(self, query, kv, attn_sink, attn_metadata: V4AttentionMetadata,
304+
def forward(self, query, kv, attn_sink, attn_metadata: CudaV4AttentionMetadata,
197305
caches, slot, index_out=None):
198306
"""Unified forward — dispatches to decoding or prefilling internally.
199307
200308
Args:
201309
query: Q tensor [bsz, 1, n_heads, head_dim] or [1, total_tokens, n_heads, head_dim]
202310
kv: KV tensor [bsz, 1, head_dim] or [1, total_tokens, head_dim]
203311
attn_sink: Learnable sink parameter
204-
attn_metadata: V4AttentionMetadata with sequence info
312+
attn_metadata: CudaV4AttentionMetadata with sequence info
205313
caches: dict of attention cache tensors
206314
slot: state cache slot indices [bsz]
207315
index_out: V4IndexerOutput from the indexer call (if any)
@@ -217,27 +325,26 @@ def forward(self, query, kv, attn_sink, attn_metadata: V4AttentionMetadata,
217325
# Decode path
218326
# ------------------------------------------------------------------
219327

220-
def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMetadata,
328+
def _forward_decoding(self, query, kv, attn_sink, attn_metadata: CudaV4AttentionMetadata,
221329
caches, slot, index_out=None):
222330
# Model sends [1, bsz, n_heads, head_dim] for decode; FlashMLA expects [bsz, 1, ...]
223331
if query.size(0) == 1 and query.size(1) > 1:
224332
query = query.transpose(0, 1).contiguous()
225333
kv = kv.transpose(0, 1).contiguous()
226334

227-
kv_seqlens = attn_metadata.kv_seqlens
228-
q_seqlens = attn_metadata.q_seqlens
229-
start_pos = (kv_seqlens.to(torch.long) - q_seqlens.to(torch.long))
230-
total_lens = kv_seqlens
231-
bsz = kv_seqlens.numel()
335+
start_pos = attn_metadata.start_pos
336+
total_lens = attn_metadata.kv_seqlens
337+
bsz = total_lens.numel()
232338
block_offsets = attn_metadata.block_offsets.long()
233339
block_size = attn_metadata.block_size
234340

235341
# Write window state + FP8 pack
236342
window_state_fp8 = self._write_window_decode(kv, caches, slot, start_pos, total_lens)
237343

238-
# Window indices
239-
extra_indices_in_kvcache, extra_topk_length = self._build_window_indices(total_lens)
240-
is_padded = slot < 0
344+
# Window indices (pre-computed once per step)
345+
extra_indices_in_kvcache = attn_metadata.extra_indices_in_kvcache
346+
extra_topk_length = attn_metadata.extra_topk_length
347+
is_padded = attn_metadata.is_padded
241348
extra_topk_length = torch.where(is_padded, 1, extra_topk_length)
242349

243350
# Compressed indices
@@ -250,12 +357,12 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
250357
if index_out is not None:
251358
indices_in_kvcache = index_out.indices_in_kvcache
252359
topk_length = index_out.topk_length
360+
elif self.compress_ratio == 4:
361+
indices_in_kvcache = attn_metadata.compress_fallback_indices_r4
362+
topk_length = attn_metadata.compress_fallback_topk_r4
253363
else:
254-
num_compressed = torch.div(total_lens, self.compress_ratio, rounding_mode='floor').long()
255-
max_comp = max(block_offsets.size(1) * block_size // self.compress_ratio, 1)
256-
comp_positions, _ = build_prefix_positions(num_compressed, max_comp)
257-
indices_in_kvcache = comp_positions.unsqueeze(1).to(torch.int32)
258-
topk_length = num_compressed.to(torch.int32)
364+
indices_in_kvcache = attn_metadata.compress_fallback_indices_r128
365+
topk_length = attn_metadata.compress_fallback_topk_r128
259366
else:
260367
indices_in_kvcache = torch.full((bsz, 1, 1), -1, dtype=torch.int32, device=query.device)
261368
topk_length = torch.zeros(bsz, dtype=torch.int32, device=query.device)
@@ -265,7 +372,7 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
265372
# FlashMLA sparse decode
266373
extra_k_cache = window_state_fp8.view(bsz, self.window_size, 1, -1)
267374
extra_indices = extra_indices_in_kvcache
268-
batch_offsets = torch.arange(bsz, device=extra_indices.device, dtype=torch.int32).view(-1, 1, 1) * self.window_size # noqa: E501
375+
batch_offsets = attn_metadata.batch_offsets
269376
extra_indices = torch.where(extra_indices >= 0, extra_indices + batch_offsets, extra_indices)
270377

271378
if block_offsets is not None and self.compress_ratio and indices_in_kvcache is not None:
@@ -307,29 +414,27 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: V4AttentionMeta
307414
# Prefill path
308415
# ------------------------------------------------------------------
309416

310-
def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMetadata,
417+
def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: CudaV4AttentionMetadata,
311418
caches, slot, index_out=None):
312-
kv_seqlens = attn_metadata.kv_seqlens
419+
start_pos = attn_metadata.start_pos
420+
total_lens = attn_metadata.kv_seqlens
313421
q_seqlens = attn_metadata.q_seqlens
314-
start_pos = (kv_seqlens.to(torch.long) - q_seqlens.to(torch.long))
315-
total_lens = kv_seqlens
316422
block_offsets = attn_metadata.block_offsets
317423

318-
# CPU-side upper bounds for flatten_v4_kv (avoids GPU .item() sync)
319-
max_kv = attn_metadata.max_kv_seqlen
320-
sum_kv = attn_metadata.sum_kv_seqlen
321424
cr = self.compress_ratio if self.compress_ratio else 1
322-
max_flat_kv_len = min(max_kv, self.window_size) + max_kv // cr
323-
total_flat_kv_tokens = sum_kv + sum_kv // cr
324-
max_compress_width = max_kv // cr
425+
if cr == 4:
426+
max_flat_kv_len = attn_metadata.prefill_max_flat_kv_len_r4
427+
total_flat_kv_tokens = attn_metadata.prefill_total_flat_kv_tokens_r4
428+
else:
429+
max_flat_kv_len = attn_metadata.prefill_max_flat_kv_len_r128
430+
total_flat_kv_tokens = attn_metadata.prefill_total_flat_kv_tokens_r128
325431

326-
# Pre-compute window_kv_lens for Indexer offset
327-
window_kv_lens = total_lens.clamp(max=self.window_size)
432+
window_kv_lens = attn_metadata.prefill_window_kv_lens
328433

329434
# Write window state + FP8 pack (batched)
330435
self._write_window_prefill(kv, caches, slot, start_pos, q_seqlens, total_lens)
331436

332-
# Build compress topk
437+
# Compress topk
333438
compress_topk = None
334439
if self.compress_ratio:
335440
if index_out is not None:
@@ -339,14 +444,10 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
339444
token_seq = torch.arange(total_tokens, device=compress_topk.device)
340445
seq_id = torch.searchsorted(cu_q_seqlens[1:], token_seq, right=True)
341446
compress_topk = compress_topk + window_kv_lens[seq_id].unsqueeze(-1)
447+
elif self.compress_ratio == 4:
448+
compress_topk = attn_metadata.prefill_compress_topk_r4
342449
else:
343-
compress_topk = build_compress_topk_indices(
344-
total_lens, self.compress_ratio,
345-
offset=window_kv_lens,
346-
q_seqlens=q_seqlens,
347-
start_pos=start_pos,
348-
causal=True,
349-
max_width=max_compress_width)
450+
compress_topk = attn_metadata.prefill_compress_topk_r128
350451

351452
# Flatten window + compressed KV into contiguous tensor
352453
fp8_compressed_kv_cache = caches['compressed_kv_fp8'] if self.compress_ratio else None
@@ -357,13 +458,8 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
357458
total_flat_kv_tokens, max_flat_kv_len,
358459
fp8_compressed_kv_cache=fp8_compressed_kv_cache, slot=slot)
359460

360-
# Build topk indices
361-
q_flat = query.squeeze(0)
362-
window_topk = build_window_topk_indices(
363-
total_lens, self.window_size,
364-
q_seqlens=q_seqlens,
365-
start_pos=start_pos,
366-
causal=True)
461+
# Window topk (pre-computed once per step)
462+
window_topk = attn_metadata.prefill_window_topk
367463

368464
if compress_topk is not None:
369465
topk_indices = torch.cat([window_topk, compress_topk], dim=-1)
@@ -376,6 +472,7 @@ def _forward_prefilling(self, query, kv, attn_sink, attn_metadata: V4AttentionMe
376472
# FlashMLA sparse prefill
377473
topk_indices = self._pad_sparse_indices(topk_indices).to(torch.int32)
378474

475+
q_flat = query.squeeze(0)
379476
num_heads = q_flat.size(1)
380477
target = 64 if num_heads < 64 else (128 if num_heads < 128 else num_heads)
381478
if target != num_heads:
File renamed without changes.

0 commit comments

Comments
 (0)