2828
2929import torch
3030import torch .nn as nn
31- from safetensors .torch import save_file
31+ from safetensors .torch import save_file , load_file , safe_open
3232
3333try :
3434 import diffusers
@@ -111,20 +111,108 @@ def _is_enabled_quantizer(quantizer):
111111 return False
112112
113113
114+ def _merge_diffusion_transformer_with_non_transformer_components (
115+ diffusion_transformer_state_dict : dict [str , torch .Tensor ],
116+ merged_base_safetensor_path : str ,
117+ ) -> tuple [dict [str , torch .Tensor ], dict [str , str ]]:
118+ """Merge diffusion transformer weights with non-transformer components from a safetensors file.
119+
120+ Non-transformer components (VAE, vocoder, text encoders) and embeddings connectors are
121+ taken from the base checkpoint. Transformer keys are prefixed with 'model.diffusion_model.'
122+ for ComfyUI compatibility.
123+
124+ Args:
125+ diffusion_transformer_state_dict: The diffusion transformer state dict (already on CPU).
126+ merged_base_safetensor_path: Path to the full base model safetensors file containing
127+ all components (transformer, VAE, vocoder, etc.).
128+
129+ Returns:
130+ Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
131+ safetensors metadata from the base checkpoint.
132+ """
133+
134+ base_state = load_file (merged_base_safetensor_path )
135+
136+ non_transformer_prefixes = [
137+ 'vae.' , 'audio_vae.' , 'vocoder.' , 'text_embedding_projection.' ,
138+ 'text_encoders.' , 'first_stage_model.' , 'cond_stage_model.' , 'conditioner.' ,
139+ ]
140+ correct_prefix = 'model.diffusion_model.'
141+ strip_prefixes = ['diffusion_model.' , 'transformer.' , '_orig_mod.' , 'model.' , 'velocity_model.' ]
142+
143+ base_non_transformer = {k : v for k , v in base_state .items ()
144+ if any (k .startswith (p ) for p in non_transformer_prefixes )}
145+ base_connectors = {k : v for k , v in base_state .items ()
146+ if 'embeddings_connector' in k and k .startswith (correct_prefix )}
147+
148+ prefixed = {}
149+ for k , v in diffusion_transformer_state_dict .items ():
150+ clean_k = k
151+ for prefix in strip_prefixes :
152+ if clean_k .startswith (prefix ):
153+ clean_k = clean_k [len (prefix ):]
154+ break
155+ prefixed [f"{ correct_prefix } { clean_k } " ] = v
156+
157+ merged = dict (base_non_transformer )
158+ merged .update (base_connectors )
159+ merged .update (prefixed )
160+ with safe_open (merged_base_safetensor_path , framework = "pt" , device = "cpu" ) as f :
161+ base_metadata = f .metadata () or {}
162+
163+ del base_state
164+ return merged , base_metadata
165+
166+
114167def _save_component_state_dict_safetensors (
115- component : nn .Module , component_export_dir : Path
168+ component : nn .Module ,
169+ component_export_dir : Path ,
170+ merged_base_safetensor_path : str | None = None ,
171+ hf_quant_config : dict | None = None
116172) -> None :
173+ """Save component state dict as safetensors with optional base checkpoint merge.
174+
175+ Args:
176+ component: The nn.Module to save.
177+ component_export_dir: Directory to save model.safetensors and config.json.
178+ merged_base_safetensor_path: If provided, merge with non-transformer components
179+ from this base safetensors file.
180+ hf_quant_config: If provided, embed quantization config in safetensors metadata
181+ and per-layer _quantization_metadata for ComfyUI.
182+ """
117183 cpu_state_dict = {k : v .detach ().contiguous ().cpu () for k , v in component .state_dict ().items ()}
118- save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ))
119- with open (component_export_dir / "config.json" , "w" ) as f :
120- json .dump (
121- {
122- "_class_name" : type (component ).__name__ ,
123- "_export_format" : "safetensors_state_dict" ,
124- },
125- f ,
126- indent = 4 ,
184+ metadata : dict [str , str ] = {}
185+ metadata_full : dict [str , str ] = {}
186+ if merged_base_safetensor_path is not None :
187+ cpu_state_dict , metadata_full = _merge_diffusion_transformer_with_non_transformer_components (
188+ cpu_state_dict , merged_base_safetensor_path
127189 )
190+ metadata ["_export_format" ] = "safetensors_state_dict"
191+ metadata ["_class_name" ] = type (component ).__name__
192+
193+ if hf_quant_config is not None :
194+ metadata_full ["quantization_config" ] = json .dumps (hf_quant_config )
195+
196+ # Build per-layer _quantization_metadata for ComfyUI
197+ quant_algo = hf_quant_config .get ("quant_algo" , "unknown" ).lower ()
198+ layer_metadata = {}
199+ for k in cpu_state_dict :
200+ if k .endswith (".weight_scale" ) or k .endswith (".weight_scale_2" ):
201+ layer_name = k .rsplit ("." , 1 )[0 ]
202+ if layer_name .endswith (".weight" ):
203+ layer_name = layer_name .rsplit ("." , 1 )[0 ]
204+ if layer_name not in layer_metadata :
205+ layer_metadata [layer_name ] = {"format" : quant_algo }
206+ metadata_full ["_quantization_metadata" ] = json .dumps ({
207+ "format_version" : "1.0" ,
208+ "layers" : layer_metadata ,
209+ })
210+
211+ metadata_full .update (metadata )
212+ save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ), metadata = metadata_full if merged_base_safetensor_path is not None else None )
213+
214+ with open (component_export_dir / "config.json" , "w" ) as f :
215+ json .dump (metadata , f , indent = 4 )
128216
129217
130218def _collect_shared_input_modules (
@@ -807,6 +895,7 @@ def _export_diffusers_checkpoint(
807895 dtype : torch .dtype | None ,
808896 export_dir : Path ,
809897 components : list [str ] | None ,
898+ merged_base_safetensor_path : str | None = None ,
810899 max_shard_size : int | str = "10GB" ,
811900) -> None :
812901 """Internal: Export diffusion(-like) model/pipeline checkpoint.
@@ -821,6 +910,8 @@ def _export_diffusers_checkpoint(
821910 export_dir: The directory to save the exported checkpoint.
822911 components: Optional list of component names to export. Only used for pipelines.
823912 If None, all components are exported.
913+ merged_base_safetensor_path: If provided, merge the exported transformer with
914+ non-transformer components from this base safetensors file.
824915 max_shard_size: Maximum size of each shard file. If the model exceeds this size,
825916 it will be sharded into multiple files and a .safetensors.index.json will be
826917 created. Use smaller values like "5GB" or "2GB" to force sharding.
@@ -879,7 +970,8 @@ def _export_diffusers_checkpoint(
879970
880971 # Step 5: Build quantization config
881972 quant_config = get_quant_config (component , is_modelopt_qlora = False )
882-
973+ hf_quant_config = convert_hf_quant_config_format (quant_config ) if quant_config else None
974+
883975 # Step 6: Save the component
884976 # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
885977 # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save
@@ -888,12 +980,14 @@ def _export_diffusers_checkpoint(
888980 component .save_pretrained (component_export_dir , max_shard_size = max_shard_size )
889981 else :
890982 with hide_quantizers_from_state_dict (component ):
891- _save_component_state_dict_safetensors (component , component_export_dir )
892-
983+ _save_component_state_dict_safetensors (
984+ component ,
985+ component_export_dir ,
986+ merged_base_safetensor_path ,
987+ hf_quant_config ,
988+ )
893989 # Step 7: Update config.json with quantization info
894- if quant_config is not None :
895- hf_quant_config = convert_hf_quant_config_format (quant_config )
896-
990+ if hf_quant_config is not None :
897991 config_path = component_export_dir / "config.json"
898992 if config_path .exists ():
899993 with open (config_path ) as file :
@@ -905,7 +999,7 @@ def _export_diffusers_checkpoint(
905999 elif hasattr (component , "save_pretrained" ):
9061000 component .save_pretrained (component_export_dir , max_shard_size = max_shard_size )
9071001 else :
908- _save_component_state_dict_safetensors (component , component_export_dir )
1002+ _save_component_state_dict_safetensors (component , component_export_dir , merged_base_safetensor_path )
9091003
9101004 print (f" Saved to: { component_export_dir } " )
9111005
@@ -985,6 +1079,7 @@ def export_hf_checkpoint(
9851079 save_modelopt_state : bool = False ,
9861080 components : list [str ] | None = None ,
9871081 extra_state_dict : dict [str , torch .Tensor ] | None = None ,
1082+ merged_base_safetensor_path : str | None = None ,
9881083):
9891084 """Export quantized HuggingFace model checkpoint (transformers or diffusers).
9901085
@@ -1002,6 +1097,9 @@ def export_hf_checkpoint(
10021097 components: Only used for diffusers pipelines. Optional list of component names
10031098 to export. If None, all quantized components are exported.
10041099 extra_state_dict: Extra state dictionary to add to the exported model.
1100+ merged_base_safetensor_path: If provided, merge the exported diffusion transformer
1101+ with non-transformer components (VAE, vocoder, etc.) from this base safetensors
1102+ file. Only used for diffusion model exports (e.g., LTX-2).
10051103 """
10061104 export_dir = Path (export_dir )
10071105 export_dir .mkdir (parents = True , exist_ok = True )
@@ -1010,7 +1108,7 @@ def export_hf_checkpoint(
10101108 if HAS_DIFFUSERS :
10111109 is_diffusers_obj = is_diffusers_object (model )
10121110 if is_diffusers_obj :
1013- _export_diffusers_checkpoint (model , dtype , export_dir , components )
1111+ _export_diffusers_checkpoint (model , dtype , export_dir , components , merged_base_safetensor_path )
10141112 return
10151113
10161114 # Transformers model export
0 commit comments