Skip to content

Commit 2559c3e

Browse files
resolve conflicts
1 parent e0e2e1e commit 2559c3e

2 files changed

Lines changed: 78 additions & 41 deletions

File tree

ucm/sparse/gsa_on_device/configs/gsa_on_device_qwen3_coder_30B_A3B_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"is_mla": false,
44
"hash_weight_type": "random",
55
"num_hidden_layers": 48,
6-
"seq_len_threshhold": 2048,
6+
"seq_len_threshhold": 4096,
77
"chunk_size": 128,
88
"chunk_repre_method": "max",
99
"head_dim": 128,

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 77 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from dataclasses import dataclass, field
21
from importlib import resources
32
from pathlib import Path
43
from typing import Dict, List, Optional, Union
@@ -35,8 +34,6 @@
3534
)
3635
from ucm.sparse.gsa_on_device.hash_encoder import reshape_and_cache_khash_triton
3736

38-
from vllm.utils import cdiv
39-
4037
from ucm.sparse.gsa_on_device.gsa_on_device_config import GSAOnDeviceConfig
4138
from ucm.sparse.gsa_on_device.hash_encoder import HashEncoder
4239
from ucm.utils import Config
@@ -231,6 +228,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
231228
self.decode_only = False
232229

233230
if not self.is_cuda: # NPU only variables
231+
self.decode_mask = None
234232
self.decode_mask_npu = None
235233
self.is_tensor_computed = False
236234

