Skip to content

Commit 08ed263

Browse files
committed
The DeepSeek-V4-Pro experts module can now be correctly dequantized to BF16.
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 1693760 commit 08ed263

2 files changed

Lines changed: 249 additions & 7 deletions

File tree

gptqmodel/utils/model_dequant.py

Lines changed: 144 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,25 @@
3636
_FLOAT8_DTYPES = available_float8_dtypes()
3737
_FLOAT8_FORMAT_NAMES = frozenset(available_float8_dtype_names())
3838
_NVFP4_STORAGE_DTYPES = (torch.uint8, *available_float4_packed_dtypes())
39+
_DEEPSEEK_V4_FP4_BLOCK_SIZE = 32
40+
_DEEPSEEK_V4_FP4_TABLE = (
41+
0.0,
42+
0.5,
43+
1.0,
44+
1.5,
45+
2.0,
46+
3.0,
47+
4.0,
48+
6.0,
49+
0.0,
50+
-0.5,
51+
-1.0,
52+
-1.5,
53+
-2.0,
54+
-3.0,
55+
-4.0,
56+
-6.0,
57+
)
3958

4059
if TYPE_CHECKING:
4160
from compressed_tensors.compressors.base import BaseCompressor
@@ -79,6 +98,76 @@ def finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.
7998
return tensor_cpu
8099

81100

