Skip to content

Commit c8f46fc

Browse files
Edwardf0t1kevalmorabia97
authored andcommitted
Remove quantization_config in config.json from original deepseek models (#753)
## What does this PR do? **Type of change:** Bug fix **Overview:** DeepSeek original checkpoints may include a `quantization_config` field in `config.json` (describing the source checkpoint's quantization). When we export ModelOpt quantization configs to `hf_quant_config.json`, leaving the original `quantization_config` in place can be confusing. Add a function to remove it. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information Resolve nvbug https://nvbugspro.nvidia.com/bug/5736665 --------- Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
1 parent 7e04df8 commit c8f46fc

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

examples/deepseek/quantize_to_nvfp4.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def _remap_key(key_dict: dict[str, Any]):
8282
key_dict.update(new_dict)
8383

8484

85+
def remove_quantization_config_from_original_config(export_dir: str) -> None:
86+
"""Remove `quantization_config` from exported HF `config.json`.
87+
88+
Assumes the exported checkpoint directory has a `config.json` containing `quantization_config`.
89+
"""
90+
config_path = os.path.join(export_dir, "config.json")
91+
with open(config_path) as f:
92+
cfg = json.load(f)
93+
del cfg["quantization_config"]
94+
with open(config_path, "w") as f:
95+
json.dump(cfg, f, indent=2, sort_keys=True)
96+
f.write("\n")
97+
98+
8599
def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
86100
state_dict_list = [
87101
torch.load(f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt")
@@ -302,3 +316,5 @@ def get_tensor(tensor_name):
302316
save_root=args.fp4_path,
303317
per_layer_quant_config=per_layer_quant_config,
304318
)
319+
320+
remove_quantization_config_from_original_config(args.fp4_path)

0 commit comments

Comments
 (0)