Skip to content

Commit 4f61dd7

Browse files
committed
Add support for export comfyui compatible checkpoint for diffusion model(e.g., LTX-2)
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 7c4c9fd commit 4f61dd7

1 file changed

Lines changed: 117 additions & 19 deletions

File tree

modelopt/torch/export/unified_export_hf.py

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import torch
3030
import torch.nn as nn
31-
from safetensors.torch import save_file
31+
from safetensors.torch import save_file, load_file, safe_open
3232

3333
try:
3434
import diffusers
@@ -111,20 +111,108 @@ def _is_enabled_quantizer(quantizer):
111111
return False
112112

113113

114+
def _merge_diffusion_transformer_with_non_transformer_components(
115+
diffusion_transformer_state_dict: dict[str, torch.Tensor],
116+
merged_base_safetensor_path: str,
117+
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
118+
"""Merge diffusion transformer weights with non-transformer components from a safetensors file.
119+
120+
Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are
121+
taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.'
122+
for ComfyUI compatibility.
123+
124+
Args:
125+
diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
126+
merged_base_safetensor_path: Path to the full base model safetensors file containing
127+
all components (transformer, VAE, vocoder, etc.).
128+
129+
Returns:
130+
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
131+
safetensors metadata from the base checkpoint.
132+
"""
133+
134+
base_state = load_file(merged_base_safetensor_path)
135+
136+
non_transformer_prefixes = [
137+
'vae.', 'audio_vae.', 'vocoder.', 'text_embedding_projection.',
138+
'text_encoders.', 'first_stage_model.', 'cond_stage_model.', 'conditioner.',
139+
]
140+
correct_prefix = 'model.diffusion_model.'
141+
strip_prefixes = ['diffusion_model.', 'transformer.', '_orig_mod.', 'model.', 'velocity_model.']
142+
143+
base_non_transformer = {k: v for k, v in base_state.items()
144+
if any(k.startswith(p) for p in non_transformer_prefixes)}
145+
base_connectors = {k: v for k, v in base_state.items()
146+
if 'embeddings_connector' in k and k.startswith(correct_prefix)}
147+
148+
prefixed = {}
149+
for k, v in diffusion_transformer_state_dict.items():
150+
clean_k = k
151+
for prefix in strip_prefixes:
152+
if clean_k.startswith(prefix):
153+
clean_k = clean_k[len(prefix):]
154+
break
155+
prefixed[f"{correct_prefix}{clean_k}"] = v
156+
157+
merged = dict(base_non_transformer)
158+
merged.update(base_connectors)
159+
merged.update(prefixed)
160+
with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
161+
base_metadata = f.metadata() or {}
162+
163+
del base_state
164+
return merged, base_metadata
165+
166+
114167
def _save_component_state_dict_safetensors(
115-
component: nn.Module, component_export_dir: Path
168+
component: nn.Module,
169+
component_export_dir: Path,
170+
merged_base_safetensor_path: str | None = None,
171+
hf_quant_config: dict | None = None
116172
) -> None:
173+
"""Save component state dict as safetensors with optional base checkpoint merge.
174+
175+
Args:
176+
component: The nn.Module to save.
177+
component_export_dir: Directory to save model.safetensors and config.json.
178+
merged_base_safetensor_path: If provided, merge with non-transformer components
179+
from this base safetensors file.
180+
hf_quant_config: If provided, embed quantization config in safetensors metadata
181+
and per-layer _quantization_metadata for ComfyUI.
182+
"""
117183
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
118-
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
119-
with open(component_export_dir / "config.json", "w") as f:
120-
json.dump(
121-
{
122-
"_class_name": type(component).__name__,
123-
"_export_format": "safetensors_state_dict",
124-
},
125-
f,
126-
indent=4,
184+
metadata: dict[str, str] = {}
185+
metadata_full: dict[str, str] = {}
186+
if merged_base_safetensor_path is not None:
187+
cpu_state_dict, metadata_full = _merge_diffusion_transformer_with_non_transformer_components(
188+
cpu_state_dict, merged_base_safetensor_path
127189
)
190+
metadata["_export_format"] = "safetensors_state_dict"
191+
metadata["_class_name"] = type(component).__name__
192+
193+
if hf_quant_config is not None:
194+
metadata_full["quantization_config"] = json.dumps(hf_quant_config)
195+
196+
# Build per-layer _quantization_metadata for ComfyUI
197+
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
198+
layer_metadata = {}
199+
for k in cpu_state_dict:
200+
if k.endswith(".weight_scale") or k.endswith(".weight_scale_2"):
201+
layer_name = k.rsplit(".", 1)[0]
202+
if layer_name.endswith(".weight"):
203+
layer_name = layer_name.rsplit(".", 1)[0]
204+
if layer_name not in layer_metadata:
205+
layer_metadata[layer_name] = {"format": quant_algo}
206+
metadata_full["_quantization_metadata"] = json.dumps({
207+
"format_version": "1.0",
208+
"layers": layer_metadata,
209+
})
210+
211+
metadata_full.update(metadata)
212+
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"), metadata=metadata_full if merged_base_safetensor_path is not None else None)
213+
214+
with open(component_export_dir / "config.json", "w") as f:
215+
json.dump(metadata, f, indent=4)
128216

129217

130218
def _collect_shared_input_modules(
@@ -807,6 +895,7 @@ def _export_diffusers_checkpoint(
807895
dtype: torch.dtype | None,
808896
export_dir: Path,
809897
components: list[str] | None,
898+
merged_base_safetensor_path: str | None = None,
810899
max_shard_size: int | str = "10GB",
811900
) -> None:
812901
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
@@ -821,6 +910,8 @@ def _export_diffusers_checkpoint(
821910
export_dir: The directory to save the exported checkpoint.
822911
components: Optional list of component names to export. Only used for pipelines.
823912
If None, all components are exported.
913+
merged_base_safetensor_path: If provided, merge the exported transformer with
914+
non-transformer components from this base safetensors file.
824915
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
825916
it will be sharded into multiple files and a .safetensors.index.json will be
826917
created. Use smaller values like "5GB" or "2GB" to force sharding.
@@ -879,7 +970,8 @@ def _export_diffusers_checkpoint(
879970

880971
# Step 5: Build quantization config
881972
quant_config = get_quant_config(component, is_modelopt_qlora=False)
882-
973+
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None
974+
883975
# Step 6: Save the component
884976
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
885977
# - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save
@@ -888,12 +980,14 @@ def _export_diffusers_checkpoint(
888980
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
889981
else:
890982
with hide_quantizers_from_state_dict(component):
891-
_save_component_state_dict_safetensors(component, component_export_dir)
892-
983+
_save_component_state_dict_safetensors(
984+
component,
985+
component_export_dir,
986+
merged_base_safetensor_path,
987+
hf_quant_config,
988+
)
893989
# Step 7: Update config.json with quantization info
894-
if quant_config is not None:
895-
hf_quant_config = convert_hf_quant_config_format(quant_config)
896-
990+
if hf_quant_config is not None:
897991
config_path = component_export_dir / "config.json"
898992
if config_path.exists():
899993
with open(config_path) as file:
@@ -905,7 +999,7 @@ def _export_diffusers_checkpoint(
905999
elif hasattr(component, "save_pretrained"):
9061000
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
9071001
else:
908-
_save_component_state_dict_safetensors(component, component_export_dir)
1002+
_save_component_state_dict_safetensors(component, component_export_dir, merged_base_safetensor_path)
9091003

9101004
print(f" Saved to: {component_export_dir}")
9111005

@@ -985,6 +1079,7 @@ def export_hf_checkpoint(
9851079
save_modelopt_state: bool = False,
9861080
components: list[str] | None = None,
9871081
extra_state_dict: dict[str, torch.Tensor] | None = None,
1082+
merged_base_safetensor_path: str | None = None,
9881083
):
9891084
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).
9901085
@@ -1002,6 +1097,9 @@ def export_hf_checkpoint(
10021097
components: Only used for diffusers pipelines. Optional list of component names
10031098
to export. If None, all quantized components are exported.
10041099
extra_state_dict: Extra state dictionary to add to the exported model.
1100+
merged_base_safetensor_path: If provided, merge the exported diffusion transformer
1101+
with non-transformer components (VAE, vocoder, etc.) from this base safetensors
1102+
file. Only used for diffusion model exports (e.g., LTX-2).
10051103
"""
10061104
export_dir = Path(export_dir)
10071105
export_dir.mkdir(parents=True, exist_ok=True)
@@ -1010,7 +1108,7 @@ def export_hf_checkpoint(
10101108
if HAS_DIFFUSERS:
10111109
is_diffusers_obj = is_diffusers_object(model)
10121110
if is_diffusers_obj:
1013-
_export_diffusers_checkpoint(model, dtype, export_dir, components)
1111+
_export_diffusers_checkpoint(model, dtype, export_dir, components, merged_base_safetensor_path)
10141112
return
10151113

10161114
# Transformers model export

0 commit comments

Comments
 (0)