Skip to content

Commit 110a44c

Browse files
authored
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2) (#911)
## What does this PR do Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2) **Type of change:** <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2) 1) Added a a parameter for merging the base vae, vocoder, connectors in the quantized checkpoint 2) storing quantization metadata and export tool as modelopt , required for ComfyUI compatibility. 3) Internally updating the transformer block prefixes to match the expectation of ComfyUI ## Usage <!-- You can potentially add a usage example below. --> ```python export_hf_checkpoint( pipeline, export_dir=EXPORT_DIR, merged_base_safetensor_path=BASE_CKPT, # merge VAE/vocoder from base ) ``` ## Testing <!-- Mention how have you tested your change if applicable. --> 1) Tested with ltx-2 model a) initializing a twoStagePipeline object b) calling mtq.quantize on transformer with NVFP4_DEFAULT_CFG c) then exporting with export_hf_checkpoint passing the param merged_base_safetensor_path to generate merged checkpoint 2) Ran the generated checkpoint with step1 on ComfyUI to validate 3) Ran step1 without merged_base_safetensor_path to check backward compatibility. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: NA - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: NA <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for exporting LTX-2 diffusion models with merged base checkpoint integration * Enhanced export functionality to preserve and attach quantization metadata during model export * Extended model export capabilities with automatic model type detection for improved export handling <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent 6f094d7 commit 110a44c

2 files changed

Lines changed: 222 additions & 17 deletions

File tree

modelopt/torch/export/diffusers_utils.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Code that export quantized Hugging Face models for deployment."""
1717

18+
import json
1819
import warnings
1920
from collections.abc import Callable
2021
from contextlib import contextmanager
@@ -23,6 +24,7 @@
2324

2425
import torch
2526
import torch.nn as nn
27+
from safetensors.torch import load_file, safe_open
2628

2729
from .layer_utils import is_quantlinear
2830

@@ -656,3 +658,146 @@ def infer_dtype_from_model(model: nn.Module) -> torch.dtype:
656658
for param in model.parameters():
657659
return param.dtype
658660
return torch.float16
661+
662+
663+
def _merge_ltx2(
664+
diffusion_transformer_state_dict: dict[str, torch.Tensor],
665+
merged_base_safetensor_path: str,
666+
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
667+
"""Merge LTX-2 transformer weights with non-transformer components.
668+
669+
Non-transformer components (VAE, vocoder, text encoders) and embeddings
670+
connectors are taken from the base checkpoint. Transformer keys are
671+
re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility.
672+
673+
Args:
674+
diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
675+
merged_base_safetensor_path: Path to the full base model safetensors file containing
676+
all components (transformer, VAE, vocoder, etc.).
677+
678+
Returns:
679+
Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
680+
safetensors metadata from the base checkpoint.
681+
"""
682+
base_state = load_file(merged_base_safetensor_path)
683+
684+
non_transformer_prefixes = [
685+
"vae.",
686+
"audio_vae.",
687+
"vocoder.",
688+
"text_embedding_projection.",
689+
"text_encoders.",
690+
"first_stage_model.",
691+
"cond_stage_model.",
692+
"conditioner.",
693+
]
694+
correct_prefix = "model.diffusion_model."
695+
strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."]
696+
697+
base_non_transformer = {
698+
k: v
699+
for k, v in base_state.items()
700+
if any(k.startswith(p) for p in non_transformer_prefixes)
701+
}
702+
base_connectors = {
703+
k: v
704+
for k, v in base_state.items()
705+
if "embeddings_connector" in k and k.startswith(correct_prefix)
706+
}
707+
708+
prefixed = {}
709+
for k, v in diffusion_transformer_state_dict.items():
710+
clean_k = k
711+
for prefix in strip_prefixes:
712+
if clean_k.startswith(prefix):
713+
clean_k = clean_k[len(prefix) :]
714+
break
715+
prefixed[f"{correct_prefix}{clean_k}"] = v
716+
717+
merged = dict(base_non_transformer)
718+
merged.update(base_connectors)
719+
merged.update(prefixed)
720+
with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f:
721+
base_metadata = f.metadata() or {}
722+
723+
del base_state
724+
return merged, base_metadata
725+
726+
727+
DIFFUSION_MERGE_FUNCTIONS: dict[str, Callable] = {
728+
"ltx2": _merge_ltx2,
729+
}
730+
731+
732+
def merge_diffusion_checkpoint(
733+
state_dict: dict[str, torch.Tensor],
734+
merged_base_safetensor_path: str,
735+
model_type: str,
736+
hf_quant_config: dict | None = None,
737+
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
738+
"""Merge transformer weights with a base checkpoint and build ComfyUI metadata.
739+
740+
Dispatches to the model-specific merge function in ``DIFFUSION_MERGE_FUNCTIONS``
741+
and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` and
742+
per-layer ``_quantization_metadata`` in the safetensors metadata for ComfyUI.
743+
744+
Args:
745+
state_dict: The transformer state dict (already on CPU).
746+
merged_base_safetensor_path: Path to the full base model ``.safetensors`` file
747+
containing all components (transformer, VAE, vocoder, etc.),
748+
e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
749+
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
750+
hf_quant_config: If provided, embed quantization config and per-layer
751+
``_quantization_metadata`` in the returned metadata dict.
752+
753+
Returns:
754+
Tuple of (merged_state_dict, metadata) where *metadata* is the base checkpoint's
755+
original metadata augmented with any quantization entries.
756+
"""
757+
merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type]
758+
merged_state_dict, metadata = merge_fn(state_dict, merged_base_safetensor_path)
759+
760+
if hf_quant_config is not None:
761+
metadata["quantization_config"] = json.dumps(hf_quant_config)
762+
763+
quant_algo = hf_quant_config.get("quant_algo", "unknown").lower()
764+
layer_metadata = {}
765+
for k in merged_state_dict:
766+
if k.endswith((".weight_scale", ".weight_scale_2")):
767+
layer_name = k.rsplit(".", 1)[0]
768+
if layer_name.endswith(".weight"):
769+
layer_name = layer_name.rsplit(".", 1)[0]
770+
if layer_name not in layer_metadata:
771+
layer_metadata[layer_name] = {"format": quant_algo}
772+
metadata["_quantization_metadata"] = json.dumps(
773+
{
774+
"format_version": "1.0",
775+
"layers": layer_metadata,
776+
}
777+
)
778+
779+
return merged_state_dict, metadata
780+
781+
782+
def get_diffusion_model_type(pipe: Any) -> str:
783+
"""Detect the diffusion model type for merge function dispatch.
784+
785+
To add a new model type, add a detection clause here and a corresponding
786+
merge function in ``DIFFUSION_MERGE_FUNCTIONS``.
787+
788+
Args:
789+
pipe: The pipeline or component being exported.
790+
791+
Returns:
792+
A string key into ``DIFFUSION_MERGE_FUNCTIONS``.
793+
794+
Raises:
795+
ValueError: If the model type is not supported.
796+
"""
797+
if TI2VidTwoStagesPipeline is not None and isinstance(pipe, TI2VidTwoStagesPipeline):
798+
return "ltx2"
799+
800+
raise ValueError(
801+
f"No merge function for model type '{type(pipe).__name__}'. "
802+
"Add an entry to DIFFUSION_MERGE_FUNCTIONS in diffusers_utils.py."
803+
)

modelopt/torch/export/unified_export_hf.py

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
from .diffusers_utils import (
3737
generate_diffusion_dummy_forward_fn,
3838
get_diffusion_components,
39+
get_diffusion_model_type,
3940
get_qkv_group_key,
4041
hide_quantizers_from_state_dict,
4142
infer_dtype_from_model,
4243
is_diffusers_object,
4344
is_qkv_projection,
45+
merge_diffusion_checkpoint,
4446
)
4547

4648
HAS_DIFFUSERS = True
@@ -116,20 +118,49 @@ def _is_enabled_quantizer(quantizer):
116118

117119

118120
def _save_component_state_dict_safetensors(
119-
component: nn.Module, component_export_dir: Path
121+
component: nn.Module,
122+
component_export_dir: Path,
123+
merged_base_safetensor_path: str | None = None,
124+
hf_quant_config: dict | None = None,
125+
model_type: str | None = None,
120126
) -> None:
127+
"""Save component state dict as safetensors with optional base checkpoint merge.
128+
129+
Args:
130+
component: The nn.Module to save.
131+
component_export_dir: Directory to save model.safetensors and config.json.
132+
merged_base_safetensor_path: If provided, merge the exported transformer weights
133+
with non-transformer components (VAE, vocoder, text encoders, etc.) from this
134+
base safetensors file and add quantization metadata to produce a single-file
135+
checkpoint compatible with ComfyUI. This should be the path to a full base
136+
model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
137+
hf_quant_config: If provided, embed quantization config in safetensors metadata
138+
and per-layer _quantization_metadata for ComfyUI.
139+
model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge.
140+
Required when ``merged_base_safetensor_path`` is not None.
141+
"""
121142
cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()}
122-
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
123-
with open(component_export_dir / "config.json", "w") as f:
124-
json.dump(
125-
{
126-
"_class_name": type(component).__name__,
127-
"_export_format": "safetensors_state_dict",
128-
},
129-
f,
130-
indent=4,
143+
metadata: dict[str, str] = {}
144+
metadata_full: dict[str, str] = {}
145+
146+
if merged_base_safetensor_path is not None and model_type is not None:
147+
cpu_state_dict, metadata_full = merge_diffusion_checkpoint(
148+
cpu_state_dict, merged_base_safetensor_path, model_type, hf_quant_config
131149
)
132150

151+
metadata["_export_format"] = "safetensors_state_dict"
152+
metadata["_class_name"] = type(component).__name__
153+
metadata_full.update(metadata)
154+
155+
save_file(
156+
cpu_state_dict,
157+
str(component_export_dir / "model.safetensors"),
158+
metadata=metadata_full if merged_base_safetensor_path is not None else None,
159+
)
160+
161+
with open(component_export_dir / "config.json", "w") as f:
162+
json.dump(metadata, f, indent=4)
163+
133164

134165
def _collect_shared_input_modules(
135166
model: nn.Module,
@@ -822,6 +853,7 @@ def _export_diffusers_checkpoint(
822853
dtype: torch.dtype | None,
823854
export_dir: Path,
824855
components: list[str] | None,
856+
merged_base_safetensor_path: str | None = None,
825857
max_shard_size: int | str = "10GB",
826858
) -> None:
827859
"""Internal: Export diffusion(-like) model/pipeline checkpoint.
@@ -836,6 +868,11 @@ def _export_diffusers_checkpoint(
836868
export_dir: The directory to save the exported checkpoint.
837869
components: Optional list of component names to export. Only used for pipelines.
838870
If None, all components are exported.
871+
merged_base_safetensor_path: If provided, merge the exported transformer weights
872+
with non-transformer components (VAE, vocoder, text encoders, etc.) from this
873+
base safetensors file and add quantization metadata to produce a single-file
874+
checkpoint compatible with ComfyUI. This should be the path to a full base
875+
model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``.
839876
max_shard_size: Maximum size of each shard file. If the model exceeds this size,
840877
it will be sharded into multiple files and a .safetensors.index.json will be
841878
created. Use smaller values like "5GB" or "2GB" to force sharding.
@@ -849,6 +886,9 @@ def _export_diffusers_checkpoint(
849886
warnings.warn("No exportable components found in the model.")
850887
return
851888

889+
# Resolve model type once (only needed when merging with a base checkpoint)
890+
model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None
891+
852892
# Separate nn.Module components for quantization-aware export
853893
module_components = {
854894
name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module)
@@ -894,6 +934,7 @@ def _export_diffusers_checkpoint(
894934

895935
# Step 5: Build quantization config
896936
quant_config = get_quant_config(component, is_modelopt_qlora=False)
937+
hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None
897938

898939
# Step 6: Save the component
899940
# - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
@@ -903,12 +944,15 @@ def _export_diffusers_checkpoint(
903944
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
904945
else:
905946
with hide_quantizers_from_state_dict(component):
906-
_save_component_state_dict_safetensors(component, component_export_dir)
907-
947+
_save_component_state_dict_safetensors(
948+
component,
949+
component_export_dir,
950+
merged_base_safetensor_path,
951+
hf_quant_config,
952+
model_type,
953+
)
908954
# Step 7: Update config.json with quantization info
909-
if quant_config is not None:
910-
hf_quant_config = convert_hf_quant_config_format(quant_config)
911-
955+
if hf_quant_config is not None:
912956
config_path = component_export_dir / "config.json"
913957
if config_path.exists():
914958
with open(config_path) as file:
@@ -920,7 +964,12 @@ def _export_diffusers_checkpoint(
920964
elif hasattr(component, "save_pretrained"):
921965
component.save_pretrained(component_export_dir, max_shard_size=max_shard_size)
922966
else:
923-
_save_component_state_dict_safetensors(component, component_export_dir)
967+
_save_component_state_dict_safetensors(
968+
component,
969+
component_export_dir,
970+
merged_base_safetensor_path,
971+
model_type=model_type,
972+
)
924973

925974
print(f" Saved to: {component_export_dir}")
926975

@@ -1044,6 +1093,7 @@ def export_hf_checkpoint(
10441093
save_modelopt_state: bool = False,
10451094
components: list[str] | None = None,
10461095
extra_state_dict: dict[str, torch.Tensor] | None = None,
1096+
**kwargs,
10471097
):
10481098
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).
10491099
@@ -1061,15 +1111,25 @@ def export_hf_checkpoint(
10611111
components: Only used for diffusers pipelines. Optional list of component names
10621112
to export. If None, all quantized components are exported.
10631113
extra_state_dict: Extra state dictionary to add to the exported model.
1114+
**kwargs: Internal-only keyword arguments. Supported key: merged_base_safetensor_path
1115+
(str, optional). When provided, merges the exported diffusion transformer
1116+
weights with non-transformer components (VAE, vocoder, text encoders, etc.)
1117+
from this base safetensors file to produce a single-file checkpoint
1118+
compatible with ComfyUI. Value should be the path to a full base model
1119+
``.safetensors`` file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``).
1120+
Only used for diffusion model exports.
10641121
"""
1122+
merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path")
10651123
export_dir = Path(export_dir)
10661124
export_dir.mkdir(parents=True, exist_ok=True)
10671125

10681126
is_diffusers_obj = False
10691127
if HAS_DIFFUSERS:
10701128
is_diffusers_obj = is_diffusers_object(model)
10711129
if is_diffusers_obj:
1072-
_export_diffusers_checkpoint(model, dtype, export_dir, components)
1130+
_export_diffusers_checkpoint(
1131+
model, dtype, export_dir, components, merged_base_safetensor_path
1132+
)
10731133
return
10741134

10751135
# Transformers model export

0 commit comments

Comments
 (0)