Skip to content

Commit 54007b2

Browse files
committed
fix little bugs
1 parent 4f64822 commit 54007b2

4 files changed

Lines changed: 134 additions & 32 deletions

File tree

lmdeploy/pytorch/backends/cuda/graph_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def __init__(self, model: torch.nn.Module, model_config: ModelConfig, cache_conf
162162
backend_config: BackendConfig, device: torch.device):
163163
super().__init__(model, model_config, cache_config, backend_config, device)
164164
self.max_batches = cache_config.max_batches
165-
self.max_tokens = cache_config.max_prefill_token_num
166165
self.num_blocks = cache_config.num_gpu_blocks
167166

168167
self.enable_graph = self.check_enable_graph()

lmdeploy/pytorch/configurations/deepseek_v4.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,6 @@ def _check_env_v4(device: str = 'cuda'):
3232
except ImportError as e:
3333
raise ImportError('DeepSeek-V4 requires <fast_hadamard_transform> to be installed.') from e
3434

35-
if not hasattr(torch, 'float4_e2m1fn_x2'):
36-
raise RuntimeError('DeepSeek-V4 requires PyTorch with float4_e2m1fn_x2 support.')
37-
3835

3936
def _finalize_v4_cache_specs(model_config: ModelConfig, block_size: int):
4037
if block_size < 128:

lmdeploy/pytorch/nn/moe/v4_fp4.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,6 @@
1010
from .base import split_size as _split_size
1111

1212

