11# Copyright (c) OpenMMLab. All rights reserved.
2- """Triton kernel to pack BF16 tokens into FlashMLA MODEL1 sparse FP8 flat-
2+ """Triton kernel to pack BF16 tokens into V4 FlashMLA sparse FP8 flat-
33layout window cache, replacing the per-token Python loop in
44_pack_window_state_tokens.
55
6- FlashMLA MODEL1 flat layout per slot (viewed as flat bytes):
6+ V4 FlashMLA flat layout per slot (viewed as flat bytes):
77 [token_0 NoPE+RoPE | token_1 NoPE+RoPE | ... | token_0 scales | token_1 scales | ...]
88 NoPE+RoPE per token = 576 bytes (448 e4m3fn + 128 bf16)
99 Scales per token = 8 bytes (7 e8m0fnu + 1 padding)
@@ -92,19 +92,19 @@ def pack_window_tokens_fp8(
9292 slot : torch .Tensor ,
9393 positions : torch .Tensor ,
9494):
95- """Pack BF16 tokens into FlashMLA MODEL1 sparse FP8 window cache.
95+ """Pack BF16 tokens into V4 FlashMLA sparse FP8 window cache.
9696
9797 Args:
9898 kv_tokens: [num_tokens, 512] BF16 tokens to pack.
9999 window_state_fp8_cache: [num_total_slots, window_size, packed_dim] FP8 cache.
100100 slot: [num_tokens] slot indices (which cache row to write to).
101101 positions: [num_tokens] ring-buffer positions within the window.
102102 """
103- from lmdeploy .pytorch .backends .cuda .attention . flashmla_utils import (
104- MODEL1_D_NOPE ,
105- MODEL1_D_ROPE ,
106- MODEL1_NUM_TILES ,
107- MODEL1_TILE_SIZE ,
103+ from lmdeploy .pytorch .kernels .cuda .dsv4 . layout import (
104+ V4_FLASHMLA_D_NOPE ,
105+ V4_FLASHMLA_D_ROPE ,
106+ V4_FLASHMLA_NUM_TILES ,
107+ V4_FLASHMLA_TILE_SIZE ,
108108 )
109109
110110 assert kv_tokens .dim () == 2
@@ -113,7 +113,7 @@ def pack_window_tokens_fp8(
113113 return
114114
115115 window_size = window_state_fp8_cache .size (1 )
116- nope_rope_stride = MODEL1_D_NOPE + 2 * MODEL1_D_ROPE # 576 bytes per token in NoPE+RoPE region
116+ nope_rope_stride = V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE # 576 bytes per token in NoPE+RoPE region
117117 num_slots = window_state_fp8_cache .size (0 )
118118
119119 # Create three views of the same FP8 cache buffer (same pattern as fill_compressed_kv)
@@ -122,16 +122,16 @@ def pack_window_tokens_fp8(
122122 # NoPE+RoPE region: [num_slots, window_size * 576] as e4m3fn
123123 nope_rope = flat [:, :window_size * nope_rope_stride ].view (
124124 num_slots , window_size , nope_rope_stride )
125- nope_view = nope_rope [:, :, :MODEL1_D_NOPE ] # [num_slots, window_size, 448] e4m3fn
125+ nope_view = nope_rope [:, :, :V4_FLASHMLA_D_NOPE ] # [num_slots, window_size, 448] e4m3fn
126126
127127 # RoPE region: slice the RoPE part first (128 e4m3fn bytes = 64 bf16 elements),
128- # then view as bf16 — same pattern as quantize_model1_fp8_sparse
129- rope_e4 = nope_rope [:, :, MODEL1_D_NOPE :] # [num_slots, window_size, 128] e4m3fn
128+ # then view as bf16 — same pattern as quantize_v4_flashmla_sparse
129+ rope_e4 = nope_rope [:, :, V4_FLASHMLA_D_NOPE :] # [num_slots, window_size, 128] e4m3fn
130130 rope_view = rope_e4 .view (torch .bfloat16 ) # [num_slots, window_size, 64] bf16
131131
132132 # Scale region: uint8 view
133133 scale_view = flat [:, window_size * nope_rope_stride :].view (
134- num_slots , window_size , 8 )[:, :, :MODEL1_NUM_TILES ].view (torch .uint8 )
134+ num_slots , window_size , 8 )[:, :, :V4_FLASHMLA_NUM_TILES ].view (torch .uint8 )
135135
136136 grid = (num_tokens ,)
137137 _pack_window_tokens_fp8_kernel [grid ](
@@ -151,8 +151,8 @@ def pack_window_tokens_fp8(
151151 stride_scale_pos = scale_view .stride (1 ),
152152 stride_slot = 1 ,
153153 WINDOW_SIZE = window_size ,
154- D_NOPE = MODEL1_D_NOPE ,
155- D_ROPE = MODEL1_D_ROPE ,
156- TILE_SIZE = MODEL1_TILE_SIZE ,
157- NUM_TILES = MODEL1_NUM_TILES ,
154+ D_NOPE = V4_FLASHMLA_D_NOPE ,
155+ D_ROPE = V4_FLASHMLA_D_ROPE ,
156+ TILE_SIZE = V4_FLASHMLA_TILE_SIZE ,
157+ NUM_TILES = V4_FLASHMLA_NUM_TILES ,
158158 )
0 commit comments