Skip to content

Commit a346f92

Browse files
fix: directly save final ckpt in save_model_dir
Signed-off-by: yashasvi <yashasvi@ibm.com>
1 parent 3344193 commit a346f92

1 file changed

Lines changed: 24 additions & 20 deletions

File tree

tuning/config/acceleration_configs/fast_moe.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,30 +97,32 @@ def on_save(
9797
"""
9898

9999
def checkpoint(checkpoint_dir, save_dir):
100-
hf_converted_output_dir = os.path.join(
101-
save_dir, "hf_converted_checkpoint"
102-
)
103-
if os.path.exists(hf_converted_output_dir):
104-
# If the folder already exists
105-
# we return, since this is possible to happen
106-
# saving the checkpointing at the end of the training
100+
101+
# If the folder already exists
102+
# we return, since this is possible to happen
103+
# saving the checkpointing at the end of the training
104+
try:
105+
already_converted = any(os.scandir(save_dir))
106+
except FileNotFoundError:
107+
already_converted = False
108+
109+
if already_converted:
110+
# Already converted (resume); skip re-conversion
107111
return
108-
os.mkdir(hf_converted_output_dir)
112+
109113
try:
110114
recover_safetensors_from_dcp(
111115
checkpoint_dir,
112116
self.pretrained_model_name_or_path,
113-
hf_converted_output_dir,
117+
save_dir,
114118
)
115119
# Save tokenizer
116120
if self.trainer.processing_class:
117-
self.trainer.processing_class.save_pretrained(
118-
hf_converted_output_dir
119-
)
121+
self.trainer.processing_class.save_pretrained(save_dir)
120122
# Save training args
121123
torch.save(
122124
args,
123-
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
125+
os.path.join(save_dir, TRAINING_ARGS_NAME),
124126
)
125127

126128
# Unwrap FSDP module
@@ -135,16 +137,14 @@ def checkpoint(checkpoint_dir, save_dir):
135137
list(config_dict["target_modules"])
136138
)
137139
with open(
138-
os.path.join(
139-
hf_converted_output_dir, "adapter_config.json"
140-
),
140+
os.path.join(save_dir, "adapter_config.json"),
141141
"w",
142142
encoding="utf-8",
143143
) as f:
144144
json.dump(config_dict, f, indent=2)
145145

146146
else:
147-
model.config.save_pretrained(hf_converted_output_dir)
147+
model.config.save_pretrained(save_dir)
148148

149149
except Exception as e:
150150
raise ValueError(
@@ -157,15 +157,19 @@ def checkpoint(checkpoint_dir, save_dir):
157157
checkpoint_dir = os.path.join(
158158
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
159159
)
160-
checkpoint(checkpoint_dir, checkpoint_dir)
160+
hf_converted_path = os.path.join(
161+
checkpoint_dir, "hf_converted_checkpoint"
162+
)
163+
if not os.path.exists(hf_converted_path):
164+
os.makedirs(hf_converted_path)
165+
checkpoint(checkpoint_dir, hf_converted_path)
161166

162167
# If final save directory is provided, save the model there
163168
if (
164169
getattr(self, "save_model_dir", None)
165170
and state.global_step == state.max_steps
166171
):
167-
if not os.path.exists(self.save_model_dir):
168-
os.mkdir(self.save_model_dir)
172+
os.makedirs(self.save_model_dir, exist_ok=True)
169173
checkpoint(checkpoint_dir, self.save_model_dir)
170174

171175
callbacks.append(

0 commit comments

Comments
 (0)