101+
def _is_deepseek_v4_routed_expert_weight_key(
102+
key: str,
103+
*,
104+
model_type: Optional[str],
105+
) -> bool:
106+
return (
107+
str(model_type or "").strip().lower() == "deepseek_v4"
108+
and key.endswith(".weight")
109+
and ".experts." in key
110+
and ".shared_experts." not in key
111+
)
112+
113+
114+
def dequantize_deepseek_v4_fp4_expert(
115+
tensor: torch.Tensor,
116+
scale: torch.Tensor,
117+
*,
118+
target_dtype: torch.dtype = torch.bfloat16,
119+
row_chunk_size: int = 256,
120+
) -> torch.Tensor:
121+
"""Dequantize DeepSeek-V4 routed expert E2M1-FP4 weights.
122+
123+
DeepSeek-V4-Pro stores routed experts as two FP4 nibbles packed in each I8
124+
element and per-row, per-32-logical-column E8M0 scales. This is distinct
125+
from the torchao NVFP4 layout used by other checkpoints.
126+
"""
127+
128+
if tensor.dtype != torch.int8:
129+
raise ValueError(
130+
f"DeepSeek-V4 FP4 expert weights must be int8, got {tensor.dtype}"
131+
)
132+
if tensor.ndim != 2 or scale.ndim != 2:
133+
raise ValueError("DeepSeek-V4 FP4 expert weight and scale tensors must be 2D")
134+
135+
out_dim, packed_in_dim = tensor.shape
136+
logical_in_dim = packed_in_dim * 2
137+
if logical_in_dim % _DEEPSEEK_V4_FP4_BLOCK_SIZE != 0:
138+
raise ValueError(
139+
f"DeepSeek-V4 FP4 logical input dim {logical_in_dim} must be divisible by "
140+
f"{_DEEPSEEK_V4_FP4_BLOCK_SIZE}"
141+
)
142+
expected_scale_shape = (out_dim, logical_in_dim // _DEEPSEEK_V4_FP4_BLOCK_SIZE)
143+
if tuple(scale.shape) != expected_scale_shape:
144+
raise ValueError(
145+
f"DeepSeek-V4 FP4 scale shape {tuple(scale.shape)} does not match expected "
146+
f"{expected_scale_shape} for weight shape {tuple(tensor.shape)}"
147+
)
148+
149+
table = torch.tensor(_DEEPSEEK_V4_FP4_TABLE, dtype=torch.float32, device=tensor.device)
150+
result = torch.empty((out_dim, logical_in_dim), dtype=target_dtype, device=tensor.device)
151+
row_chunk_size = max(1, int(row_chunk_size))
152+
153+
for start in range(0, out_dim, row_chunk_size):
154+
end = min(start + row_chunk_size, out_dim)
155+
packed = tensor[start:end].view(torch.uint8)
156+
low = packed & 0x0F
157+
high = (packed >> 4) & 0x0F
158+
codes = torch.stack((low, high), dim=-1).reshape(
159+
end - start,
160+
logical_in_dim,
161+
).long()
162+
scale_expanded = scale[start:end].to(torch.float32).repeat_interleave(
163+
_DEEPSEEK_V4_FP4_BLOCK_SIZE,
164+
dim=1,
165+
)
166+
result[start:end] = (table[codes] * scale_expanded).to(target_dtype)
167+
168+
return result
169+
170+
82171
def normalize_device(device: Optional[str]) -> Optional[str]:
83172
if device is None:
84173
return None
@@ -799,6 +888,7 @@ def convert_fp8_shard(
799888
target_dtype: torch.dtype,
800889
*,
801890
block_shape: Optional[Tuple[int, int]],
891+
model_type: Optional[str] = None,
802892
scale_semantics: str = "heuristic",
803893
tensor_lookup: Optional[_ShardTensorLookup] = None,
804894
ignored_layers: Iterable[str] = (),
@@ -814,7 +904,30 @@ def convert_fp8_shard(
814904
if _tensor_key_matches_ignored_layer(key, ignored_layers):
815905
continue
816906

817-
if key.endswith(".weight") and tensor.dtype in _FLOAT8_DTYPES:
907+
if (
908+
_is_deepseek_v4_routed_expert_weight_key(key, model_type=model_type)
909+
and tensor.dtype == torch.int8
910+
):
911+
scale_key = key[:-len(".weight")] + ".scale"
912+
if tensor_lookup is None or not tensor_lookup.has_tensor(
913+
scale_key, local_reader=reader, local_keys=reader_keys
914+
):
915+
raise KeyError(f"Missing DeepSeek-V4 FP4 expert scale tensor for {key}")
916+
scale = tensor_lookup.get_tensor(
917+
scale_key, local_reader=reader, local_keys=reader_keys
918+
)
919+
LOG.debug(
920+
"Using scale tensor '%s' for DeepSeek-V4 FP4 expert weight '%s'",
921+
scale_key,
922+
key,
923+
)
924+
deq = dequantize_deepseek_v4_fp4_expert(
925+
tensor,
926+
scale,
927+
target_dtype=target_dtype,
928+
)
929+
tensors[key] = finalize_for_save(deq, target_dtype)
930+
elif key.endswith(".weight") and tensor.dtype in _FLOAT8_DTYPES:
818931
scale_key = key + "_scale_inv"
819932
scale_tensor = None
820933
scale_inv = None
@@ -923,15 +1036,38 @@ def convert_fp8_shard(
9231036
weight_tensor = tensor_lookup.get_tensor(
9241037
weight_key, local_reader=reader, local_keys=reader_keys
9251038
)
926-
if weight_tensor.dtype in _FLOAT8_DTYPES:
1039+
if weight_tensor.dtype in _FLOAT8_DTYPES or (
1040+
weight_tensor.dtype == torch.int8
1041+
and _is_deepseek_v4_routed_expert_weight_key(
1042+
weight_key,
1043+
model_type=model_type,
1044+
)
1045+
):
9271046
# Mirror the `_scale_inv` handling so exported BF16 checkpoints
928-
# keep only dense weights, not FP8 reconstruction metadata.
929-
LOG.debug("Dropping auxiliary FP8 tensor '%s' after dequantization", key)
1047+
# keep only dense weights, not FP8/FP4 reconstruction metadata.
1048+
LOG.debug(
1049+
"Dropping auxiliary quantization tensor '%s' after dequantization",
1050+
key,
1051+
)
1052+
continue
1053+
elif weight_key in reader_keys:
1054+
weight_tensor = reader.get_tensor(weight_key)
1055+
should_drop_scale = weight_tensor.dtype in _FLOAT8_DTYPES or (
1056+
weight_tensor.dtype == torch.int8
1057+
and _is_deepseek_v4_routed_expert_weight_key(
1058+
weight_key,
1059+
model_type=model_type,
1060+
)
1061+
)
1062+
if not should_drop_scale:
1063+
tensors[key] = finalize_for_save(tensor, target_dtype)
9301064
continue
931-
elif weight_key in reader_keys and reader.get_tensor(weight_key).dtype in _FLOAT8_DTYPES:
9321065
# Mirror the `_scale_inv` handling so exported BF16 checkpoints
933-
# keep only dense weights, not FP8 reconstruction metadata.
934-
LOG.debug("Dropping auxiliary FP8 tensor '%s' after dequantization", key)
1066+
# keep only dense weights, not FP8/FP4 reconstruction metadata.
1067+
LOG.debug(
1068+
"Dropping auxiliary quantization tensor '%s' after dequantization",
1069+
key,
1070+
)
9351071
continue
9361072
tensors[key] = finalize_for_save(tensor, target_dtype)
9371073
else:
@@ -1361,6 +1497,7 @@ def dequantize_model(
13611497
reader,
13621498
target_dtype,
13631499
block_shape=block_shape,
1500+
model_type=config.get("model_type"),
13641501
scale_semantics=fp8_scale_semantics,
13651502
tensor_lookup=tensor_lookup,
13661503
ignored_layers=ignored_layers,

tests/test_model_dequant_fp8.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,108 @@ def test_dequantize_model_fp8_honors_ignored_layers(tmp_path):
278278
expected_quant = quant_weight.to(torch.bfloat16) / quant_scale_inv.to(torch.bfloat16)
279279
torch.testing.assert_close(quant_out, expected_quant)
280280
torch.testing.assert_close(ignored_out, ignored_weight)
281+
282+
283+
@pytest.mark.skipif(
284+
not hasattr(torch, "float8_e8m0fnu"),
285+
reason="float8_e8m0fnu dtype not available",
286+
)
287+
def test_dequantize_model_fp8_dequantizes_deepseek_v4_packed_experts(tmp_path):
288+
model_dir = tmp_path / "deepseek_v4_fp4_experts"
289+
output_dir = tmp_path / "deepseek_v4_fp4_experts_out"
290+
model_dir.mkdir()
291+
292+
config = {
293+
"architectures": ["DeepseekV4ForCausalLM"],
294+
"model_type": "deepseek_v4",
295+
"expert_dtype": "fp4",
296+
"quantization_config": {
297+
"quant_method": "fp8",
298+
"fmt": "e4m3",
299+
"scale_fmt": "ue8m0",
300+
"weight_block_size": [128, 128],
301+
},
302+
}
303+
(model_dir / "config.json").write_text(json.dumps(config), encoding="utf-8")
304+
305+
# Logical FP4 codes 0..15 repeated once, packed as low/high nibbles.
306+
packed_bytes = torch.tensor(
307+
[[lo | (hi << 4) for lo, hi in zip(range(0, 16, 2), range(1, 16, 2))] * 2],
308+
dtype=torch.uint8,
309+
)
310+
weight = packed_bytes.view(torch.int8)
311+
scale = torch.tensor([[2.0]], dtype=torch.float32).to(torch.float8_e8m0fnu)
312+
313+
weight_key = "layers.0.ffn.experts.0.w1.weight"
314+
scale_key = "layers.0.ffn.experts.0.w1.scale"
315+
shard_name = "model.safetensors"
316+
save_file({weight_key: weight, scale_key: scale}, str(model_dir / shard_name))
317+
_write_index(model_dir, shard_name, [weight_key, scale_key])
318+
319+
dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu")
320+
321+
with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader:
322+
assert set(reader.keys()) == {weight_key}
323+
output = reader.get_tensor(weight_key)
324+
325+
fp4_table = torch.tensor(
326+
[
327+
0.0,
328+
0.5,
329+
1.0,
330+
1.5,
331+
2.0,
332+
3.0,
333+
4.0,
334+
6.0,
335+
0.0,
336+
-0.5,
337+
-1.0,
338+
-1.5,
339+
-2.0,
340+
-3.0,
341+
-4.0,
342+
-6.0,
343+
],
344+
dtype=torch.float32,
345+
)
346+
expected = (fp4_table.repeat(2).view(1, 32) * 2.0).to(torch.bfloat16)
347+
assert output.dtype is torch.bfloat16
348+
torch.testing.assert_close(output, expected)
349+
350+
351+
@pytest.mark.skipif(
352+
not hasattr(torch, "float8_e8m0fnu"),
353+
reason="float8_e8m0fnu dtype not available",
354+
)
355+
def test_dequantize_model_fp8_does_not_treat_other_models_as_deepseek_v4(tmp_path):
356+
model_dir = tmp_path / "non_deepseek_v4_fp8"
357+
output_dir = tmp_path / "non_deepseek_v4_fp8_out"
358+
model_dir.mkdir()
359+
360+
config = {
361+
"architectures": ["TestModel"],
362+
"model_type": "not_deepseek_v4",
363+
"quantization_config": {
364+
"quant_method": "fp8",
365+
"fmt": "e4m3",
366+
"scale_fmt": "ue8m0",
367+
"weight_block_size": [128, 128],
368+
},
369+
}
370+
(model_dir / "config.json").write_text(json.dumps(config), encoding="utf-8")
371+
372+
weight_key = "layers.0.ffn.experts.0.w1.weight"
373+
scale_key = "layers.0.ffn.experts.0.w1.scale"
374+
weight = torch.zeros((1, 16), dtype=torch.int8)
375+
scale = torch.ones((1, 1), dtype=torch.float32).to(torch.float8_e8m0fnu)
376+
shard_name = "model.safetensors"
377+
save_file({weight_key: weight, scale_key: scale}, str(model_dir / shard_name))
378+
_write_index(model_dir, shard_name, [weight_key, scale_key])
379+
380+
dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu")
381+
382+
with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader:
383+
assert set(reader.keys()) == {weight_key, scale_key}
384+
assert reader.get_tensor(weight_key).dtype is torch.int8
385+
assert reader.get_tensor(scale_key).dtype is torch.bfloat16

0 commit comments

Comments
 (0)