diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index a9bf138767..f09032be8a 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -15,6 +15,7 @@ """Code that export quantized Hugging Face models for deployment.""" +import json import warnings from collections.abc import Callable from contextlib import contextmanager @@ -23,6 +24,7 @@ import torch import torch.nn as nn +from safetensors.torch import load_file, safe_open from .layer_utils import is_quantlinear @@ -656,3 +658,146 @@ def infer_dtype_from_model(model: nn.Module) -> torch.dtype: for param in model.parameters(): return param.dtype return torch.float16 + + +def _merge_ltx2( + diffusion_transformer_state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge LTX-2 transformer weights with non-transformer components. + + Non-transformer components (VAE, vocoder, text encoders) and embeddings + connectors are taken from the base checkpoint. Transformer keys are + re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility. + + Args: + diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model safetensors file containing + all components (transformer, VAE, vocoder, etc.). + + Returns: + Tuple of (merged_state_dict, base_metadata) where base_metadata is the original + safetensors metadata from the base checkpoint. + """ + base_state = load_file(merged_base_safetensor_path) + + non_transformer_prefixes = [ + "vae.", + "audio_vae.", + "vocoder.", + "text_embedding_projection.", + "text_encoders.", + "first_stage_model.", + "cond_stage_model.", + "conditioner.", + ] + correct_prefix = "model.diffusion_model." + strip_prefixes = ["diffusion_model.", "transformer.", "_orig_mod.", "model.", "velocity_model."] + + base_non_transformer = { + k: v + for k, v in base_state.items() + if any(k.startswith(p) for p in non_transformer_prefixes) + } + base_connectors = { + k: v + for k, v in base_state.items() + if "embeddings_connector" in k and k.startswith(correct_prefix) + } + + prefixed = {} + for k, v in diffusion_transformer_state_dict.items(): + clean_k = k + for prefix in strip_prefixes: + if clean_k.startswith(prefix): + clean_k = clean_k[len(prefix) :] + break + prefixed[f"{correct_prefix}{clean_k}"] = v + + merged = dict(base_non_transformer) + merged.update(base_connectors) + merged.update(prefixed) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + + del base_state + return merged, base_metadata + + +DIFFUSION_MERGE_FUNCTIONS: dict[str, Callable] = { + "ltx2": _merge_ltx2, +} + + +def merge_diffusion_checkpoint( + state_dict: dict[str, torch.Tensor], + merged_base_safetensor_path: str, + model_type: str, + hf_quant_config: dict | None = None, +) -> tuple[dict[str, torch.Tensor], dict[str, str]]: + """Merge transformer weights with a base checkpoint and build ComfyUI metadata. + + Dispatches to the model-specific merge function in ``DIFFUSION_MERGE_FUNCTIONS`` + and, when ``hf_quant_config`` is provided, embeds ``quantization_config`` and + per-layer ``_quantization_metadata`` in the safetensors metadata for ComfyUI. + + Args: + state_dict: The transformer state dict (already on CPU). + merged_base_safetensor_path: Path to the full base model ``.safetensors`` file + containing all components (transformer, VAE, vocoder, etc.), + e.g. ``"path/to/ltx-2-19b-dev.safetensors"``. + model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge. + hf_quant_config: If provided, embed quantization config and per-layer + ``_quantization_metadata`` in the returned metadata dict. + + Returns: + Tuple of (merged_state_dict, metadata) where *metadata* is the base checkpoint's + original metadata augmented with any quantization entries. + """ + merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] + merged_state_dict, metadata = merge_fn(state_dict, merged_base_safetensor_path) + + if hf_quant_config is not None: + metadata["quantization_config"] = json.dumps(hf_quant_config) + + quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() + layer_metadata = {} + for k in merged_state_dict: + if k.endswith((".weight_scale", ".weight_scale_2")): + layer_name = k.rsplit(".", 1)[0] + if layer_name.endswith(".weight"): + layer_name = layer_name.rsplit(".", 1)[0] + if layer_name not in layer_metadata: + layer_metadata[layer_name] = {"format": quant_algo} + metadata["_quantization_metadata"] = json.dumps( + { + "format_version": "1.0", + "layers": layer_metadata, + } + ) + + return merged_state_dict, metadata + + +def get_diffusion_model_type(pipe: Any) -> str: + """Detect the diffusion model type for merge function dispatch. + + To add a new model type, add a detection clause here and a corresponding + merge function in ``DIFFUSION_MERGE_FUNCTIONS``. + + Args: + pipe: The pipeline or component being exported. + + Returns: + A string key into ``DIFFUSION_MERGE_FUNCTIONS``. + + Raises: + ValueError: If the model type is not supported. + """ + if TI2VidTwoStagesPipeline is not None and isinstance(pipe, TI2VidTwoStagesPipeline): + return "ltx2" + + raise ValueError( + f"No merge function for model type '{type(pipe).__name__}'. " + "Add an entry to DIFFUSION_MERGE_FUNCTIONS in diffusers_utils.py." + ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 4e9e3ba321..3567f64b44 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -36,11 +36,13 @@ from .diffusers_utils import ( generate_diffusion_dummy_forward_fn, get_diffusion_components, + get_diffusion_model_type, get_qkv_group_key, hide_quantizers_from_state_dict, infer_dtype_from_model, is_diffusers_object, is_qkv_projection, + merge_diffusion_checkpoint, ) HAS_DIFFUSERS = True @@ -116,20 +118,49 @@ def _is_enabled_quantizer(quantizer): def _save_component_state_dict_safetensors( - component: nn.Module, component_export_dir: Path + component: nn.Module, + component_export_dir: Path, + merged_base_safetensor_path: str | None = None, + hf_quant_config: dict | None = None, + model_type: str | None = None, ) -> None: + """Save component state dict as safetensors with optional base checkpoint merge. + + Args: + component: The nn.Module to save. + component_export_dir: Directory to save model.safetensors and config.json. + merged_base_safetensor_path: If provided, merge the exported transformer weights + with non-transformer components (VAE, vocoder, text encoders, etc.) from this + base safetensors file and add quantization metadata to produce a single-file + checkpoint compatible with ComfyUI. This should be the path to a full base + model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``. + hf_quant_config: If provided, embed quantization config in safetensors metadata + and per-layer _quantization_metadata for ComfyUI. + model_type: Key into ``DIFFUSION_MERGE_FUNCTIONS`` for the model-specific merge. + Required when ``merged_base_safetensor_path`` is not None. + """ cpu_state_dict = {k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()} - save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) - with open(component_export_dir / "config.json", "w") as f: - json.dump( - { - "_class_name": type(component).__name__, - "_export_format": "safetensors_state_dict", - }, - f, - indent=4, + metadata: dict[str, str] = {} + metadata_full: dict[str, str] = {} + + if merged_base_safetensor_path is not None and model_type is not None: + cpu_state_dict, metadata_full = merge_diffusion_checkpoint( + cpu_state_dict, merged_base_safetensor_path, model_type, hf_quant_config ) + metadata["_export_format"] = "safetensors_state_dict" + metadata["_class_name"] = type(component).__name__ + metadata_full.update(metadata) + + save_file( + cpu_state_dict, + str(component_export_dir / "model.safetensors"), + metadata=metadata_full if merged_base_safetensor_path is not None else None, + ) + + with open(component_export_dir / "config.json", "w") as f: + json.dump(metadata, f, indent=4) + def _collect_shared_input_modules( model: nn.Module, @@ -822,6 +853,7 @@ def _export_diffusers_checkpoint( dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, + merged_base_safetensor_path: str | None = None, max_shard_size: int | str = "10GB", ) -> None: """Internal: Export diffusion(-like) model/pipeline checkpoint. @@ -836,6 +868,11 @@ def _export_diffusers_checkpoint( export_dir: The directory to save the exported checkpoint. components: Optional list of component names to export. Only used for pipelines. If None, all components are exported. + merged_base_safetensor_path: If provided, merge the exported transformer weights + with non-transformer components (VAE, vocoder, text encoders, etc.) from this + base safetensors file and add quantization metadata to produce a single-file + checkpoint compatible with ComfyUI. This should be the path to a full base + model ``.safetensors`` file, e.g. ``"path/to/ltx-2-19b-dev.safetensors"``. max_shard_size: Maximum size of each shard file. If the model exceeds this size, it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. @@ -849,6 +886,9 @@ def _export_diffusers_checkpoint( warnings.warn("No exportable components found in the model.") return + # Resolve model type once (only needed when merging with a base checkpoint) + model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None + # Separate nn.Module components for quantization-aware export module_components = { name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) @@ -894,6 +934,7 @@ def _export_diffusers_checkpoint( # Step 5: Build quantization config quant_config = get_quant_config(component, is_modelopt_qlora=False) + hf_quant_config = convert_hf_quant_config_format(quant_config) if quant_config else None # Step 6: Save the component # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter @@ -903,12 +944,15 @@ def _export_diffusers_checkpoint( component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: with hide_quantizers_from_state_dict(component): - _save_component_state_dict_safetensors(component, component_export_dir) - + _save_component_state_dict_safetensors( + component, + component_export_dir, + merged_base_safetensor_path, + hf_quant_config, + model_type, + ) # Step 7: Update config.json with quantization info - if quant_config is not None: - hf_quant_config = convert_hf_quant_config_format(quant_config) - + if hf_quant_config is not None: config_path = component_export_dir / "config.json" if config_path.exists(): with open(config_path) as file: @@ -920,7 +964,12 @@ def _export_diffusers_checkpoint( elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) else: - _save_component_state_dict_safetensors(component, component_export_dir) + _save_component_state_dict_safetensors( + component, + component_export_dir, + merged_base_safetensor_path, + model_type=model_type, + ) print(f" Saved to: {component_export_dir}") @@ -1044,6 +1093,7 @@ def export_hf_checkpoint( save_modelopt_state: bool = False, components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, + **kwargs, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -1061,7 +1111,15 @@ def export_hf_checkpoint( components: Only used for diffusers pipelines. Optional list of component names to export. If None, all quantized components are exported. extra_state_dict: Extra state dictionary to add to the exported model. + **kwargs: Internal-only keyword arguments. Supported key: merged_base_safetensor_path + (str, optional). When provided, merges the exported diffusion transformer + weights with non-transformer components (VAE, vocoder, text encoders, etc.) + from this base safetensors file to produce a single-file checkpoint + compatible with ComfyUI. Value should be the path to a full base model + ``.safetensors`` file (e.g. ``"path/to/ltx-2-19b-dev.safetensors"``). + Only used for diffusion model exports. """ + merged_base_safetensor_path: str | None = kwargs.get("merged_base_safetensor_path") export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -1069,7 +1127,9 @@ def export_hf_checkpoint( if HAS_DIFFUSERS: is_diffusers_obj = is_diffusers_object(model) if is_diffusers_obj: - _export_diffusers_checkpoint(model, dtype, export_dir, components) + _export_diffusers_checkpoint( + model, dtype, export_dir, components, merged_base_safetensor_path + ) return # Transformers model export