Skip to content

Commit 69107c0

Browse files
committed
Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2)
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 4f61dd7 commit 69107c0

1 file changed

Lines changed: 53 additions & 30 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import torch
3030
import torch.nn as nn
31-
from safetensors.torch import save_file, load_file, safe_open
31+
from safetensors.torch import load_file, safe_open, save_file
3232

3333
try:
3434
import diffusers
@@ -130,27 +130,38 @@ def _merge_diffusion_transformer_with_non_transformer_components(
130130
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
131131
safetensors metadata from the base checkpoint.
132132
"""
133-
134133
base_state = load_file(merged_base_safetensor_path)
135134

136135
non_transformer_prefixes = [
137-
'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.',
138-
'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.',
136+
"vae.",
137+
"audio_vae.",
138+
"vocoder.",
139+
"text_embedding_projection.",
140+
"text_encoders.",
141+
"first_stage_model.",
142+
"cond_stage_model.",
143+
"conditioner.",
139144
]
140-
correct_prefix = 'model.diffusion_model.'
141-
strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.']
145+
correct_prefix = "model.diffusion_model."
146+
strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."]
142147

143-
base_non_transformer = {k: v for k, v in base_state.items()
144-
if any(k.startswith(p) for p in non_transformer_prefixes)}
145-
base_connectors = {k: v for k, v in base_state.items()
146-
if 'embeddings_connector' in k and k.startswith(correct_prefix)}
148+
base_non_transformer = {
149+
k: v
150+
for k, v in base_state.items()
151+
if any(k.startswith(p) for p in non_transformer_prefixes)
152+
}
153+
base_connectors = {
154+
k: v
155+
for k, v in base_state.items()
156+
if "embeddings_connector" in k and k.startswith(correct_prefix)
157+
}
147158

148159
prefixed = {}
149160
for k, v in diffusion_transformer_state_dict.items():
150161
clean_k = k
151162
for prefix in strip_prefixes:
152163
if clean_k.startswith(prefix):
153-
clean_k = clean_k[len(prefix):]
164+
clean_k = clean_k[len(prefix) :]
154165
break
155166
prefixed[f"{correct_prefix}{clean_k}"] = v
156167

@@ -165,10 +176,10 @@ def _merge_diffusion_transformer_with_non_transformer_components(
165176

166177

167178
def _save_component_state_dict_safetensors(
168-
component: nn.Module,
169-
component_export_dir: Path,
170-
merged_base_safetensor_path: str | None = None,
171-
hf_quant_config: dict | None = None
179+
component: nn.Module,
180+
component_export_dir: Path,
181+
merged_base_safetensor_path: str | None = None,
182+
hf_quant_config: dict | None = None,
172183
) -> None:
173184
"""Save component state dict as safetensors with optional base checkpoint merge.
174185
@@ -184,10 +195,12 @@ def _save_component_state_dict_safetensors(
184195
metadata: dict[str, str] = {}
185196
metadata_full: dict[str, str] = {}
186197
if merged_base_safetensor_path is not None:
187-
cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components(
188-
cpu_state_dict, merged_base_safetensor_path
198+
cpu_state_dict, metadata_full = (
199+
_merge_diffusion_transformer_with_non_transformer_components(
200+
cpu_state_dict, merged_base_safetensor_path
201+
)
189202
)
190-
metadata["_export_format"] = "safetensors_state_dict"
203+
metadata["_export_format"] = "safetensors_state_dict"
191204
metadata["_class_name"] = type(component).__name__
192205

193206
if hf_quant_config is not None:
@@ -197,20 +210,26 @@ def _save_component_state_dict_safetensors(
197210
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
198211
layer_metadata = {}
199212
for k in cpu_state_dict:
200-
if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"):
213+
if k.endswith((".weight_scale", ".weight_scale_2")):
201214
layer_name = k.rsplit(".", 1)[0]
202215
if layer_name.endswith(".weight"):
203216
layer_name = layer_name.rsplit(".", 1)[0]
204217
if layer_name not in layer_metadata:
205218
layer_metadata[layer_name] = {"format": quant_algo}
206-
metadata_full["_quantization_metadata"] = json.dumps({
207-
"format_version": "1.0",
208-
"layers": layer_metadata,
209-
})
219+
metadata_full["_quantization_metadata"] = json.dumps(
220+
{
221+
"format_version": "1.0",
222+
"layers": layer_metadata,
223+
}
224+
)
210225

211226
metadata_full.update(metadata)
212-
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None)
213-
227+
save_file(
228+
cpu_state_dict,
229+
str(component_export_dir / "model.safetensors"),
230+
metadata=metadata_full if merged_base_safetensor_path is not None else None,
231+
)
232+
214233
with open(component_export_dir / "config.json", "w") as f:
215234
json.dump(metadata, f, indent=4)
216235

@@ -971,7 +990,7 @@ def _export_diffusers_checkpoint(
971990
# Step 5: Build quantization config
972991
quant_config = get_quant_config(component, is_modelopt_qlora=False)
973992
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None
974-
993+
975994
# Step 6: Save the component
976995
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
977996
# - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save
@@ -981,8 +1000,8 @@ def _export_diffusers_checkpoint(
9811000
else:
9821001
with hide_quantizers_from_state_dict(component):
9831002
_save_component_state_dict_safetensors(
984-
component,
985-
component_export_dir,
1003+
component,
1004+
component_export_dir,
9861005
merged_base_safetensor_path,
9871006
hf_quant_config,
9881007
)
@@ -999,7 +1018,9 @@ def _export_diffusers_checkpoint(
9991018
elif hasattr(component, "save_pretrained"):
10001019
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
10011020
else:
1002-
_save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path)
1021+
_save_component_state_dict_safetensors(
1022+
component, component_export_dir, merged_base_safetensor_path
1023+
)
10031024

10041025
print(f" Saved to: {component_export_dir}")
10051026

@@ -1108,7 +1129,9 @@ def export_hf_checkpoint(
11081129
if HAS_DIFFUSERS:
11091130
is_diffusers_obj = is_diffusers_object(model)
11101131
if is_diffusers_obj:
1111-
_export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path)
1132+
_export_diffusers_checkpoint(
1133+
model, dtype, export_dir, components, merged_base_safetensor_path
1134+
)
11121135
return
11131136

11141137
# Transformers model export

0 commit comments

Comments
 (0)