@@ -97,30 +97,32 @@ def on_save(
9797 """
9898
9999 def checkpoint (checkpoint_dir , save_dir ):
100- hf_converted_output_dir = os .path .join (
101- save_dir , "hf_converted_checkpoint"
102- )
103- if os .path .exists (hf_converted_output_dir ):
104- # If the folder already exists
105- # we return, since this is possible to happen
106- # saving the checkpointing at the end of the training
100+
101+ # If the folder already exists
102+ # we return, since this is possible to happen
103+ # saving the checkpointing at the end of the training
104+ try :
105+ already_converted = any (os .scandir (save_dir ))
106+ except FileNotFoundError :
107+ already_converted = False
108+
109+ if already_converted :
110+ # Already converted (resume); skip re-conversion
107111 return
108- os . mkdir ( hf_converted_output_dir )
112+
109113 try :
110114 recover_safetensors_from_dcp (
111115 checkpoint_dir ,
112116 self .pretrained_model_name_or_path ,
113- hf_converted_output_dir ,
117+ save_dir ,
114118 )
115119 # Save tokenizer
116120 if self .trainer .processing_class :
117- self .trainer .processing_class .save_pretrained (
118- hf_converted_output_dir
119- )
121+ self .trainer .processing_class .save_pretrained (save_dir )
120122 # Save training args
121123 torch .save (
122124 args ,
123- os .path .join (hf_converted_output_dir , TRAINING_ARGS_NAME ),
125+ os .path .join (save_dir , TRAINING_ARGS_NAME ),
124126 )
125127
126128 # Unwrap FSDP module
@@ -135,16 +137,14 @@ def checkpoint(checkpoint_dir, save_dir):
135137 list (config_dict ["target_modules" ])
136138 )
137139 with open (
138- os .path .join (
139- hf_converted_output_dir , "adapter_config.json"
140- ),
140+ os .path .join (save_dir , "adapter_config.json" ),
141141 "w" ,
142142 encoding = "utf-8" ,
143143 ) as f :
144144 json .dump (config_dict , f , indent = 2 )
145145
146146 else :
147- model .config .save_pretrained (hf_converted_output_dir )
147+ model .config .save_pretrained (save_dir )
148148
149149 except Exception as e :
150150 raise ValueError (
@@ -157,15 +157,19 @@ def checkpoint(checkpoint_dir, save_dir):
157157 checkpoint_dir = os .path .join (
158158 args .output_dir , f"{ PREFIX_CHECKPOINT_DIR } -{ state .global_step } "
159159 )
160- checkpoint (checkpoint_dir , checkpoint_dir )
160+ hf_converted_path = os .path .join (
161+ checkpoint_dir , "hf_converted_checkpoint"
162+ )
163+ if not os .path .exists (hf_converted_path ):
164+ os .makedirs (hf_converted_path )
165+ checkpoint (checkpoint_dir , hf_converted_path )
161166
162167 # If final save directory is provided, save the model there
163168 if (
164169 getattr (self , "save_model_dir" , None )
165170 and state .global_step == state .max_steps
166171 ):
167- if not os .path .exists (self .save_model_dir ):
168- os .mkdir (self .save_model_dir )
172+ os .makedirs (self .save_model_dir , exist_ok = True )
169173 checkpoint (checkpoint_dir , self .save_model_dir )
170174
171175 callbacks .append (
0 commit comments