Skip to content

Commit d6e6a47

Browse files
authored
[Opt] Gsa npu performance optimize (#647)
Improve performance for GSAOnDevice in NPU
1 parent d63a467 commit d6e6a47

1 file changed

Lines changed: 101 additions & 51 deletions

File tree

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 101 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.attention.ops.flashmla import get_mla_metadata
1414
from vllm.config import VllmConfig
1515
from vllm.forward_context import ForwardContext
16+
from vllm.utils import cdiv
1617
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
1718
from vllm.v1.core.sched.output import SchedulerOutput
1819
from vllm.v1.request import Request, RequestStatus
@@ -136,6 +137,18 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
136137

137138
self.seq_len_threshhold = self.gsa_on_device_config.seq_len_threshhold
138139

140+
assert (
141+
self.seq_len_threshhold
142+
>= self.gsa_on_device_config.vllm_hash_attention_topk
143+
), "seq_len_threshhold must be larger than or equal to vllm_hash_attention_topk"
144+
assert (
145+
self.gsa_on_device_config.vllm_hash_attention_topk % self.block_size == 0
146+
), "vllm_hash_attention_topk must be divisible by block_size"
147+
assert (
148+
self.gsa_on_device_config.vllm_hash_attention_topk
149+
<= vllm_config.model_config.max_model_len
150+
), "vllm_hash_attention_topk must be less than max_model_len"
151+
139152
if role == UcmSparseRole.WORKER:
140153
if self.is_cuda: # cuda only variables
141154
device_properties = torch.cuda.get_device_properties(self.device)
@@ -233,7 +246,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
233246
[
234247
self.max_batch_size,
235248
self.num_key_heads,
236-
self.hash_topk_tokens // self.block_size,
249+
cdiv(
250+
vllm_config.model_config.max_model_len, self.block_size
251+
),
237252
],
238253
dtype=torch.int32,
239254
device=self.device,
@@ -340,36 +355,47 @@ def attention_begin(
340355
else: # NPU
341356
if not self.is_tensor_computed:
342357
if self.decode_mask.any(): # with at least one decode request
343-
decode_req_ids = torch.nonzero(
344-
self.decode_mask, as_tuple=False
345-
).flatten()
346-
decode_req_ids_npu = torch.nonzero(
347-
self.decode_mask_npu, as_tuple=False
348-
).flatten()
349-
batch_size_for_hamming = self.decode_mask.sum().item()
350-
self.query_lens_device = attn_metadata.query_lens_device[
351-
decode_req_ids_npu
352-
]
358+
if self.slice_enabled:
359+
# if slice_enabled, the batch_size_for_hamming is the number of decode requests
360+
self.batch_size_for_hamming = self.num_decode_requests
361+
else:
362+
# if not slice_enabled, the batch_size_for_hamming is the number of all requests
363+
self.batch_size_for_hamming = len(
364+
attn_metadata.seq_lens
365+
)
366+
# only get decode_mask_npu when slice_enabled is False
367+
self.decode_mask_npu = (
368+
attn_metadata.query_lens_device == 1
369+
) & (
370+
attn_metadata.seq_lens_device
371+
>= self.seq_len_threshhold
372+
)
373+
353374
self.topk_for_hamming = self.topk_for_hamming_full[
354-
:batch_size_for_hamming
375+
: self.batch_size_for_hamming
355376
]
356377
self.chunk_sizes_for_hamming = (
357378
self.chunk_sizes_for_hamming_full[
358-
:batch_size_for_hamming
379+
: self.batch_size_for_hamming
359380
]
360381
)
382+
361383
self.seq_lens_for_hamming = attn_metadata.seq_lens_device[
362-
decode_req_ids_npu
384+
: self.batch_size_for_hamming
363385
]
364386
self.max_seq_len_for_hamming = torch.max(
365-
attn_metadata.seq_lens[decode_req_ids]
387+
attn_metadata.seq_lens[: self.batch_size_for_hamming]
366388
).item()
389+
self.block_table_decode = self.ori_block_table_decode[
390+
: self.batch_size_for_hamming
391+
]
392+
367393
self.is_tensor_computed = True
368394

369395
k_hash_compute = self.hash_encoder.compute_hash(key)
370-
assert (
371-
k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
372-
), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
396+
# assert (
397+
# k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
398+
# ), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
373399
k_hash_compute = (
374400
k_hash_compute.transpose(0, 1)
375401
.reshape(-1, k_hash_compute.shape[-1])
@@ -433,7 +459,10 @@ def attention_begin(
433459
if not is_rollback_layer:
434460
if is_skip_hash_layer:
435461
# 跳层 使用上一个topk结果
436-
attn_metadata.block_tables = self.topk_block_table
462+
if self.is_cuda:
463+
attn_metadata.block_table = self.topk_block_table
464+
else:
465+
attn_metadata.block_tables = self.topk_block_table
437466
attn_metadata.seq_lens = self.topk_seq_lens
438467
else:
439468
if self.is_cuda:
@@ -475,26 +504,22 @@ def attention_begin(
475504
attn_metadata.seq_lens[self.decode_mask] = (
476505
self.topk_seq_lens_qwen
477506
)
478-
else: # NPU
479507

480-
decode_req_ids = torch.nonzero(
481-
self.decode_mask_npu, as_tuple=False
482-
).flatten()
483-
decode_token_idx = q_start[:-1].index_select(
484-
0, decode_req_ids
485-
)
486-
q_decode = query.index_select(0, decode_token_idx)
508+
# topk for skip layer
509+
self.topk_block_table = attn_metadata.block_table
510+
self.topk_seq_lens = attn_metadata.seq_lens
511+
else: # NPU
512+
if self.slice_enabled:
513+
q_decode = query[: self.batch_size_for_hamming]
514+
else:
515+
q_decode = query.index_select(0, q_start[:-1])
487516

488517
q_hash = (
489518
self.hash_encoder.compute_hash(q_decode)
490519
.unsqueeze(2)
491520
.contiguous()
492521
)
493522

494-
block_table_decode = attn_metadata.block_table.index_select(
495-
0, decode_req_ids
496-
)
497-
498523
ucm_custom_ops.hamming_dist_top_k(
499524
q_hash,
500525
k_hash,
@@ -505,23 +530,37 @@ def attention_begin(
505530
self.hamming_keep_chunks_head,
506531
self.hamming_keep_chunks_tail,
507532
0, # support_offload is disabled
508-
block_table_decode,
509-
self.hamming_output[: len(decode_req_ids)],
533+
self.block_table_decode,
534+
(
535+
self.decode_mask_npu
536+
if not self.slice_enabled
537+
else None
538+
),
539+
self.hamming_output[: self.batch_size_for_hamming],
510540
)
511-
topk = self.hamming_output.shape[-1]
512-
attn_metadata.block_table[decode_req_ids, :topk] = (
513-
self.hamming_output[: len(decode_req_ids), 0, :]
514-
)
515-
attn_metadata.block_table[decode_req_ids, topk:] = 0
541+
new_seq_lens = self.topk_seq_lens_qwen
542+
attn_metadata.seq_lens = new_seq_lens
543+
if (
544+
self.slice_enabled
545+
and attn_metadata.attn_state
546+
!= AscendAttentionState.DecodeOnly
547+
):
548+
new_block_tables = attn_metadata.block_tables.clone()
549+
new_block_tables[: self.batch_size_for_hamming] = (
550+
self.hamming_output[
551+
: self.batch_size_for_hamming, 0, :
552+
]
553+
)
554+
else:
555+
new_block_tables = self.hamming_output[
556+
: self.batch_size_for_hamming, 0, :
557+
]
516558

517-
# we have already computed the topk_seq_lens_qwen in `build_decode_attention_meta_npu()`
518-
attn_metadata.seq_lens[self.decode_mask] = (
519-
self.topk_seq_lens_qwen
520-
)
559+
attn_metadata.block_tables = new_block_tables
521560

522-
# topk for skip layer
523-
self.topk_block_table = attn_metadata.block_table
524-
self.topk_seq_lens = attn_metadata.seq_lens
561+
# topk for skip layer
562+
self.topk_block_table = attn_metadata.block_tables
563+
self.topk_seq_lens = attn_metadata.seq_lens
525564

526565
return query, key, value, output
527566

@@ -549,7 +588,10 @@ def attention_finished(
549588
attn_metadata.decode.num_splits = self.origin_num_splits
550589
else: # 判断req decode阶段
551590
if self.decode_mask.any():
552-
attn_metadata.block_table = self.ori_block_table_decode
591+
if self.is_cuda:
592+
attn_metadata.block_table = self.ori_block_table_decode
593+
else:
594+
attn_metadata.block_tables = self.ori_block_table_decode
553595
attn_metadata.seq_lens = self.ori_seq_lens_decode
554596

555597
def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]):
@@ -645,20 +687,28 @@ def build_decode_attention_meta_npu(self, query_lens, seq_lens, block_table):
645687

646688
# self.decode_mask is on cpu in vllm-asencd under NPU device
647689
self.decode_mask = (query_lens == 1) & (seq_lens >= self.seq_len_threshhold)
648-
self.decode_mask = self.decode_mask.pin_memory()
690+
# self.decode_mask = self.decode_mask.pin_memory()
691+
692+
self.num_decode_requests = self.decode_mask.sum().item()
693+
if self.num_decode_requests > 0:
694+
self.slice_enabled = (
695+
self.decode_mask[: self.num_decode_requests].all().item()
696+
)
697+
else:
698+
self.slice_enabled = False
649699

650700
self.ori_seq_lens_decode = seq_lens.clone()
651701
self.ori_block_table_decode = block_table.clone()
652702

653-
self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
654-
655703
if self.decode_mask.any():
656-
decode_seq_lens = seq_lens[self.decode_mask]
704+
# self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
657705
self.topk_seq_lens_qwen = update_seq_lens(
658-
decode_seq_lens,
706+
seq_lens,
659707
topk_token=self.hash_topk_tokens,
660708
block_size=self.block_size,
661709
)
710+
# (ldeng) set the seq_lens for the non-decode requests to the original seq_lens
711+
self.topk_seq_lens_qwen[~self.decode_mask] = seq_lens[~self.decode_mask]
662712

663713
def maybe_init_cudagraph_buffers_for_topk(self, n, tile_scheduler_metadata):
664714
sm_parts = tile_scheduler_metadata.size(0)

0 commit comments

Comments
 (0)