Skip to content

Commit efa8f2e

Browse files
committed
[Opt] Add fused triton operator for MLA models to gsa_on_device
1 parent 054a3a8 commit efa8f2e

File tree

4 files changed

+570
-45
lines changed

4 files changed

+570
-45
lines changed

ucm/sparse/gsa_on_device/gsa_on_device.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,15 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
218218
self.qk_rope_head_dim = getattr(
219219
vllm_config.model_config.hf_text_config, "qk_rope_head_dim", None
220220
)
221+
self.hash_encoder = HashEncoder(
222+
input_dim=self.kv_lora_rank,
223+
hash_bits=self.kv_lora_rank,
224+
dtype=vllm_config.model_config.dtype,
225+
device=self.device,
226+
input_dim_rope=self.qk_rope_head_dim,
227+
hash_bits_rope=self.qk_rope_head_dim,
228+
is_mla=True,
229+
)
221230
self.hash_encoder_nope = HashEncoder(
222231
input_dim=self.kv_lora_rank,
223232
hash_bits=self.kv_lora_rank,
@@ -426,14 +435,12 @@ def get_layer_state(self, layer_name: str):
426435
def cache_k_hash_mla_cuda(
427436
self, nope, rope, k_hash, attn_metadata, forward_context, layer_name
428437
):
429-
k_c_normed_hash, k_pe_hash = self.hash_code(nope=nope, rope=rope)
430-
ops.concat_and_cache_mla(
431-
k_c_normed_hash,
432-
k_pe_hash.squeeze(1),
433-
k_hash,
434-
attn_metadata.slot_mapping.flatten(),
435-
kv_cache_dtype="auto",
436-
scale=self._k_scale,
438+
self.hash_encoder.compute_hash_and_cache_mla(
439+
x=nope,
440+
x_rope=rope.squeeze(1),
441+
slot_mapping=attn_metadata.slot_mapping.flatten(),
442+
k_hash_cache=k_hash,
443+
block_size=self.block_size,
437444
)
438445
if self.has_pc_hit:
439446
## kvcache -> nope + rope
@@ -444,14 +451,13 @@ def cache_k_hash_mla_cuda(
444451
k_c_normed, k_pe = torch.split(k_cache, [512, 64], dim=-1)
445452
k_c_normed = k_c_normed.reshape(-1, k_c_normed.shape[2])
446453
k_pe = k_pe.reshape(-1, k_pe.shape[2])
447-
k_c_normed_hash, k_pe_hash = self.hash_code(nope=k_c_normed, rope=k_pe)
448-
ops.concat_and_cache_mla(
449-
k_c_normed_hash,
450-
k_pe_hash,
451-
k_hash,
452-
self.prefix_slot_mapping.flatten(),
453-
kv_cache_dtype="auto",
454-
scale=self._k_scale,
454+
455+
self.hash_encoder.compute_hash_and_cache_mla(
456+
x=k_c_normed,
457+
x_rope=k_pe,
458+
slot_mapping=self.prefix_slot_mapping.flatten(),
459+
k_hash_cache=k_hash,
460+
block_size=self.block_size,
455461
)
456462

457463
def cache_k_hash_mla_npu(

0 commit comments

Comments
 (0)