Skip to content

Commit 2ef9c6b

Browse files
authored
[ascend] fix prefix caching (#4448)
* fix prefix caching * change attention layout from BSH to TND * remove unused comments
1 parent 687385e commit 2ef9c6b

5 files changed

Lines changed: 48 additions & 51 deletions

File tree

lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,11 @@ def update_step_context(cls, step_context):
155155
"""Update step context."""
156156

157157
block_num, block_size, *_ = step_context.kv_caches[0][0].shape
158-
is_unpaged_prefill = False
158+
is_prefill_no_cache = False
159159
if not step_context.is_decoding:
160-
is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
160+
is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist())
161161
if step_context.block_offsets.dtype != torch.int32:
162162
step_context.block_offsets = step_context.block_offsets.to(torch.int32)
163-
if not (step_context.is_decoding or is_unpaged_prefill):
164-
step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0)
165163
if step_context.kv_seqlens.dtype != torch.int32:
166164
step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32)
167165
if step_context.q_seqlens.dtype != torch.int32:
@@ -175,7 +173,7 @@ def get_total_slots():
175173
cls.total_slots = cls.total_slots.view(block_num, block_size)
176174
return cls.total_slots
177175

178-
def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
176+
def get_cpu_seqlens(is_decoding, is_prefill_no_cache):
179177
"""Get sequence lengths on CPU.
180178
181179
Returns:
@@ -187,37 +185,43 @@ def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
187185
"""
188186
if is_decoding:
189187
q_seqlens_cpu = None
190-
kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu()
191-
elif is_unpaged_prefill:
188+
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
189+
elif is_prefill_no_cache:
192190
q_seqlens_cpu = step_context.q_seqlens.cpu()
193-
kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu
191+
kv_seqlens_cpu = q_seqlens_cpu
194192
else:
195193
q_seqlens_cpu = step_context.q_seqlens.cpu()
196194
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
197-
# Expand kv_seqlens to per-token for paged prefill attention
198-
kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0)
199-
return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded
195+
return q_seqlens_cpu, kv_seqlens_cpu
200196

201-
def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):
197+
def get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None, kv_seqlens_cpu=None):
202198
if is_decoding:
203199
q_seqlens_list, kv_seqlens_list = None, None
204-
elif is_unpaged_prefill:
200+
elif is_prefill_no_cache:
205201
q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist()
206202
else:
207203
q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist()
208204
return q_seqlens_list, kv_seqlens_list
209205

210-
def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None):
206+
def get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list=None, kv_seqlens_list=None):
211207
if is_decoding:
212208
max_q_seq_len, max_kv_seq_len = 1, None
213-
elif is_unpaged_prefill:
209+
elif is_prefill_no_cache:
214210
max_q_seq_len = max_kv_seq_len = max(q_seqlens_list)
215211
else:
216212
max_q_seq_len = max(q_seqlens_list)
217213
max_kv_seq_len = max(kv_seqlens_list)
218214
return max_q_seq_len, max_kv_seq_len
219215

220-
def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list,
216+
def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None):
217+
if is_decoding:
218+
batch_size = step_context.q_seqlens.size(0)
219+
return torch.arange(1, batch_size + 1, dtype=torch.int32)
220+
elif is_prefill_no_cache:
221+
return q_seqlens_cpu
222+
return q_seqlens_cpu.cumsum(dim=0)
223+
224+
def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list,
221225
max_q_seq_len, max_kv_seq_len):
222226
kv_start_indices, attention_mask = [], []
223227
if is_decoding:
@@ -236,25 +240,17 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
236240
slots = slot_tables[history_length:kv_seq_len]
237241
kv_start_indices.append(slots)
238242

239-
if not is_unpaged_prefill:
240-
single_attention_mask = torch.triu(
241-
torch.ones(q_seq_len,
242-
step_context.block_offsets.shape[1] * block_size,
243-
dtype=torch.bool,
244-
device=step_context.block_offsets.device),
245-
diagonal=kv_seq_len - q_seq_len + 1,
246-
)
247-
attention_mask.append(single_attention_mask)
248-
249-
if is_unpaged_prefill:
243+
if is_prefill_no_cache:
250244
attention_mask.append(
251245
torch.triu(torch.ones(max_q_seq_len,
252246
max_kv_seq_len,
253247
dtype=step_context.kv_caches[0][0].dtype,
254248
device=step_context.block_offsets.device),
255249
diagonal=max_kv_seq_len - max_q_seq_len + 1))
256250
else:
257-
attention_mask = [torch.cat(attention_mask)]
251+
attention_mask.append(
252+
torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device),
253+
diagonal=1))
258254

