@@ -20,16 +20,16 @@ def cuda_hamming_topk(
2020 topk_token ,
2121 sink_token ,
2222 recent_token ,
23+ is_mla ,
2324):
2425 q_hash = q_hash .view (torch .int32 )
2526 k_hash = k_hash .view (torch .int32 )
26- assert k_hash .shape [1 ] == 1
27+ # assert k_hash.shape[1] == 1
2728 # assert k_hash.shape[-1] == 18 and q_hash.shape[-1] == 18
28- block_size = k_hash .shape [2 ]
29+ block_size = k_hash .shape [1 ]
2930 assert topk_token % block_size == 0
3031 assert recent_token > 0 and topk_token > (sink_token + recent_token )
3132 max_seqlen = block_size * block_table .shape [1 ]
32-
3333 output = hamming .hamming_score (
3434 k_hash ,
3535 q_hash ,
@@ -40,17 +40,23 @@ def cuda_hamming_topk(
4040 recent_token ,
4141 )
4242
43- block_output = torch . min (
44- output .view ( output . shape [ 0 ], output . shape [ - 1 ] // block_size , block_size ), dim = - 1
45- )[ 0 ]
43+ k_blocks = topk_token // block_size
44+ B , Hk , S = output .shape
45+ num_blocks = S // block_size
4646
47- ind = torch .topk (block_output , k = (topk_token // block_size ), dim = - 1 , largest = False )[
48- 1
49- ]
50- ind = torch .sort (ind , dim = - 1 , descending = False )[0 ]
47+ # block_output: [B, Hk, num_blocks]
48+ block_output = output .view (B , Hk , num_blocks , block_size ).amin (dim = - 1 )
5149
52- new_block_table = torch .gather (block_table , dim = - 1 , index = ind )
53- return new_block_table
50+ if is_mla :
51+ block_score = block_output [:, 0 , :]
52+ ind = torch .topk (block_score , k = k_blocks , dim = - 1 , largest = False ).indices
53+ ind = ind .sort (dim = - 1 ).values
54+ return torch .gather (block_table , dim = - 1 , index = ind )
55+
56+ block_score = block_output .amin (dim = 1 ) # [B, num_blocks]
57+ ind = torch .topk (block_score , k = k_blocks , dim = - 1 , largest = False ).indices
58+ ind = ind .sort (dim = - 1 ).values
59+ return torch .gather (block_table , dim = - 1 , index = ind )
5460
5561
5662def fake_hamming_topk (
@@ -66,7 +72,7 @@ def fake_hamming_topk(
6672 k_hash = k_hash .view (torch .int32 )
6773 assert k_hash .shape [1 ] == 1
6874 assert k_hash .shape [- 1 ] == 18 and q_hash .shape [- 1 ] == 18
69- block_size = k_hash .shape [2 ]
75+ block_size = k_hash .shape [1 ]
7076 assert topk_token % block_size == 0
7177 assert recent_token > 0 and topk_token > (sink_token + recent_token )
7278 max_seqlen = block_size * block_table .shape [1 ]
0 commit comments