@@ -106,6 +106,22 @@ def _is_enabled_quantizer(quantizer):
106106 return False
107107
108108
109+ def _save_component_state_dict_safetensors (
110+ component : nn .Module , component_export_dir : Path
111+ ) -> None :
112+ cpu_state_dict = {k : v .detach ().contiguous ().cpu () for k , v in component .state_dict ().items ()}
113+ save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ))
114+ with open (component_export_dir / "config.json" , "w" ) as f :
115+ json .dump (
116+ {
117+ "_class_name" : type (component ).__name__ ,
118+ "_export_format" : "safetensors_state_dict" ,
119+ },
120+ f ,
121+ indent = 4 ,
122+ )
123+
124+
109125def _collect_shared_input_modules (
110126 model : nn .Module ,
111127 dummy_forward_fn : Callable [[], None ],
@@ -853,19 +869,7 @@ def _export_diffusers_checkpoint(
853869 component .save_pretrained (component_export_dir , max_shard_size = max_shard_size )
854870 else :
855871 with hide_quantizers_from_state_dict (component ):
856- cpu_state_dict = {
857- k : v .detach ().contiguous ().cpu () for k , v in component .state_dict ().items ()
858- }
859- save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ))
860- with open (component_export_dir / "config.json" , "w" ) as f :
861- json .dump (
862- {
863- "_class_name" : type (component ).__name__ ,
864- "_export_format" : "safetensors_state_dict" ,
865- },
866- f ,
867- indent = 4 ,
868- )
872+ _save_component_state_dict_safetensors (component , component_export_dir )
869873
870874 # Step 7: Update config.json with quantization info
871875 if quant_config is not None :
@@ -882,19 +886,7 @@ def _export_diffusers_checkpoint(
882886 elif hasattr (component , "save_pretrained" ):
883887 component .save_pretrained (component_export_dir , max_shard_size = max_shard_size )
884888 else :
885- cpu_state_dict = {
886- k : v .detach ().contiguous ().cpu () for k , v in component .state_dict ().items ()
887- }
888- save_file (cpu_state_dict , str (component_export_dir / "model.safetensors" ))
889- with open (component_export_dir / "config.json" , "w" ) as f :
890- json .dump (
891- {
892- "_class_name" : type (component ).__name__ ,
893- "_export_format" : "safetensors_state_dict" ,
894- },
895- f ,
896- indent = 4 ,
897- )
889+ _save_component_state_dict_safetensors (component , component_export_dir )
898890
899891 print (f" Saved to: { component_export_dir } " )
900892
0 commit comments