Skip to content

Commit ec90f8a

Browse files
committed
guard fa3 prefill
1 parent 3b12b26 commit ec90f8a

2 files changed

Lines changed: 96 additions & 3 deletions

File tree

lmdeploy/pytorch/backends/cuda/attention/fa3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from lmdeploy.messages import QuantPolicy
55
from lmdeploy.utils import get_logger
66

7-
from .default import TritonAttentionImpl, TritonAttentionMetadata
7+
from .default import TritonAttentionImpl, TritonAttentionMetadata, _cdiv
88

99
logger = get_logger('lmdeploy')
1010

@@ -261,13 +261,15 @@ def _forward_prefill(
261261
quant_policy = attn_metadata.quant_policy
262262

263263
# Flatten KV cache for varlen attention
264+
block_size = k_cache.size(1)
265+
out_size = _cdiv(kv_flatten_size, block_size) * block_size + block_size
264266
flatten_k, flatten_v = self.flatten_kv_cache(
265267
k_cache,
266268
v_cache,
267269
kv_seqlens,
268270
block_offsets,
269271
start_loc=kv_start_loc,
270-
out_size=kv_flatten_size,
272+
out_size=out_size,
271273
out_dtype=query.dtype,
272274
k_scales_zeros=k_scales_zeros,
273275
v_scales_zeros=v_scales_zeros,
@@ -293,7 +295,7 @@ def _forward_prefill(
293295
cu_seqlens_q=attn_metadata.cu_seqlens_q,
294296
cu_seqlens_k=attn_metadata.cu_seqlens_k,
295297
max_seqlen_q=max_q_seqlen,
296-
max_seqlen_k=kv_flatten_size,
298+
max_seqlen_k=attn_metadata.max_kv_seqlen,
297299
softmax_scale=self.scale,
298300
causal=self.causal,
299301
window_size=sliding_window,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from lmdeploy.messages import QuantPolicy
5+
from lmdeploy.pytorch.backends.cuda.attention.default import TritonAttentionMetadata
6+
from lmdeploy.pytorch.backends.cuda.attention.fa3 import FA3Impl
7+
8+
_BLOCK_SIZE = 16
9+
_PREFILL_SEQLENS = (29, 18)
10+
11+
12+
def _make_prefill_metadata(q_seqlens, block_offsets):
13+
cu_seqlens = torch.nn.functional.pad(torch.cumsum(q_seqlens, dim=0, dtype=torch.int32), (1, 0))
14+
return TritonAttentionMetadata(
15+
is_decoding=False,
16+
block_offsets=block_offsets,
17+
q_start_loc=cu_seqlens[:-1],
18+
q_seqlens=q_seqlens,
19+
kv_start_loc=cu_seqlens[:-1],
20+
kv_seqlens=q_seqlens,
21+
quant_policy=QuantPolicy.NONE,
22+
kv_flatten_size=int(q_seqlens.sum().item()),
23+
cu_seqlens_q=cu_seqlens,
24+
cu_seqlens_k=cu_seqlens.clone(),
25+
max_kv_seqlen=int(q_seqlens.max().item()),
26+
max_q_seqlen=int(q_seqlens.max().item()),
27+
)
28+
29+
30+
def _make_recycled_block_offsets(device):
31+
return torch.tensor([
32+
[0, 2, 1],
33+
[3, 4, 0],
34+
],
35+
dtype=torch.int32,
36+
device=device)
37+
38+
39+
def _make_prefill_seqlens(device='cpu'):
40+
return torch.tensor(_PREFILL_SEQLENS, dtype=torch.int32, device=device)
41+
42+
43+
def _guarded_flatten_size(q_seqlens):
44+
kv_flatten_size = int(q_seqlens.sum().item())
45+
return (kv_flatten_size + _BLOCK_SIZE - 1) // _BLOCK_SIZE * _BLOCK_SIZE + _BLOCK_SIZE
46+
47+
48+
def _num_cache_blocks(block_offsets):
49+
return int(block_offsets.max().item()) + 1
50+
51+
52+
def test_fa3_prefill_uses_guarded_flatten_buffer_and_max_kv_seqlen():
53+
"""Regression test for FA3 prefill with recycled paged KV blocks."""
54+
impl = FA3Impl.__new__(FA3Impl)
55+
impl.scale = 1.0
56+
impl.causal = True
57+
impl.sliding_window = None
58+
impl.logit_softcapping = 0.0
59+
60+
q_seqlens = _make_prefill_seqlens()
61+
block_offsets = _make_recycled_block_offsets(device='cpu')
62+
metadata = _make_prefill_metadata(q_seqlens, block_offsets)
63+
64+
query = torch.empty((int(q_seqlens.sum().item()), 2, 8), dtype=torch.float16)
65+
k_cache = torch.empty((_num_cache_blocks(block_offsets), _BLOCK_SIZE, 2, 8), dtype=torch.float16)
66+
v_cache = torch.empty_like(k_cache)
67+
captured = {}
68+
69+
def fake_flatten_kv_cache(k_cache_arg, v_cache_arg, seqlens, offsets, **kwargs):
70+
captured['flatten_out_size'] = kwargs['out_size']
71+
captured['flatten_start_loc'] = kwargs['start_loc']
72+
return (
73+
torch.empty((kwargs['out_size'], 2, 8), dtype=kwargs['out_dtype']),
74+
torch.empty((kwargs['out_size'], 2, 8), dtype=kwargs['out_dtype']),
75+
)
76+
77+
def fake_flash_attn_varlen_func(**kwargs):
78+
captured['flash_max_seqlen_k'] = kwargs['max_seqlen_k']
79+
captured['flash_k_size'] = kwargs['k'].size(0)
80+
return torch.empty_like(kwargs['q'])
81+
82+
impl.flatten_kv_cache = fake_flatten_kv_cache
83+
impl.flash_attn_varlen_func_v3 = fake_flash_attn_varlen_func
84+
85+
out = impl._forward_prefill(query, k_cache, v_cache, metadata, max_q_seqlen=int(q_seqlens.max().item()))
86+
87+
assert out.shape == query.shape
88+
assert captured['flatten_start_loc'] is metadata.kv_start_loc
89+
assert captured['flatten_out_size'] == _guarded_flatten_size(q_seqlens)
90+
assert captured['flash_k_size'] == _guarded_flatten_size(q_seqlens)
91+
assert captured['flash_max_seqlen_k'] == metadata.max_kv_seqlen

0 commit comments

Comments
 (0)