Skip to content

Commit 7058419

Browse files
committed
remove repeated code
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 3d74cfc commit 7058419

1 file changed

Lines changed: 18 additions & 26 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,22 @@ def _is_enabled_quantizer(quantizer):
106106
return False
107107

108108

109+
def _save_component_state_dict_safetensors(
110+
component: nn.Module, component_export_dir: Path
111+
) -> None:
112+
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
113+
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
114+
with open(component_export_dir / "config.json", "w") as f:
115+
json.dump(
116+
{
117+
"_class_name": type(component).__name__,
118+
"_export_format": "safetensors_state_dict",
119+
},
120+
f,
121+
indent=4,
122+
)
123+
124+
109125
def _collect_shared_input_modules(
110126
model: nn.Module,
111127
dummy_forward_fn: Callable[[], None],
@@ -853,19 +869,7 @@ def _export_diffusers_checkpoint(
853869
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
854870
else:
855871
with hide_quantizers_from_state_dict(component):
856-
cpu_state_dict = {
857-
k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()
858-
}
859-
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
860-
with open(component_export_dir / "config.json", "w") as f:
861-
json.dump(
862-
{
863-
"_class_name": type(component).__name__,
864-
"_export_format": "safetensors_state_dict",
865-
},
866-
f,
867-
indent=4,
868-
)
872+
_save_component_state_dict_safetensors(component, component_export_dir)
869873

870874
# Step 7: Update config.json with quantization info
871875
if quant_config is not None:
@@ -882,19 +886,7 @@ def _export_diffusers_checkpoint(
882886
elif hasattr(component, "save_pretrained"):
883887
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
884888
else:
885-
cpu_state_dict = {
886-
k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()
887-
}
888-
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
889-
with open(component_export_dir / "config.json", "w") as f:
890-
json.dump(
891-
{
892-
"_class_name": type(component).__name__,
893-
"_export_format": "safetensors_state_dict",
894-
},
895-
f,
896-
indent=4,
897-
)
889+
_save_component_state_dict_safetensors(component, component_export_dir)
898890

899891
print(f" Saved to: {component_export_dir}")
900892

0 commit comments

Comments
 (0)