Skip to content

Commit 4ed038f

Browse files
committed
refactor
1 parent 5d36638 commit 4ed038f

8 files changed

Lines changed: 118 additions & 164 deletions

File tree

lmdeploy/pytorch/backends/cuda/attention/flashmla_utils.py

Lines changed: 0 additions & 104 deletions
This file was deleted.

lmdeploy/pytorch/configurations/deepseek_v4.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import torch
33

4-
from lmdeploy.pytorch.backends.cuda.attention.flashmla_utils import model1_fp8_sparse_token_dim
4+
from lmdeploy.pytorch.kernels.cuda.dsv4.layout import V4_FLASHMLA_D_NOPE, V4_FLASHMLA_D_ROPE, V4_FLASHMLA_NUM_TILES
55
from lmdeploy.pytorch.config import BlockCacheSpec, ModelConfig, StateCacheSpec
66

77
from .builder import AutoModelConfigBuilder
@@ -45,7 +45,8 @@ def _finalize_v4_cache_specs(model_config: ModelConfig, block_size: int):
4545
'has an integral number of entries per block.')
4646

4747
hf_config = model_config.hf_config
48-
packed_token_dim = model1_fp8_sparse_token_dim(64)
48+
# V4 FlashMLA sparse FP8: 448 fp8 NoPE + 128 bytes (64 bf16) RoPE + 7 e8m0 scales + 1 pad = 584
49+
packed_token_dim = V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE + V4_FLASHMLA_NUM_TILES + 1
4950
num_layers = hf_config.num_hidden_layers
5051
compress_ratios = getattr(hf_config, 'compress_ratios', None) or [0] * num_layers
5152
ratio4_layers = [i for i, r in enumerate(compress_ratios) if r == 4]
@@ -97,7 +98,8 @@ def build(cls, hf_config, model_path: str | None = None, tp: int = 1, **kwargs):
9798
"""
9899
bos_token_id = getattr(hf_config, 'bos_token_id', None)
99100
head_dim = getattr(hf_config, 'head_dim', 512)
100-
packed_token_dim = model1_fp8_sparse_token_dim(64)
101+
# V4 FlashMLA sparse FP8: 448 fp8 NoPE + 128 bytes (64 bf16) RoPE + 7 e8m0 scales + 1 pad = 584
102+
packed_token_dim = V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE + V4_FLASHMLA_NUM_TILES + 1
101103
num_layers = hf_config.num_hidden_layers
102104
compress_ratios = getattr(hf_config, 'compress_ratios', None) or [0] * num_layers
103105

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
"""DeepSeek-V4 FlashMLA sparse FP8 layout constants and helpers.
3+
4+
The V4 FlashMLA sparse layout packs a 512-dim K cache head as:
5+
[448 fp8 NoPE | 128 bytes (64 bf16) RoPE | 7 e8m0 scale bytes | 1 pad byte]
6+
= 584 bytes per token.
7+
8+
NoPE region: 7 tiles of 64 elements, each tile quantized to FP8 e4m3fn with
9+
a per-tile e8m0fnu power-of-2 scale factor.
10+
RoPE region: 64 BF16 values stored as raw bytes (128 bytes).
11+
Scales: 7 e8m0fnu scale bytes + 1 padding byte = 8 bytes.
12+
"""
13+
import torch
14+
15+
V4_FLASHMLA_HEAD_DIM = 512
16+
V4_FLASHMLA_D_NOPE = 448
17+
V4_FLASHMLA_D_ROPE = 64
18+
V4_FLASHMLA_TILE_SIZE = 64
19+
V4_FLASHMLA_NUM_TILES = 7
20+
21+
22+
def dequantize_v4_flashmla_sparse(quant_k_cache: torch.Tensor) -> torch.Tensor:
23+
"""Dequantize V4 FlashMLA sparse FP8 KV cache to BF16.
24+
25+
Args:
26+
quant_k_cache: ``[num_blocks, block_size, 1, packed_dim]`` FP8 cache.
27+
28+
Returns:
29+
``[num_blocks, block_size, 1, 512]`` BF16 tensor.
30+
"""
31+
assert quant_k_cache.dim() == 4
32+
num_blocks, block_size, num_heads, _ = quant_k_cache.shape
33+
assert num_heads == 1
34+
35+
result = torch.empty((num_blocks, block_size, V4_FLASHMLA_HEAD_DIM),
36+
dtype=torch.bfloat16,
37+
device=quant_k_cache.device)
38+
quant_k_cache = quant_k_cache.view(num_blocks, -1)
39+
input_nope_rope = quant_k_cache[:, :block_size * (V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE)].view(
40+
num_blocks, block_size, V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE)
41+
input_nope = input_nope_rope[:, :, :V4_FLASHMLA_D_NOPE]
42+
input_rope = input_nope_rope[:, :, V4_FLASHMLA_D_NOPE:].view(torch.bfloat16)
43+
input_scale = quant_k_cache[:, block_size * (V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE):].view(
44+
num_blocks, block_size, 8)[:, :, :V4_FLASHMLA_NUM_TILES].view(torch.float8_e8m0fnu)
45+
46+
result[..., V4_FLASHMLA_D_NOPE:] = input_rope
47+
for tile_idx in range(V4_FLASHMLA_NUM_TILES):
48+
cur_nope = input_nope[..., tile_idx * V4_FLASHMLA_TILE_SIZE:(tile_idx + 1) * V4_FLASHMLA_TILE_SIZE].to(
49+
torch.bfloat16)
50+
cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1)
51+
result[..., tile_idx * V4_FLASHMLA_TILE_SIZE:(tile_idx + 1) * V4_FLASHMLA_TILE_SIZE] = cur_nope * cur_scales
52+
53+
return result.view(num_blocks, block_size, 1, V4_FLASHMLA_HEAD_DIM)

lmdeploy/pytorch/kernels/cuda/v4_compressor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,9 @@ def _fill_compressed_kv_kernel(
625625
cache_ptrs = kv_cache_ptr + phys_block * kvc_stride_b + block_off * kvc_stride_s + offs_d * kvc_stride_d
626626
tl.store(cache_ptrs, compressed.to(kv_cache_ptr.dtype.element_ty))
627627

628-
# ---- Write to FP8 paged block cache (MODEL1 sparse format) ----
628+
# ---- Write to FP8 paged block cache (V4 FlashMLA sparse format) ----
629629
if has_fp8:
630-
# FlashMLA MODEL1 sparse FP8 layout (matches C++ kernel addressing):
630+
# V4 FlashMLA sparse FP8 layout (matches C++ kernel addressing):
631631
# NoPE+RoPE region: [num_blocks, entries_per_block, 576] as e4m3fn
632632
# per-token: [NoPE 448 fp8 | RoPE 128 bytes (64 bf16)]
633633
# token stride = 576 bytes
@@ -639,6 +639,7 @@ def _fill_compressed_kv_kernel(
639639
# fp8_nope_rope_ptr — e4m3fn, stride_b/stride_s for NoPE write
640640
# fp8_rope_bf16_ptr — bfloat16 view of the RoPE region
641641
# fp8_scales_u8_ptr — uint8 view of the scales region
642+
# Must match V4_FLASHMLA_* in dsv4/layout.py
642643
D_NOPE: tl.constexpr = 448
643644
D_ROPE: tl.constexpr = 64
644645
TILE_SIZE: tl.constexpr = 64
@@ -732,7 +733,7 @@ def fill_compressed_kv(
732733
(abs_pos = n*ratio - 1), this kernel scatters those entries into the
733734
block-paged kv_cache used by the decode-phase sparse attention.
734735
735-
When fp8_cache is provided, also writes MODEL1 sparse FP8 packed entries
736+
When fp8_cache is provided, also writes V4 FlashMLA sparse FP8 packed entries
736737
directly into fp8_cache, eliminating the need for a separate Python-side
737738
packing step.
738739
@@ -755,7 +756,7 @@ def fill_compressed_kv(
755756
phys_block = block_offsets[batch_id, block_idx] (physical block in kv_cache)
756757
write target: kv_cache[phys_block, block_off]
757758
758-
== FP8 MODEL1 sparse format ==
759+
== FP8 V4 FlashMLA sparse format ==
759760
When fp8_cache is not None, the kernel also writes to:
760761
fp8_cache: [num_blocks, entries_per_block, packed_dim=584]
761762
Per-token layout: [NoPE 448 FP8 | RoPE 128 BF16-as-bytes | 7 E8M0 scales | 1 pad]
@@ -806,7 +807,7 @@ def fill_compressed_kv(
806807
kv_scale_cache = dummy
807808

808809
if has_fp8:
809-
# FlashMLA MODEL1 sparse FP8 layout: the fp8_cache tensor is
810+
# V4 FlashMLA sparse FP8 layout: the fp8_cache tensor is
810811
# [num_blocks, entries_per_block, 584] but the actual memory layout
811812
# has NoPE+RoPE at stride 576 bytes per token, with scales in a
812813
# separate region. We create three views matching FlashMLA's addressing:

lmdeploy/pytorch/kernels/cuda/v4_flatten_kv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def flatten_v4_kv(
121121
cu_seqlens_k: optional [bsz+1] int32 cumulative KV sequence lengths.
122122
If None, computed from kv_seqlens.
123123
fp8_compressed_kv_cache: optional [num_blocks, entries_per_block, 584]
124-
FP8 MODEL1 sparse paged cache. When provided and compressed_kv_cache
124+
FP8 V4 FlashMLA sparse paged cache. When provided and compressed_kv_cache
125125
is None, the FP8 cache is dequantized to a temporary BF16 tensor
126126
and used instead.
127127
slot: optional [bsz] int64 slot indices into the global
@@ -136,9 +136,9 @@ def flatten_v4_kv(
136136
"""
137137
# If FP8 cache is provided and no BF16 cache, dequantize first
138138
if fp8_compressed_kv_cache is not None and compressed_kv_cache is None:
139-
from lmdeploy.pytorch.backends.cuda.attention.flashmla_utils import dequantize_model1_fp8_sparse
139+
from lmdeploy.pytorch.kernels.cuda.dsv4.layout import dequantize_v4_flashmla_sparse
140140
# fp8_cache is [num_blocks, entries, 584]; dequantize expects [num_blocks, entries, 1, 584]
141-
dequant = dequantize_model1_fp8_sparse(
141+
dequant = dequantize_v4_flashmla_sparse(
142142
fp8_compressed_kv_cache.unsqueeze(2)).squeeze(2) # [num_blocks, entries, 512]
143143
# Clone to decouple from FP8 cache views. No synchronize() needed —
144144
# same-stream kernel launches are ordered, so the Triton flatten kernel

lmdeploy/pytorch/kernels/cuda/v4_pack_window.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
"""Triton kernel to pack BF16 tokens into FlashMLA MODEL1 sparse FP8 flat-
2+
"""Triton kernel to pack BF16 tokens into V4 FlashMLA sparse FP8 flat-
33
layout window cache, replacing the per-token Python loop in
44
_pack_window_state_tokens.
55
6-
FlashMLA MODEL1 flat layout per slot (viewed as flat bytes):
6+
V4 FlashMLA flat layout per slot (viewed as flat bytes):
77
[token_0 NoPE+RoPE | token_1 NoPE+RoPE | ... | token_0 scales | token_1 scales | ...]
88
NoPE+RoPE per token = 576 bytes (448 e4m3fn + 128 bf16)
99
Scales per token = 8 bytes (7 e8m0fnu + 1 padding)
@@ -92,19 +92,19 @@ def pack_window_tokens_fp8(
9292
slot: torch.Tensor,
9393
positions: torch.Tensor,
9494
):
95-
"""Pack BF16 tokens into FlashMLA MODEL1 sparse FP8 window cache.
95+
"""Pack BF16 tokens into V4 FlashMLA sparse FP8 window cache.
9696
9797
Args:
9898
kv_tokens: [num_tokens, 512] BF16 tokens to pack.
9999
window_state_fp8_cache: [num_total_slots, window_size, packed_dim] FP8 cache.
100100
slot: [num_tokens] slot indices (which cache row to write to).
101101
positions: [num_tokens] ring-buffer positions within the window.
102102
"""
103-
from lmdeploy.pytorch.backends.cuda.attention.flashmla_utils import (
104-
MODEL1_D_NOPE,
105-
MODEL1_D_ROPE,
106-
MODEL1_NUM_TILES,
107-
MODEL1_TILE_SIZE,
103+
from lmdeploy.pytorch.kernels.cuda.dsv4.layout import (
104+
V4_FLASHMLA_D_NOPE,
105+
V4_FLASHMLA_D_ROPE,
106+
V4_FLASHMLA_NUM_TILES,
107+
V4_FLASHMLA_TILE_SIZE,
108108
)
109109

110110
assert kv_tokens.dim() == 2
@@ -113,7 +113,7 @@ def pack_window_tokens_fp8(
113113
return
114114

115115
window_size = window_state_fp8_cache.size(1)
116-
nope_rope_stride = MODEL1_D_NOPE + 2 * MODEL1_D_ROPE # 576 bytes per token in NoPE+RoPE region
116+
nope_rope_stride = V4_FLASHMLA_D_NOPE + 2 * V4_FLASHMLA_D_ROPE # 576 bytes per token in NoPE+RoPE region
117117
num_slots = window_state_fp8_cache.size(0)
118118

119119
# Create three views of the same FP8 cache buffer (same pattern as fill_compressed_kv)
@@ -122,16 +122,16 @@ def pack_window_tokens_fp8(
122122
# NoPE+RoPE region: [num_slots, window_size * 576] as e4m3fn
123123
nope_rope = flat[:, :window_size * nope_rope_stride].view(
124124
num_slots, window_size, nope_rope_stride)
125-
nope_view = nope_rope[:, :, :MODEL1_D_NOPE] # [num_slots, window_size, 448] e4m3fn
125+
nope_view = nope_rope[:, :, :V4_FLASHMLA_D_NOPE] # [num_slots, window_size, 448] e4m3fn
126126

127127
# RoPE region: slice the RoPE part first (128 e4m3fn bytes = 64 bf16 elements),
128-
# then view as bf16 — same pattern as quantize_model1_fp8_sparse
129-
rope_e4 = nope_rope[:, :, MODEL1_D_NOPE:] # [num_slots, window_size, 128] e4m3fn
128+
# then view as bf16 — same pattern as quantize_v4_flashmla_sparse
129+
rope_e4 = nope_rope[:, :, V4_FLASHMLA_D_NOPE:] # [num_slots, window_size, 128] e4m3fn
130130
rope_view = rope_e4.view(torch.bfloat16) # [num_slots, window_size, 64] bf16
131131

132132
# Scale region: uint8 view
133133
scale_view = flat[:, window_size * nope_rope_stride:].view(
134-
num_slots, window_size, 8)[:, :, :MODEL1_NUM_TILES].view(torch.uint8)
134+
num_slots, window_size, 8)[:, :, :V4_FLASHMLA_NUM_TILES].view(torch.uint8)
135135

136136
grid = (num_tokens,)
137137
_pack_window_tokens_fp8_kernel[grid](
@@ -151,8 +151,8 @@ def pack_window_tokens_fp8(
151151
stride_scale_pos=scale_view.stride(1),
152152
stride_slot=1,
153153
WINDOW_SIZE=window_size,
154-
D_NOPE=MODEL1_D_NOPE,
155-
D_ROPE=MODEL1_D_ROPE,
156-
TILE_SIZE=MODEL1_TILE_SIZE,
157-
NUM_TILES=MODEL1_NUM_TILES,
154+
D_NOPE=V4_FLASHMLA_D_NOPE,
155+
D_ROPE=V4_FLASHMLA_D_ROPE,
156+
TILE_SIZE=V4_FLASHMLA_TILE_SIZE,
157+
NUM_TILES=V4_FLASHMLA_NUM_TILES,
158158
)

tests/pytorch/kernel/test_v4_compressor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def test_decode_r128(self, kvlen, compress_ratio, head_dim, block_size, device,
749749
class TestFillCompressedKVFP8:
750750
"""Test FP8 direct write in fill_compressed_kv.
751751
752-
Verifies that the kernel's MODEL1 sparse FP8 output matches the Python reference (quantize_model1_fp8_sparse). Only
752+
Verifies that the kernel's V4 FlashMLA sparse FP8 output matches the Python reference (quantize_v4_flashmla_sparse). Only
753753
ratio=4 is tested since r128 has no FP8 cache.
754754
"""
755755

@@ -772,13 +772,13 @@ def _packed_token_dim(self):
772772
return self.D_NOPE + 2 * self.D_ROPE + self.NUM_TILES + 1 # 584
773773

774774
def _reference_pack_fp8(self, bf16_tokens):
775-
"""Pack BF16 tokens [N, 512] to MODEL1 FP8 using the Python
775+
"""Pack BF16 tokens [N, 512] to V4 FlashMLA FP8 using the Python
776776
reference."""
777-
from lmdeploy.pytorch.backends.cuda.attention.flashmla_utils import quantize_model1_fp8_sparse
778-
# quantize_model1_fp8_sparse expects [num_blocks, block_size, 1, 512]
777+
from .dsv4_utils import quantize_v4_flashmla_sparse
778+
# quantize_v4_flashmla_sparse expects [num_blocks, block_size, 1, 512]
779779
# For N tokens, treat as 1 block of N entries
780780
input_cache = bf16_tokens.unsqueeze(0).unsqueeze(2) # [1, N, 1, 512]
781-
packed = quantize_model1_fp8_sparse(input_cache) # [1, N, 1, 584]
781+
packed = quantize_v4_flashmla_sparse(input_cache) # [1, N, 1, 584]
782782
return packed.squeeze(0).squeeze(1) # [N, 584]
783783

784784
def _run_test(self, compressed_kv, cu_q_seqlens, kv_seqlens, block_offsets, device):
@@ -805,9 +805,9 @@ def _run_test(self, compressed_kv, cu_q_seqlens, kv_seqlens, block_offsets, devi
805805
fp8_cache=fp8_cache)
806806

807807
# Reference: dequantize FP8 cache and compare with BF16 cache
808-
from lmdeploy.pytorch.backends.cuda.attention.flashmla_utils import dequantize_model1_fp8_sparse
808+
from .dsv4_utils import dequantize_v4_flashmla_sparse
809809
# Dequantize all blocks
810-
dequant = dequantize_model1_fp8_sparse(
810+
dequant = dequantize_v4_flashmla_sparse(
811811
fp8_cache.unsqueeze(2)) # [num_blocks, entries_per_block, 1, 512]
812812
dequant = dequant.squeeze(2) # [num_blocks, entries_per_block, 512]
813813

0 commit comments

Comments
 (0)