259255
kv_start_indices = torch.cat(kv_start_indices)
260256

@@ -357,16 +353,16 @@ def get_moe_group_name(group):
357353
group_name = backend.get_hccl_comm_name(local_rank)
358354
return group_name
359355

360-
q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
361-
is_unpaged_prefill)
362-
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
356+
q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache)
357+
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu,
363358
kv_seqlens_cpu)
364-
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,
359+
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list,
365360
kv_seqlens_list)
366361
kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding,
367-
is_unpaged_prefill, q_seqlens_list,
362+
is_prefill_no_cache, q_seqlens_list,
368363
kv_seqlens_list, max_q_seq_len,
369364
max_kv_seq_len)
365+
q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu)
370366

371367
if not cls.enable_graph and step_context.kv_quant_policy == 8:
372368
record_file = os.getenv('ASCEND_QUANT_RECORD_FILE')
@@ -400,13 +396,11 @@ def get_moe_group_name(group):
400396
# Otherwise, q_start_loc is None.
401397
q_start_loc=cu_seqlens,
402398
q_seqlens=q_seqlens_cpu,
403-
# kv_seqlens_expanded is only expanded in paged prefill,
404-
# otherwise it equals kv_seqlens_cpu
405-
kv_seqlens=kv_seqlens_expanded,
399+
kv_seqlens=kv_seqlens_cpu,
406400
kv_start_indices=kv_start_indices,
407401
block_size=block_size,
408402
attention_mask=attention_mask,
409-
is_unpaged_prefill=is_unpaged_prefill,
403+
is_prefill_no_cache=is_prefill_no_cache,
410404
max_q_seq_len=max_q_seq_len,
411405
max_kv_seq_len=max_kv_seq_len,
412406
quant_policy=step_context.kv_quant_policy,

lmdeploy/pytorch/backends/dlinfer/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
1313
kv_start_indices: Tensor | None = None
1414
block_size: int = 64
1515
attention_mask: Sequence[Tensor] = tuple()
16-
is_unpaged_prefill: bool | None = None
16+
is_prefill_no_cache: bool | None = None
1717
max_q_seq_len: int = 1
1818
max_kv_seq_len: int = 1
1919
quant_meta: dict = None
@@ -80,7 +80,7 @@ def forward(
8080
kv_start_indices = attn_metadata.kv_start_indices
8181
block_size = attn_metadata.block_size
8282
attn_mask = attn_metadata.attention_mask
83-
is_unpaged_prefill = attn_metadata.is_unpaged_prefill
83+
is_prefill_no_cache = attn_metadata.is_prefill_no_cache
8484
max_q_seq_len = attn_metadata.max_q_seq_len
8585
max_kv_seq_len = attn_metadata.max_kv_seq_len
8686
quant_bits = attn_metadata.quant_policy
@@ -139,7 +139,7 @@ def forward(
139139
v_head_size=self.v_head_size,
140140
attn_mask=attn_mask,
141141
softmax_scale=self.scale,
142-
is_unpaged_prefill=is_unpaged_prefill,
142+
is_prefill_no_cache=is_prefill_no_cache,
143143
kv_scales=kv_scales,
144144
kv_zeros=kv_zeros,
145145
quant_bits=quant_bits,

lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_total_slots():
6060
kv_start_indices = []
6161
block_num, _, block_size, _ = step_context.kv_caches[0][0].shape
6262

63-
is_unpaged_prefill = False
63+
is_prefill_no_cache = False
6464
q_start_loc = step_context.q_start_loc
6565
q_seqlens = step_context.q_seqlens
6666
kv_seqlens = step_context.kv_seqlens.to(torch.int32)
@@ -74,7 +74,7 @@ def get_total_slots():
7474
q_seqlens_list = step_context.q_seqlens.tolist()
7575
kv_seqlens_list = step_context.kv_seqlens.tolist()
7676
if not step_context.is_decoding:
77-
is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
77+
is_prefill_no_cache = q_seqlens_list == kv_seqlens_list
7878
# get kv_indices
7979
for i in range(q_start_loc.size(0)):
8080
q_seq_len = q_seqlens_list[i]
@@ -86,7 +86,7 @@ def get_total_slots():
8686
slots = slot_tables[history_length:kv_seq_len]
8787
kv_start_indices.append(slots)
8888
kv_start_indices = torch.cat(kv_start_indices)
89-
if not is_unpaged_prefill:
89+
if not is_prefill_no_cache:
9090
cu_seq_lens_kv = torch.cat((torch.tensor([0], device=kv_seqlens.device), kv_seqlens.cumsum(0))).int()
9191
else:
9292
# collect kv_start_indices without using a for-loop,
@@ -108,7 +108,7 @@ def get_total_slots():
108108
kv_start_indices=kv_start_indices,
109109
block_size=block_size,
110110
attention_mask=None,
111-
is_unpaged_prefill=is_unpaged_prefill,
111+
is_prefill_no_cache=is_prefill_no_cache,
112112
max_q_seq_len=max_q_seq_len,
113113
max_kv_seq_len=max_kv_seq_len,
114114
)

lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def get_total_slots():
5252
kv_start_indices, attention_mask = [], []
5353
block_num, block_size, _, _ = step_context.kv_caches[0][1].shape
5454

55-
is_unpaged_prefill = False
55+
is_prefill_no_cache = False
5656
if not step_context.is_decoding:
57-
is_unpaged_prefill = \
57+
is_prefill_no_cache = \
5858
all((step_context.q_seqlens ==
5959
step_context.kv_seqlens).tolist())
6060
q_start_loc = step_context.q_start_loc
@@ -99,7 +99,7 @@ def get_total_slots():
9999
kv_start_indices=kv_start_indices,
100100
block_size=block_size,
101101
attention_mask=attention_mask,
102-
is_unpaged_prefill=is_unpaged_prefill,
102+
is_prefill_no_cache=is_prefill_no_cache,
103103
max_q_seq_len=max_q_seq_len,
104104
max_kv_seq_len=max_kv_seq_len,
105105
)