@@ -383,34 +381,37 @@ def cache_k_hash_gqa_cuda(
383381
def cache_k_hash_gqa_npu(self, key, k_hash, attn_metadata):
384382
if not self.is_tensor_computed:
385383
if self.decode_mask.any(): # with at least one decode request
386-
decode_req_ids = torch.nonzero(
387-
self.decode_mask, as_tuple=False
388-
).flatten()
389-
decode_req_ids_npu = torch.nonzero(
390-
self.decode_mask_npu, as_tuple=False
391-
).flatten()
392-
batch_size_for_hamming = self.decode_mask.sum().item()
393-
self.query_lens_device = attn_metadata.query_lens_device[
394-
decode_req_ids_npu
395-
]
384+
if self.slice_enabled:
385+
# if slice_enabled, the batch_size_for_hamming is the number of decode requests
386+
self.batch_size_for_hamming = self.num_decode_requests
387+
else:
388+
# if not slice_enabled, the batch_size_for_hamming is the number of all requests
389+
self.batch_size_for_hamming = len(attn_metadata.seq_lens)
390+
# only get decode_mask_npu when slice_enabled is False
391+
self.decode_mask_npu = (attn_metadata.query_lens_device == 1) & (
392+
attn_metadata.seq_lens_device >= self.seq_len_threshhold
393+
)
396394
self.topk_for_hamming = self.topk_for_hamming_full[
397-
:batch_size_for_hamming
395+
: self.batch_size_for_hamming
398396
]
399397
self.chunk_sizes_for_hamming = self.chunk_sizes_for_hamming_full[
400-
:batch_size_for_hamming
398+
: self.batch_size_for_hamming
401399
]
402400
self.seq_lens_for_hamming = attn_metadata.seq_lens_device[
403-
decode_req_ids_npu
401+
: self.batch_size_for_hamming
404402
]
405403
self.max_seq_len_for_hamming = torch.max(
406-
attn_metadata.seq_lens[decode_req_ids]
404+
attn_metadata.seq_lens[: self.batch_size_for_hamming]
407405
).item()
406+
self.block_table_decode = self.ori_block_table_decode[
407+
: self.batch_size_for_hamming
408+
]
408409
self.is_tensor_computed = True
409410

410411
k_hash_compute = self.hash_encoder.compute_hash(key)
411-
assert (
412-
k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
413-
), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
412+
# assert (
413+
# k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
414+
# ), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
414415
k_hash_compute = (
415416
k_hash_compute.transpose(0, 1)
416417
.reshape(-1, k_hash_compute.shape[-1])
@@ -510,17 +511,18 @@ def update_decode_topk_gqa_cuda(self, query, k_hash, attn_metadata):
510511
0, self.decode_req_ids, self.topk_seq_lens_qwen
511512
)
512513
attn_metadata.seq_lens = self.new_seq_lens
514+
# topk for skip layer
515+
self.topk_block_table = attn_metadata.block_table
516+
self.topk_seq_lens = attn_metadata.seq_lens
513517

514518
def update_decode_topk_gqa_npu(self, query, k_hash, attn_metadata):
515519
q_start = attn_metadata.query_start_loc
516-
decode_req_ids = torch.nonzero(self.decode_mask_npu, as_tuple=False).flatten()
517-
decode_token_idx = q_start[:-1].index_select(0, decode_req_ids)
518-
q_decode = query.index_select(0, decode_token_idx)
519-
520+
if self.slice_enabled:
521+
q_decode = query[: self.batch_size_for_hamming]
522+
else:
523+
q_decode = query.index_select(0, q_start[:-1])
520524
q_hash = self.hash_encoder.compute_hash(q_decode).unsqueeze(2).contiguous()
521525

522-
block_table_decode = attn_metadata.block_table.index_select(0, decode_req_ids)
523-
524526
ucm_custom_ops.hamming_dist_top_k(
525527
q_hash,
526528
k_hash,
@@ -531,17 +533,26 @@ def update_decode_topk_gqa_npu(self, query, k_hash, attn_metadata):
531533
self.hamming_keep_chunks_head,
532534
self.hamming_keep_chunks_tail,
533535
0, # support_offload is disabled
534-
block_table_decode,
535-
self.hamming_output[: len(decode_req_ids)],
536+
self.block_table_decode,
537+
(self.decode_mask_npu if not self.slice_enabled else None),
538+
self.hamming_output[: self.batch_size_for_hamming],
536539
)
537-
topk = self.hamming_output.shape[-1]
538-
attn_metadata.block_table[decode_req_ids, :topk] = self.hamming_output[
539-
: len(decode_req_ids), 0, :
540-
]
541-
attn_metadata.block_table[decode_req_ids, topk:] = 0
542-
543-
# we have already computed the topk_seq_lens_qwen in `build_decode_attention_meta_npu()`
544-
attn_metadata.seq_lens[self.decode_mask] = self.topk_seq_lens_qwen
540+
new_seq_lens = self.topk_seq_lens_qwen
541+
attn_metadata.seq_lens = new_seq_lens
542+
if (
543+
self.slice_enabled
544+
and attn_metadata.attn_state != AscendAttentionState.DecodeOnly
545+
):
546+
new_block_tables = attn_metadata.block_tables.clone()
547+
new_block_tables[: self.batch_size_for_hamming] = self.hamming_output[
548+
: self.batch_size_for_hamming, 0, :
549+
]
550+
else:
551+
new_block_tables = self.hamming_output[: self.batch_size_for_hamming, 0, :]
552+
attn_metadata.block_tables = new_block_tables
553+
# topk for skip layer
554+
self.topk_block_table = attn_metadata.block_tables
555+
self.topk_seq_lens = attn_metadata.seq_lens
545556

546557
def attention_begin(
547558
self,
@@ -608,9 +619,6 @@ def attention_begin(
608619
self.update_decode_topk_gqa_npu(
609620
query, k_hash, attn_metadata
610621
)
611-
# topk for skip layer
612-
self.topk_block_table = attn_metadata.block_table
613-
self.topk_seq_lens = attn_metadata.seq_lens
614622

615623
return query, key, value, output
616624

@@ -719,6 +727,35 @@ def build_decode_hash(self, seq_lens):
719727
)
720728
return topk_seq_lens, topk_tile_scheduler_metadata, topk_num_splits
721729

730+
def build_decode_attention_meta_npu(self, query_lens, seq_lens, block_table):
731+
732+
from ucm.sparse.gsa_on_device.hamming_topk import update_seq_lens
733+
734+
# self.decode_mask is on cpu in vllm-asencd under NPU device
735+
self.decode_mask = (query_lens == 1) & (seq_lens >= self.seq_len_threshhold)
736+
# self.decode_mask = self.decode_mask.pin_memory()
737+
738+
self.num_decode_requests = self.decode_mask.sum().item()
739+
if self.num_decode_requests > 0:
740+
self.slice_enabled = (
741+
self.decode_mask[: self.num_decode_requests].all().item()
742+
)
743+
else:
744+
self.slice_enabled = False
745+
746+
self.ori_seq_lens_decode = seq_lens.clone()
747+
self.ori_block_table_decode = block_table.clone()
748+
749+
if self.decode_mask.any():
750+
# self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
751+
self.topk_seq_lens_qwen = update_seq_lens(
752+
seq_lens,
753+
topk_token=self.hash_topk_tokens,
754+
block_size=self.block_size,
755+
)
756+
# (ldeng) set the seq_lens for the non-decode requests to the original seq_lens
757+
self.topk_seq_lens_qwen[~self.decode_mask] = seq_lens[~self.decode_mask]
758+
722759
def rebuild_prefix_cache_info_for_req(
723760
self,
724761
block_table_row: torch.Tensor,
@@ -915,4 +952,4 @@ def update_states(self, scheduler_output: SchedulerOutput) -> None:
915952
req_data.req_ids, req_data.resumed_from_preemption
916953
):
917954
if resumed_from_preemption:
918-
self._free_cached_request(req_id)
955+
self._free_cached_request(req_id)

0 commit comments

Comments
 (0)