Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 101 additions & 51 deletions ucm/sparse/gsa_on_device/gsa_on_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.attention.ops.flashmla import get_mla_metadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -136,6 +137,18 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):

self.seq_len_threshhold = self.gsa_on_device_config.seq_len_threshhold

assert (
self.seq_len_threshhold
>= self.gsa_on_device_config.vllm_hash_attention_topk
), "seq_len_threshhold must be larger than or equal to vllm_hash_attention_topk"
assert (
self.gsa_on_device_config.vllm_hash_attention_topk % self.block_size == 0
), "vllm_hash_attention_topk must be divisible by block_size"
assert (
self.gsa_on_device_config.vllm_hash_attention_topk
<= vllm_config.model_config.max_model_len
), "vllm_hash_attention_topk must be less than max_model_len"

if role == UcmSparseRole.WORKER:
if self.is_cuda: # cuda only variables
device_properties = torch.cuda.get_device_properties(self.device)
Expand Down Expand Up @@ -233,7 +246,9 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
[
self.max_batch_size,
self.num_key_heads,
self.hash_topk_tokens // self.block_size,
cdiv(
vllm_config.model_config.max_model_len, self.block_size
),
],
dtype=torch.int32,
device=self.device,
Expand Down Expand Up @@ -340,36 +355,47 @@ def attention_begin(
else: # NPU
if not self.is_tensor_computed:
if self.decode_mask.any(): # with at least one decode request
decode_req_ids = torch.nonzero(
self.decode_mask, as_tuple=False
).flatten()
decode_req_ids_npu = torch.nonzero(
self.decode_mask_npu, as_tuple=False
).flatten()
batch_size_for_hamming = self.decode_mask.sum().item()
self.query_lens_device = attn_metadata.query_lens_device[
decode_req_ids_npu
]
if self.slice_enabled:
# if slice_enabled, the batch_size_for_hamming is the number of decode requests
self.batch_size_for_hamming = self.num_decode_requests
else:
# if not slice_enabled, the batch_size_for_hamming is the number of all requests
self.batch_size_for_hamming = len(
attn_metadata.seq_lens
)
# only get decode_mask_npu when slice_enabled is False
self.decode_mask_npu = (
attn_metadata.query_lens_device == 1
) & (
attn_metadata.seq_lens_device
>= self.seq_len_threshhold
)

self.topk_for_hamming = self.topk_for_hamming_full[
:batch_size_for_hamming
: self.batch_size_for_hamming
]
self.chunk_sizes_for_hamming = (
self.chunk_sizes_for_hamming_full[
:batch_size_for_hamming
: self.batch_size_for_hamming
]
)

self.seq_lens_for_hamming = attn_metadata.seq_lens_device[
decode_req_ids_npu
: self.batch_size_for_hamming
]
self.max_seq_len_for_hamming = torch.max(
attn_metadata.seq_lens[decode_req_ids]
attn_metadata.seq_lens[: self.batch_size_for_hamming]
).item()
self.block_table_decode = self.ori_block_table_decode[
: self.batch_size_for_hamming
]

self.is_tensor_computed = True

k_hash_compute = self.hash_encoder.compute_hash(key)
assert (
k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
# assert (
# k_hash_compute.shape[0] == attn_metadata.slot_mapping.numel()
# ), f"shape mismatch: k_hash_compute.shape[0]={k_hash_compute.shape[0]} != attn_metadata.slot_mapping.numel()={attn_metadata.slot_mapping.numel()}"
k_hash_compute = (
k_hash_compute.transpose(0, 1)
.reshape(-1, k_hash_compute.shape[-1])
Expand Down Expand Up @@ -433,7 +459,10 @@ def attention_begin(
if not is_rollback_layer:
if is_skip_hash_layer:
# 跳层 使用上一个topk结果
attn_metadata.block_tables = self.topk_block_table
if self.is_cuda:
attn_metadata.block_table = self.topk_block_table
else:
attn_metadata.block_tables = self.topk_block_table
attn_metadata.seq_lens = self.topk_seq_lens
else:
if self.is_cuda:
Expand Down Expand Up @@ -475,26 +504,22 @@ def attention_begin(
attn_metadata.seq_lens[self.decode_mask] = (
self.topk_seq_lens_qwen
)
else: # NPU

decode_req_ids = torch.nonzero(
self.decode_mask_npu, as_tuple=False
).flatten()
decode_token_idx = q_start[:-1].index_select(
0, decode_req_ids
)
q_decode = query.index_select(0, decode_token_idx)
# topk for skip layer
self.topk_block_table = attn_metadata.block_table
self.topk_seq_lens = attn_metadata.seq_lens
else: # NPU
if self.slice_enabled:
q_decode = query[: self.batch_size_for_hamming]
else:
q_decode = query.index_select(0, q_start[:-1])

q_hash = (
self.hash_encoder.compute_hash(q_decode)
.unsqueeze(2)
.contiguous()
)

block_table_decode = attn_metadata.block_table.index_select(
0, decode_req_ids
)

ucm_custom_ops.hamming_dist_top_k(
q_hash,
k_hash,
Expand All @@ -505,23 +530,37 @@ def attention_begin(
self.hamming_keep_chunks_head,
self.hamming_keep_chunks_tail,
0, # support_offload is disabled
block_table_decode,
self.hamming_output[: len(decode_req_ids)],
self.block_table_decode,
(
self.decode_mask_npu
if not self.slice_enabled
else None
),
self.hamming_output[: self.batch_size_for_hamming],
)
topk = self.hamming_output.shape[-1]
attn_metadata.block_table[decode_req_ids, :topk] = (
self.hamming_output[: len(decode_req_ids), 0, :]
)
attn_metadata.block_table[decode_req_ids, topk:] = 0
new_seq_lens = self.topk_seq_lens_qwen
attn_metadata.seq_lens = new_seq_lens
if (
self.slice_enabled
and attn_metadata.attn_state
!= AscendAttentionState.DecodeOnly
):
new_block_tables = attn_metadata.block_tables.clone()
new_block_tables[: self.batch_size_for_hamming] = (
self.hamming_output[
: self.batch_size_for_hamming, 0, :
]
)
else:
new_block_tables = self.hamming_output[
: self.batch_size_for_hamming, 0, :
]

# we have already computed the topk_seq_lens_qwen in `build_decode_attention_meta_npu()`
attn_metadata.seq_lens[self.decode_mask] = (
self.topk_seq_lens_qwen
)
attn_metadata.block_tables = new_block_tables

# topk for skip layer
self.topk_block_table = attn_metadata.block_table
self.topk_seq_lens = attn_metadata.seq_lens
# topk for skip layer
self.topk_block_table = attn_metadata.block_tables
self.topk_seq_lens = attn_metadata.seq_lens

return query, key, value, output

Expand Down Expand Up @@ -549,7 +588,10 @@ def attention_finished(
attn_metadata.decode.num_splits = self.origin_num_splits
else: # 判断req decode阶段
if self.decode_mask.any():
attn_metadata.block_table = self.ori_block_table_decode
if self.is_cuda:
attn_metadata.block_table = self.ori_block_table_decode
else:
attn_metadata.block_tables = self.ori_block_table_decode
attn_metadata.seq_lens = self.ori_seq_lens_decode

def request_begin(self, request_id: ReqType, prompt_token_ids: List[int]):
Expand Down Expand Up @@ -645,20 +687,28 @@ def build_decode_attention_meta_npu(self, query_lens, seq_lens, block_table):

# self.decode_mask is on cpu in vllm-asencd under NPU device
self.decode_mask = (query_lens == 1) & (seq_lens >= self.seq_len_threshhold)
self.decode_mask = self.decode_mask.pin_memory()
# self.decode_mask = self.decode_mask.pin_memory()

self.num_decode_requests = self.decode_mask.sum().item()
if self.num_decode_requests > 0:
self.slice_enabled = (
self.decode_mask[: self.num_decode_requests].all().item()
)
else:
self.slice_enabled = False

self.ori_seq_lens_decode = seq_lens.clone()
self.ori_block_table_decode = block_table.clone()

self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)

if self.decode_mask.any():
decode_seq_lens = seq_lens[self.decode_mask]
# self.decode_mask_npu = self.decode_mask.to(self.device, non_blocking=True)
self.topk_seq_lens_qwen = update_seq_lens(
decode_seq_lens,
seq_lens,
topk_token=self.hash_topk_tokens,
block_size=self.block_size,
)
# (ldeng) set the seq_lens for the non-decode requests to the original seq_lens
self.topk_seq_lens_qwen[~self.decode_mask] = seq_lens[~self.decode_mask]

def maybe_init_cudagraph_buffers_for_topk(self, n, tile_scheduler_metadata):
sm_parts = tile_scheduler_metadata.size(0)
Expand Down