lmdeploy/pytorch/kernels/dlinfer/pagedattention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def prefill_attention(
2525
head_size_v: int,
2626
attn_mask: Sequence[Tensor | None],
2727
softmax_scale: float | None,
28-
is_unpaged_prefill: bool | None,
28+
is_prefill_no_cache: bool | None,
2929
kv_scales: Tensor | None,
3030
kv_zeros: Tensor | None,
3131
quant_bits: int | None,
3232
) -> Tensor:
33-
if is_unpaged_prefill:
33+
if is_prefill_no_cache:
3434
return ext_ops.prefill_attention(
3535
query_states,
3636
key_states,
@@ -79,6 +79,7 @@ def paged_token_attention(
7979
k_cache,
8080
v_cache,
8181
attn_output,
82+
q_seqlens,
8283
kv_seq_len,
8384
max_kv_seq_len,
8485
block_offsets,
@@ -97,6 +98,7 @@ def paged_token_attention(
9798
v_cache,
9899
block_offsets,
99100
block_size,
101+
q_seqlens,
100102
kv_seq_len,
101103
max_kv_seq_len,
102104
num_q_heads,
@@ -131,7 +133,7 @@ def paged_attention_fwd(
131133
v_head_size: int,
132134
attn_mask: Sequence[Tensor | None] = (),
133135
softmax_scale: float | None = None,
134-
is_unpaged_prefill: bool | None = None,
136+
is_prefill_no_cache: bool | None = None,
135137
kv_scales: Tensor | None = None,
136138
kv_zeros: Tensor | None = None,
137139
quant_bits: int | None = 0,
@@ -157,7 +159,7 @@ def paged_attention_fwd(
157159
v_head_size,
158160
attn_mask,
159161
softmax_scale,
160-
is_unpaged_prefill,
162+
is_prefill_no_cache,
161163
kv_scales=kv_scales,
162164
kv_zeros=kv_zeros,
163165
quant_bits=quant_bits,
@@ -168,6 +170,7 @@ def paged_attention_fwd(
168170
key_cache,
169171
value_cache,
170172
attn_output,
173+
q_seqlens,
171174
kv_seqlens,
172175
max_kv_seq_len,
173176
block_offsets,

0 commit comments

Comments
 (0)