2828
2929import torch
3030import torch .nn as nn
31- from safetensors .torch import save_file , load_file , safe_open
31+ from safetensors .torch import load_file , safe_open , save_file
3232
3333try :
3434 import diffusers
@@ -130,27 +130,38 @@ def _merge_diffusion_transformer_with_non_transformer_components(
130130 Tuple of (merged_state_dict, base_metadata) where base_metadata is the original
131131 safetensors metadata from the base checkpoint.
132132 """
133-
134133 base_state = load_file (merged_base_safetensor_path )
135134
136135 non_transformer_prefixes = [
137- 'vae.' , 'audio_vae.' , 'vocoder.' , 'text_embedding_projection.' ,
138- 'text_encoders.' , 'first_stage_model.' , 'cond_stage_model.' , 'conditioner.' ,
136+ "vae." ,
137+ "audio_vae." ,
138+ "vocoder." ,
139+ "text_embedding_projection." ,
140+ "text_encoders." ,
141+ "first_stage_model." ,
142+ "cond_stage_model." ,
143+ "conditioner." ,
139144 ]
140- correct_prefix = ' model.diffusion_model.'
141- strip_prefixes = [' diffusion_model.' , ' transformer.' , ' _orig_mod.' , ' model.' , ' velocity_model.' ]
145+ correct_prefix = " model.diffusion_model."
146+ strip_prefixes = [" diffusion_model." , " transformer." , " _orig_mod." , " model." , " velocity_model." ]
142147
143- base_non_transformer = {k : v for k , v in base_state .items ()
144- if any (k .startswith (p ) for p in non_transformer_prefixes )}
145- base_connectors = {k : v for k , v in base_state .items ()
146- if 'embeddings_connector' in k and k .startswith (correct_prefix )}
148+ base_non_transformer = {
149+ k : v
150+ for k , v in base_state .items ()
151+ if any (k .startswith (p ) for p in non_transformer_prefixes )
152+ }
153+ base_connectors = {
154+ k : v
155+ for k , v in base_state .items ()
156+ if "embeddings_connector" in k and k .startswith (correct_prefix )
157+ }
147158
148159 prefixed = {}
149160 for k , v in diffusion_transformer_state_dict .items ():
150161 clean_k = k
151162 for prefix in strip_prefixes :
152163 if clean_k .startswith (prefix ):
153- clean_k = clean_k [len (prefix ):]
164+ clean_k = clean_k [len (prefix ) :]
154165 break
155166 prefixed [f"{ correct_prefix } { clean_k } " ] = v
156167
@@ -165,10 +176,10 @@ def _merge_diffusion_transformer_with_non_transformer_components(
165176
166177
167178def _save_component_state_dict_safetensors (
168- component : nn .Module ,
169- component_export_dir : Path ,
170- merged_base_safetensor_path : str | None = None ,
171- hf_quant_config : dict | None = None
179+ component : nn .Module ,
180+ component_export_dir : Path ,
181+ merged_base_safetensor_path : str | None = None ,
182+ hf_quant_config : dict | None = None ,
172183) -> None :
173184 """Save component state dict as safetensors with optional base checkpoint merge.
174185
@@ -184,10 +195,12 @@ def _save_component_state_dict_safetensors(
184195 metadata : dict [str , str ] = {}
185196 metadata_full : dict [str , str ] = {}
186197 if merged_base_safetensor_path is not None :
187- cpu_state_dict , metadata_full = _merge_diffusion_transformer_with_non_transformer_components (
188- cpu_state_dict , merged_base_safetensor_path
198+ cpu_state_dict , metadata_full = (
199+ _merge_diffusion_transformer_with_non_transformer_components (
200+ cpu_state_dict , merged_base_safetensor_path
201+ )
189202 )
190- metadata ["_export_format" ] = "safetensors_state_dict"
203+ metadata ["_export_format" ] = "safetensors_state_dict"
191204 metadata ["_class_name" ] = type (component ).__name__
192205
193206 if hf_quant_config is not None :
@@ -197,20 +210,26 @@ def _save_component_state_dict_safetensors(
197210 quant_algo = hf_quant_config .get ("quant_algo" , "unknown" ).lower ()
198211 layer_metadata = {}
199212 for k in cpu_state_dict :
200- if k .endswith (".weight_scale" ) or k . endswith ( ".weight_scale_2" ):
213+ if k .endswith (( ".weight_scale" , ".weight_scale_2" ) ):
201214 layer_name = k .rsplit ("." , 1 )[0 ]
202215 if layer_name .endswith (".weight" ):
203216 layer_name = layer_name .rsplit ("." , 1 )[0 ]
204217 if layer_name not in layer_metadata :
205218 layer_metadata [layer_name ] = {"format" : quant_algo }
206- metadata_full ["_quantization_metadata" ] = json .dumps ({
207- "format_version" : "1.0" ,
208- "layers" : layer_metadata ,
209- })
219+ metadata_full ["_quantization_metadata" ] = json .dumps (
220+ {
221+ "format_version" : "1.0" ,
222+ "layers" : layer_metadata ,
223+ }
224+ )
210225
211226 metadata_full .update (metadata )
212- save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ), metadata = metadata_full if merged_base_safetensor_path is not None else None )
213-
227+ save_file (
228+ cpu_state_dict ,
229+ str (component_export_dir / "model.safetensors" ),
230+ metadata = metadata_full if merged_base_safetensor_path is not None else None ,
231+ )
232+
214233 with open (component_export_dir / "config.json" , "w" ) as f :
215234 json .dump (metadata , f , indent = 4 )
216235
@@ -971,7 +990,7 @@ def _export_diffusers_checkpoint(
971990 # Step 5: Build quantization config
972991 quant_config = get_quant_config (component , is_modelopt_qlora = False )
973992 hf_quant_config = convert_hf_quant_config_format (quant_config ) if quant_config else None
974-
993+
975994 # Step 6: Save the component
976995 # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter
977996 # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save
@@ -981,8 +1000,8 @@ def _export_diffusers_checkpoint(
9811000 else :
9821001 with hide_quantizers_from_state_dict (component ):
9831002 _save_component_state_dict_safetensors (
984- component ,
985- component_export_dir ,
1003+ component ,
1004+ component_export_dir ,
9861005 merged_base_safetensor_path ,
9871006 hf_quant_config ,
9881007 )
@@ -999,7 +1018,9 @@ def _export_diffusers_checkpoint(
9991018 elif hasattr (component , "save_pretrained" ):
10001019 component .save_pretrained (component_export_dir , max_shard_size = max_shard_size )
10011020 else :
1002- _save_component_state_dict_safetensors (component , component_export_dir , merged_base_safetensor_path )
1021+ _save_component_state_dict_safetensors (
1022+ component , component_export_dir , merged_base_safetensor_path
1023+ )
10031024
10041025 print (f" Saved to: { component_export_dir } " )
10051026
@@ -1108,7 +1129,9 @@ def export_hf_checkpoint(
11081129 if HAS_DIFFUSERS :
11091130 is_diffusers_obj = is_diffusers_object (model )
11101131 if is_diffusers_obj :
1111- _export_diffusers_checkpoint (model , dtype , export_dir , components , merged_base_safetensor_path )
1132+ _export_diffusers_checkpoint (
1133+ model , dtype , export_dir , components , merged_base_safetensor_path
1134+ )
11121135 return
11131136
11141137 # Transformers model export
0 commit comments