@@ -218,6 +218,15 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
218218 self .qk_rope_head_dim = getattr (
219219 vllm_config .model_config .hf_text_config , "qk_rope_head_dim" , None
220220 )
221+ self .hash_encoder = HashEncoder (
222+ input_dim = self .kv_lora_rank ,
223+ hash_bits = self .kv_lora_rank ,
224+ dtype = vllm_config .model_config .dtype ,
225+ device = self .device ,
226+ input_dim_rope = self .qk_rope_head_dim ,
227+ hash_bits_rope = self .qk_rope_head_dim ,
228+ is_mla = True ,
229+ )
221230 self .hash_encoder_nope = HashEncoder (
222231 input_dim = self .kv_lora_rank ,
223232 hash_bits = self .kv_lora_rank ,
@@ -426,14 +435,12 @@ def get_layer_state(self, layer_name: str):
426435 def cache_k_hash_mla_cuda (
427436 self , nope , rope , k_hash , attn_metadata , forward_context , layer_name
428437 ):
429- k_c_normed_hash , k_pe_hash = self .hash_code (nope = nope , rope = rope )
430- ops .concat_and_cache_mla (
431- k_c_normed_hash ,
432- k_pe_hash .squeeze (1 ),
433- k_hash ,
434- attn_metadata .slot_mapping .flatten (),
435- kv_cache_dtype = "auto" ,
436- scale = self ._k_scale ,
438+ self .hash_encoder .compute_hash_and_cache_mla (
439+ x = nope ,
440+ x_rope = rope .squeeze (1 ),
441+ slot_mapping = attn_metadata .slot_mapping .flatten (),
442+ k_hash_cache = k_hash ,
443+ block_size = self .block_size ,
437444 )
438445 if self .has_pc_hit :
439446 ## kvcache -> nope + rope
@@ -444,14 +451,13 @@ def cache_k_hash_mla_cuda(
444451 k_c_normed , k_pe = torch .split (k_cache , [512 , 64 ], dim = - 1 )
445452 k_c_normed = k_c_normed .reshape (- 1 , k_c_normed .shape [2 ])
446453 k_pe = k_pe .reshape (- 1 , k_pe .shape [2 ])
447- k_c_normed_hash , k_pe_hash = self .hash_code (nope = k_c_normed , rope = k_pe )
448- ops .concat_and_cache_mla (
449- k_c_normed_hash ,
450- k_pe_hash ,
451- k_hash ,
452- self .prefix_slot_mapping .flatten (),
453- kv_cache_dtype = "auto" ,
454- scale = self ._k_scale ,
454+
455+ self .hash_encoder .compute_hash_and_cache_mla (
456+ x = k_c_normed ,
457+ x_rope = k_pe ,
458+ slot_mapping = self .prefix_slot_mapping .flatten (),
459+ k_hash_cache = k_hash ,
460+ block_size = self .block_size ,
455461 )
456462
457463 def cache_k_hash_mla_npu (
0 commit comments