diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 169da9c150..796c2f7112 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -155,13 +155,11 @@ def update_step_context(cls, step_context): """Update step context.""" block_num, block_size, *_ = step_context.kv_caches[0][0].shape - is_unpaged_prefill = False + is_prefill_no_cache = False if not step_context.is_decoding: - is_unpaged_prefill = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) + is_prefill_no_cache = all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) if step_context.block_offsets.dtype != torch.int32: step_context.block_offsets = step_context.block_offsets.to(torch.int32) - if not (step_context.is_decoding or is_unpaged_prefill): - step_context.block_offsets = step_context.block_offsets.repeat_interleave(step_context.q_seqlens, 0) if step_context.kv_seqlens.dtype != torch.int32: step_context.kv_seqlens = step_context.kv_seqlens.to(torch.int32) if step_context.q_seqlens.dtype != torch.int32: @@ -175,7 +173,7 @@ def get_total_slots(): cls.total_slots = cls.total_slots.view(block_num, block_size) return cls.total_slots - def get_cpu_seqlens(is_decoding, is_unpaged_prefill): + def get_cpu_seqlens(is_decoding, is_prefill_no_cache): """Get sequence lengths on CPU. Returns: @@ -187,37 +185,43 @@ def get_cpu_seqlens(is_decoding, is_unpaged_prefill): """ if is_decoding: q_seqlens_cpu = None - kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu() - elif is_unpaged_prefill: + kv_seqlens_cpu = step_context.kv_seqlens.cpu() + elif is_prefill_no_cache: q_seqlens_cpu = step_context.q_seqlens.cpu() - kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu + kv_seqlens_cpu = q_seqlens_cpu else: q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = step_context.kv_seqlens.cpu() - # Expand kv_seqlens to per-token for paged prefill attention - kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0) - return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded + return q_seqlens_cpu, kv_seqlens_cpu - def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): + def get_list_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None, kv_seqlens_cpu=None): if is_decoding: q_seqlens_list, kv_seqlens_list = None, None - elif is_unpaged_prefill: + elif is_prefill_no_cache: q_seqlens_list = kv_seqlens_list = q_seqlens_cpu.tolist() else: q_seqlens_list, kv_seqlens_list = q_seqlens_cpu.tolist(), kv_seqlens_cpu.tolist() return q_seqlens_list, kv_seqlens_list - def get_max_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_list=None, kv_seqlens_list=None): + def get_max_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_list=None, kv_seqlens_list=None): if is_decoding: max_q_seq_len, max_kv_seq_len = 1, None - elif is_unpaged_prefill: + elif is_prefill_no_cache: max_q_seq_len = max_kv_seq_len = max(q_seqlens_list) else: max_q_seq_len = max(q_seqlens_list) max_kv_seq_len = max(kv_seqlens_list) return max_q_seq_len, max_kv_seq_len - def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_seqlens_list, kv_seqlens_list, + def update_q_seqlens(is_decoding, is_prefill_no_cache, q_seqlens_cpu=None): + if is_decoding: + batch_size = step_context.q_seqlens.size(0) + return torch.arange(1, batch_size + 1, dtype=torch.int32) + elif is_prefill_no_cache: + return q_seqlens_cpu + return q_seqlens_cpu.cumsum(dim=0) + + def get_kv_start_indices_and_attention_mask(is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len): kv_start_indices, attention_mask = [], [] if is_decoding: @@ -236,17 +240,7 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s slots = slot_tables[history_length:kv_seq_len] kv_start_indices.append(slots) - if not is_unpaged_prefill: - single_attention_mask = torch.triu( - torch.ones(q_seq_len, - step_context.block_offsets.shape[1] * block_size, - dtype=torch.bool, - device=step_context.block_offsets.device), - diagonal=kv_seq_len - q_seq_len + 1, - ) - attention_mask.append(single_attention_mask) - - if is_unpaged_prefill: + if is_prefill_no_cache: attention_mask.append( torch.triu(torch.ones(max_q_seq_len, max_kv_seq_len, @@ -254,7 +248,9 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s device=step_context.block_offsets.device), diagonal=max_kv_seq_len - max_q_seq_len + 1)) else: - attention_mask = [torch.cat(attention_mask)] + attention_mask.append( + torch.triu(torch.ones(2048, 2048, dtype=torch.bool, device=step_context.block_offsets.device), + diagonal=1)) kv_start_indices = torch.cat(kv_start_indices) @@ -357,16 +353,16 @@ def get_moe_group_name(group): group_name = backend.get_hccl_comm_name(local_rank) return group_name - q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding, - is_unpaged_prefill) - q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, + q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_prefill_no_cache) + q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu, kv_seqlens_cpu) - max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, + max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_list, kv_seqlens_list) kv_start_indices, attention_mask = get_kv_start_indices_and_attention_mask(step_context.is_decoding, - is_unpaged_prefill, q_seqlens_list, + is_prefill_no_cache, q_seqlens_list, kv_seqlens_list, max_q_seq_len, max_kv_seq_len) + q_seqlens_cpu = update_q_seqlens(step_context.is_decoding, is_prefill_no_cache, q_seqlens_cpu) if not cls.enable_graph and step_context.kv_quant_policy == 8: record_file = os.getenv('ASCEND_QUANT_RECORD_FILE') @@ -387,13 +383,11 @@ def get_moe_group_name(group): step_context.block_offsets, q_start_loc=None, q_seqlens=q_seqlens_cpu, - # kv_seqlens_expanded is only expanded in paged prefill, - # otherwise it equals kv_seqlens_cpu - kv_seqlens=kv_seqlens_expanded, + kv_seqlens=kv_seqlens_cpu, kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask, - is_unpaged_prefill=is_unpaged_prefill, + is_prefill_no_cache=is_prefill_no_cache, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, quant_policy=step_context.kv_quant_policy, diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 78afe49040..745c2d152a 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -13,7 +13,7 @@ class DlinferAttentionMetadata(AttentionMetadata): kv_start_indices: Tensor | None = None block_size: int = 64 attention_mask: Sequence[Tensor] = tuple() - is_unpaged_prefill: bool | None = None + is_prefill_no_cache: bool | None = None max_q_seq_len: int = 1 max_kv_seq_len: int = 1 quant_meta: dict = None @@ -79,7 +79,7 @@ def forward( kv_start_indices = attn_metadata.kv_start_indices block_size = attn_metadata.block_size attn_mask = attn_metadata.attention_mask - is_unpaged_prefill = attn_metadata.is_unpaged_prefill + is_prefill_no_cache = attn_metadata.is_prefill_no_cache max_q_seq_len = attn_metadata.max_q_seq_len max_kv_seq_len = attn_metadata.max_kv_seq_len quant_bits = attn_metadata.quant_policy @@ -138,7 +138,7 @@ def forward( v_head_size=self.v_head_size, attn_mask=attn_mask, softmax_scale=self.scale, - is_unpaged_prefill=is_unpaged_prefill, + is_prefill_no_cache=is_prefill_no_cache, kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits, diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py index 18f04de73b..2c6379334e 100644 --- a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py @@ -60,7 +60,7 @@ def get_total_slots(): kv_start_indices = [] block_num, _, block_size, _ = step_context.kv_caches[0][0].shape - is_unpaged_prefill = False + is_prefill_no_cache = False q_start_loc = step_context.q_start_loc q_seqlens = step_context.q_seqlens kv_seqlens = step_context.kv_seqlens.to(torch.int32) @@ -74,7 +74,7 @@ def get_total_slots(): q_seqlens_list = step_context.q_seqlens.tolist() kv_seqlens_list = step_context.kv_seqlens.tolist() if not step_context.is_decoding: - is_unpaged_prefill = q_seqlens_list == kv_seqlens_list + is_prefill_no_cache = q_seqlens_list == kv_seqlens_list # get kv_indices for i in range(q_start_loc.size(0)): q_seq_len = q_seqlens_list[i] @@ -86,7 +86,7 @@ def get_total_slots(): slots = slot_tables[history_length:kv_seq_len] kv_start_indices.append(slots) kv_start_indices = torch.cat(kv_start_indices) - if not is_unpaged_prefill: + if not is_prefill_no_cache: cu_seq_lens_kv = torch.cat((torch.tensor([0], device=kv_seqlens.device), kv_seqlens.cumsum(0))).int() else: # collect kv_start_indices without using a for-loop, @@ -108,7 +108,7 @@ def get_total_slots(): kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=None, - is_unpaged_prefill=is_unpaged_prefill, + is_prefill_no_cache=is_prefill_no_cache, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, ) diff --git a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py index 8420b159dc..78376df43f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/maca/op_backend.py @@ -52,9 +52,9 @@ def get_total_slots(): kv_start_indices, attention_mask = [], [] block_num, block_size, _, _ = step_context.kv_caches[0][1].shape - is_unpaged_prefill = False + is_prefill_no_cache = False if not step_context.is_decoding: - is_unpaged_prefill = \ + is_prefill_no_cache = \ all((step_context.q_seqlens == step_context.kv_seqlens).tolist()) q_start_loc = step_context.q_start_loc @@ -99,7 +99,7 @@ def get_total_slots(): kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask, - is_unpaged_prefill=is_unpaged_prefill, + is_prefill_no_cache=is_prefill_no_cache, max_q_seq_len=max_q_seq_len, max_kv_seq_len=max_kv_seq_len, ) diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 6916d8a082..5e75de0a5b 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -25,12 +25,12 @@ def prefill_attention( head_size_v: int, attn_mask: Sequence[Tensor | None], softmax_scale: float | None, - is_unpaged_prefill: bool | None, + is_prefill_no_cache: bool | None, kv_scales: Tensor | None, kv_zeros: Tensor | None, quant_bits: int | None, ) -> Tensor: - if is_unpaged_prefill: + if is_prefill_no_cache: return ext_ops.prefill_attention( query_states, key_states, @@ -79,6 +79,7 @@ def paged_token_attention( k_cache, v_cache, attn_output, + q_seqlens, kv_seq_len, max_kv_seq_len, block_offsets, @@ -97,6 +98,7 @@ def paged_token_attention( v_cache, block_offsets, block_size, + q_seqlens, kv_seq_len, max_kv_seq_len, num_q_heads, @@ -131,7 +133,7 @@ def paged_attention_fwd( v_head_size: int, attn_mask: Sequence[Tensor | None] = (), softmax_scale: float | None = None, - is_unpaged_prefill: bool | None = None, + is_prefill_no_cache: bool | None = None, kv_scales: Tensor | None = None, kv_zeros: Tensor | None = None, quant_bits: int | None = 0, @@ -157,7 +159,7 @@ def paged_attention_fwd( v_head_size, attn_mask, softmax_scale, - is_unpaged_prefill, + is_prefill_no_cache, kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits, @@ -168,6 +170,7 @@ def paged_attention_fwd( key_cache, value_cache, attn_output, + q_seqlens, kv_seqlens, max_kv_seq_len, block_offsets,