3636_FLOAT8_DTYPES = available_float8_dtypes ()
3737_FLOAT8_FORMAT_NAMES = frozenset (available_float8_dtype_names ())
3838_NVFP4_STORAGE_DTYPES = (torch .uint8 , * available_float4_packed_dtypes ())
39+ _DEEPSEEK_V4_FP4_BLOCK_SIZE = 32
40+ _DEEPSEEK_V4_FP4_TABLE = (
41+ 0.0 ,
42+ 0.5 ,
43+ 1.0 ,
44+ 1.5 ,
45+ 2.0 ,
46+ 3.0 ,
47+ 4.0 ,
48+ 6.0 ,
49+ 0.0 ,
50+ - 0.5 ,
51+ - 1.0 ,
52+ - 1.5 ,
53+ - 2.0 ,
54+ - 3.0 ,
55+ - 4.0 ,
56+ - 6.0 ,
57+ )
3958
4059if TYPE_CHECKING :
4160 from compressed_tensors .compressors .base import BaseCompressor
@@ -79,6 +98,76 @@ def finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.
7998 return tensor_cpu
8099
81100
101+ def _is_deepseek_v4_routed_expert_weight_key (
102+ key : str ,
103+ * ,
104+ model_type : Optional [str ],
105+ ) -> bool :
106+ return (
107+ str (model_type or "" ).strip ().lower () == "deepseek_v4"
108+ and key .endswith (".weight" )
109+ and ".experts." in key
110+ and ".shared_experts." not in key
111+ )
112+
113+
114+ def dequantize_deepseek_v4_fp4_expert (
115+ tensor : torch .Tensor ,
116+ scale : torch .Tensor ,
117+ * ,
118+ target_dtype : torch .dtype = torch .bfloat16 ,
119+ row_chunk_size : int = 256 ,
120+ ) -> torch .Tensor :
121+ """Dequantize DeepSeek-V4 routed expert E2M1-FP4 weights.
122+
123+ DeepSeek-V4-Pro stores routed experts as two FP4 nibbles packed in each I8
124+ element and per-row, per-32-logical-column E8M0 scales. This is distinct
125+ from the torchao NVFP4 layout used by other checkpoints.
126+ """
127+
128+ if tensor .dtype != torch .int8 :
129+ raise ValueError (
130+ f"DeepSeek-V4 FP4 expert weights must be int8, got { tensor .dtype } "
131+ )
132+ if tensor .ndim != 2 or scale .ndim != 2 :
133+ raise ValueError ("DeepSeek-V4 FP4 expert weight and scale tensors must be 2D" )
134+
135+ out_dim , packed_in_dim = tensor .shape
136+ logical_in_dim = packed_in_dim * 2
137+ if logical_in_dim % _DEEPSEEK_V4_FP4_BLOCK_SIZE != 0 :
138+ raise ValueError (
139+ f"DeepSeek-V4 FP4 logical input dim { logical_in_dim } must be divisible by "
140+ f"{ _DEEPSEEK_V4_FP4_BLOCK_SIZE } "
141+ )
142+ expected_scale_shape = (out_dim , logical_in_dim // _DEEPSEEK_V4_FP4_BLOCK_SIZE )
143+ if tuple (scale .shape ) != expected_scale_shape :
144+ raise ValueError (
145+ f"DeepSeek-V4 FP4 scale shape { tuple (scale .shape )} does not match expected "
146+ f"{ expected_scale_shape } for weight shape { tuple (tensor .shape )} "
147+ )
148+
149+ table = torch .tensor (_DEEPSEEK_V4_FP4_TABLE , dtype = torch .float32 , device = tensor .device )
150+ result = torch .empty ((out_dim , logical_in_dim ), dtype = target_dtype , device = tensor .device )
151+ row_chunk_size = max (1 , int (row_chunk_size ))
152+
153+ for start in range (0 , out_dim , row_chunk_size ):
154+ end = min (start + row_chunk_size , out_dim )
155+ packed = tensor [start :end ].view (torch .uint8 )
156+ low = packed & 0x0F
157+ high = (packed >> 4 ) & 0x0F
158+ codes = torch .stack ((low , high ), dim = - 1 ).reshape (
159+ end - start ,
160+ logical_in_dim ,
161+ ).long ()
162+ scale_expanded = scale [start :end ].to (torch .float32 ).repeat_interleave (
163+ _DEEPSEEK_V4_FP4_BLOCK_SIZE ,
164+ dim = 1 ,
165+ )
166+ result [start :end ] = (table [codes ] * scale_expanded ).to (target_dtype )
167+
168+ return result
169+
170+
82171def normalize_device (device : Optional [str ]) -> Optional [str ]:
83172 if device is None :
84173 return None
@@ -799,6 +888,7 @@ def convert_fp8_shard(
799888 target_dtype : torch .dtype ,
800889 * ,
801890 block_shape : Optional [Tuple [int , int ]],
891+ model_type : Optional [str ] = None ,
802892 scale_semantics : str = "heuristic" ,
803893 tensor_lookup : Optional [_ShardTensorLookup ] = None ,
804894 ignored_layers : Iterable [str ] = (),
@@ -814,7 +904,30 @@ def convert_fp8_shard(
814904 if _tensor_key_matches_ignored_layer (key , ignored_layers ):
815905 continue
816906
817- if key .endswith (".weight" ) and tensor .dtype in _FLOAT8_DTYPES :
907+ if (
908+ _is_deepseek_v4_routed_expert_weight_key (key , model_type = model_type )
909+ and tensor .dtype == torch .int8
910+ ):
911+ scale_key = key [:- len (".weight" )] + ".scale"
912+ if tensor_lookup is None or not tensor_lookup .has_tensor (
913+ scale_key , local_reader = reader , local_keys = reader_keys
914+ ):
915+ raise KeyError (f"Missing DeepSeek-V4 FP4 expert scale tensor for { key } " )
916+ scale = tensor_lookup .get_tensor (
917+ scale_key , local_reader = reader , local_keys = reader_keys
918+ )
919+ LOG .debug (
920+ "Using scale tensor '%s' for DeepSeek-V4 FP4 expert weight '%s'" ,
921+ scale_key ,
922+ key ,
923+ )
924+ deq = dequantize_deepseek_v4_fp4_expert (
925+ tensor ,
926+ scale ,
927+ target_dtype = target_dtype ,
928+ )
929+ tensors [key ] = finalize_for_save (deq , target_dtype )
930+ elif key .endswith (".weight" ) and tensor .dtype in _FLOAT8_DTYPES :
818931 scale_key = key + "_scale_inv"
819932 scale_tensor = None
820933 scale_inv = None
@@ -923,15 +1036,38 @@ def convert_fp8_shard(
9231036 weight_tensor = tensor_lookup .get_tensor (
9241037 weight_key , local_reader = reader , local_keys = reader_keys
9251038 )
926- if weight_tensor .dtype in _FLOAT8_DTYPES :
1039+ if weight_tensor .dtype in _FLOAT8_DTYPES or (
1040+ weight_tensor .dtype == torch .int8
1041+ and _is_deepseek_v4_routed_expert_weight_key (
1042+ weight_key ,
1043+ model_type = model_type ,
1044+ )
1045+ ):
9271046 # Mirror the `_scale_inv` handling so exported BF16 checkpoints
928- # keep only dense weights, not FP8 reconstruction metadata.
929- LOG .debug ("Dropping auxiliary FP8 tensor '%s' after dequantization" , key )
1047+ # keep only dense weights, not FP8/FP4 reconstruction metadata.
1048+ LOG .debug (
1049+ "Dropping auxiliary quantization tensor '%s' after dequantization" ,
1050+ key ,
1051+ )
1052+ continue
1053+ elif weight_key in reader_keys :
1054+ weight_tensor = reader .get_tensor (weight_key )
1055+ should_drop_scale = weight_tensor .dtype in _FLOAT8_DTYPES or (
1056+ weight_tensor .dtype == torch .int8
1057+ and _is_deepseek_v4_routed_expert_weight_key (
1058+ weight_key ,
1059+ model_type = model_type ,
1060+ )
1061+ )
1062+ if not should_drop_scale :
1063+ tensors [key ] = finalize_for_save (tensor , target_dtype )
9301064 continue
931- elif weight_key in reader_keys and reader .get_tensor (weight_key ).dtype in _FLOAT8_DTYPES :
9321065 # Mirror the `_scale_inv` handling so exported BF16 checkpoints
933- # keep only dense weights, not FP8 reconstruction metadata.
934- LOG .debug ("Dropping auxiliary FP8 tensor '%s' after dequantization" , key )
1066+ # keep only dense weights, not FP8/FP4 reconstruction metadata.
1067+ LOG .debug (
1068+ "Dropping auxiliary quantization tensor '%s' after dequantization" ,
1069+ key ,
1070+ )
9351071 continue
9361072 tensors [key ] = finalize_for_save (tensor , target_dtype )
9371073 else :
@@ -1361,6 +1497,7 @@ def dequantize_model(
13611497 reader ,
13621498 target_dtype ,
13631499 block_shape = block_shape ,
1500+ model_type = config .get ("model_type" ),
13641501 scale_semantics = fp8_scale_semantics ,
13651502 tensor_lookup = tensor_lookup ,
13661503 ignored_layers = ignored_layers ,
0 commit comments