13-
def _v4_swiglu(intermediate: torch.Tensor, swiglu_limit: float) -> torch.Tensor:
14-
"""Match DeepSeek-V4 routed-expert activation semantics.
15-
16-
Keep the activation hook in `nn/moe` so the V4 fused MoE wrapper does not depend on the legacy CUDA backend
17-
implementation file.
18-
"""
19-
hidden = intermediate.size(-1) // 2
20-
gate = intermediate[..., :hidden].float()
21-
up = intermediate[..., hidden:].float()
22-
if swiglu_limit > 0:
23-
up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit)
24-
gate = torch.clamp(gate, max=swiglu_limit)
25-
return (torch.nn.functional.silu(gate) * up).to(intermediate.dtype)
26-
27-
28-
def _get_v4_moe_runtime_kind(device: torch.device) -> str:
29-
"""Select the routed-expert runtime path for the current GPU.
30-
31-
CUDA uses lmdeploy's Triton FP8xFP4 MoE path, which keeps checkpoint-native packed FP4 expert weights resident and
32-
unpacks them inside the GEMM kernel.
33-
"""
34-
if device.type == 'cuda' and torch.cuda.is_available():
35-
return 'triton_fp4'
36-
raise RuntimeError('DeepSeek-V4 FP4 MoE requires CUDA because the expert weights stay in packed FP4 format.')
37-
38-
3913
class V4ExpertWeights(nn.Module):
4014
"""Local expert-sharded V4 expert weights.
4115
@@ -142,7 +116,6 @@ def __init__(self,
142116
self.ffn_dim = ffn_dim
143117
self.top_k = top_k
144118
self.block_size = 128
145-
self.runtime_kind = _get_v4_moe_runtime_kind(device)
146119

147120
self.gate_up = V4ExpertWeights(self.num_local_experts,
148121
hidden_dim,
@@ -283,7 +256,6 @@ def __init__(self,
283256
device: torch.device | None = None):
284257
super().__init__()
285258
device = device or torch.device('cpu')
286-
self.runtime_kind = _get_v4_moe_runtime_kind(device)
287259
dist_ctx = get_dist_manager().current_context()
288260
dist_config = dist_ctx.dist_config
289261
if dist_config.ep > 1:

tests/pytorch/kernel/dsv4_utils.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""Reference implementations for V4 FlashMLA sparse FP8 quantize/dequantize.
2+
3+
Used by kernel tests for correctness comparison only — the production path fuses these operations into Triton kernels.
4+
"""
5+
6+
import torch
7+
8+
from lmdeploy.pytorch.consts import (
9+
V4_FLASHMLA_D_NOPE,
10+
V4_FLASHMLA_D_ROPE,
11+
V4_FLASHMLA_NUM_TILES,
12+
V4_FLASHMLA_TILE_SIZE,
13+
)
14+
15+
D_NOPE = V4_FLASHMLA_D_NOPE # 448
16+
D_ROPE = V4_FLASHMLA_D_ROPE # 64
17+
TILE_SIZE = V4_FLASHMLA_TILE_SIZE # 64
18+
NUM_TILES = V4_FLASHMLA_NUM_TILES # 7
19+
NR_DIM = D_NOPE + 2 * D_ROPE # 576 bytes per token (NoPE + RoPE in e4m3fn)
20+
FP8_MAX = 448.0
21+
22+
23+
def quantize_v4_flashmla_sparse(input_k_cache: torch.Tensor) -> torch.Tensor:
24+
"""Pack BF16 ``[num_blocks, block_size, 1, 512]`` K cache into V4 FlashMLA
25+
sparse FP8 layout.
26+
27+
Returns ``[num_blocks, block_size, 1, 584]`` e4m3fn tensor.
28+
"""
29+
assert input_k_cache.dim() == 4
30+
num_blocks, block_size, _, head_dim = input_k_cache.shape
31+
assert head_dim == 512
32+
33+
device = input_k_cache.device
34+
packed_dim = NR_DIM + 8 # 576 + 8 = 584
35+
output = torch.zeros(num_blocks, block_size, 1, packed_dim,
36+
dtype=torch.float8_e4m3fn, device=device)
37+
38+
# Flat view for layout construction (same pattern as v4_compressor.py / v4_flatten_kv.py)
39+
flat_out = output.view(num_blocks, -1)
40+
41+
# NoPE+RoPE region: [num_blocks, block_size * NR_DIM] as e4m3fn
42+
nope_rope = flat_out[:, :block_size * NR_DIM].view(
43+
num_blocks, block_size, NR_DIM)
44+
nope_view = nope_rope[:, :, :D_NOPE] # [num_blocks, block_size, 448] e4m3fn
45+
46+
# RoPE region: view as bf16
47+
rope_e4 = nope_rope[:, :, D_NOPE:] # [num_blocks, block_size, 128] e4m3fn
48+
rope_view = rope_e4.view(torch.bfloat16) # [num_blocks, block_size, 64] bf16
49+
50+
# Scale region: uint8
51+
scale_view = flat_out[:, block_size * NR_DIM:].view(
52+
num_blocks, block_size, 8).view(torch.uint8)
53+
54+
# Per-block, per-token quantize
55+
for b in range(num_blocks):
56+
for t in range(block_size):
57+
token = input_k_cache[b, t, 0] # [512] bf16
58+
59+
# Quantize NoPE tiles
60+
for tile_idx in range(NUM_TILES):
61+
d_base = tile_idx * TILE_SIZE
62+
tile = token[d_base:d_base + TILE_SIZE].float()
63+
64+
amax = tile.abs().max()
65+
scale_inv = max(amax.item() / FP8_MAX, 1e-4)
66+
ceil_log2 = torch.ceil(torch.log2(torch.tensor(scale_inv, dtype=torch.float32)))
67+
scale_inv_pow2 = torch.exp2(ceil_log2)
68+
69+
quantized = (tile / scale_inv_pow2).to(torch.float8_e4m3fn)
70+
nope_view[b, t, d_base:d_base + TILE_SIZE] = quantized
71+
72+
# e8m0fnu scale byte: raw byte = ceil_log2 + 127
73+
scale_byte = int(ceil_log2.item() + 127)
74+
scale_view[b, t, tile_idx] = scale_byte
75+
76+
# RoPE: direct bf16 copy (128 e4m3fn bytes = 64 bf16 elements)
77+
rope_vals = token[D_NOPE:] # [64] bf16
78+
rope_view[b, t] = rope_vals
79+
80+
return output
81+
82+
83+
def dequantize_v4_flashmla_sparse(quant_k_cache: torch.Tensor) -> torch.Tensor:
84+
"""Dequantize V4 FlashMLA sparse FP8 K cache to BF16.
85+
86+
Re-exports from the production module for test convenience.
87+
88+
Args:
89+
quant_k_cache: [num_blocks, block_size, 1, 584] e4m3fn FP8 cache.
90+
91+
Returns:
92+
[num_blocks, block_size, 1, 512] BF16 cache.
93+
"""
94+
assert quant_k_cache.dim() == 4
95+
num_blocks, block_size, _, packed_dim = quant_k_cache.shape
96+
assert packed_dim == NR_DIM + 8
97+
98+
device = quant_k_cache.device
99+
output = torch.zeros(num_blocks, block_size, 1, 512,
100+
dtype=torch.bfloat16, device=device)
101+
102+
# Build views (same layout as quantize)
103+
flat = quant_k_cache.view(num_blocks, -1)
104+
nope_rope = flat[:, :block_size * NR_DIM].view(
105+
num_blocks, block_size, NR_DIM)
106+
nope_view = nope_rope[:, :, :D_NOPE] # [num_blocks, block_size, 448] e4m3fn
107+
108+
rope_e4 = nope_rope[:, :, D_NOPE:] # [num_blocks, block_size, 128] e4m3fn
109+
rope_view = rope_e4.view(torch.bfloat16) # [num_blocks, block_size, 64] bf16
110+
111+
scale_view = flat[:, block_size * NR_DIM:].view(
112+
num_blocks, block_size, 8).view(torch.uint8)
113+
114+
# Per-block, per-token dequantize
115+
for b in range(num_blocks):
116+
for t in range(block_size):
117+
# Dequantize NoPE tiles
118+
for tile_idx in range(NUM_TILES):
119+
d_base = tile_idx * TILE_SIZE
120+
nope_fp8 = nope_view[b, t, d_base:d_base + TILE_SIZE].float()
121+
122+
# Read scale byte and reconstruct float scale
123+
scale_byte = scale_view[b, t, tile_idx].item()
124+
# e8m0fnu: bits = scale_byte, float = 2^(scale_byte - 127)
125+
scale_bits = scale_byte << 23
126+
scale_f32 = torch.tensor(scale_bits, dtype=torch.int32).view(torch.float32)
127+
128+
dequant = (nope_fp8 * scale_f32).to(torch.bfloat16)
129+
output[b, t, 0, d_base:d_base + TILE_SIZE] = dequant
130+
131+
# RoPE: direct bf16 copy
132+
output[b, t, 0, D_NOPE:] = rope_view[b, t]
133+
134+
return output

0 commit comments

Comments
 (0)