Skip to content

Commit 503da7e

Browse files
committed
turn off pth saving which is useless now
1 parent c9a0dc6 commit 503da7e

4 files changed

Lines changed: 1 addition & 42 deletions

File tree

fms_fsdp/utils/checkpointing_utils.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -314,34 +314,3 @@ def save(
314314
)
315315

316316
return self._cleanup()
317-
318-
def save_single_file(
319-
self,
320-
step,
321-
model,
322-
is_compiled=False,
323-
**kwargs,
324-
):
325-
# Note: metadata kwargs cannot contain any of:
326-
# (step, model)
327-
pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step))
328-
os.makedirs(pth_path, exist_ok=True)
329-
save_name = os.path.join(pth_path, "consolidated.00.pth")
330-
save_time = time.time()
331-
with FSDP.state_dict_type(
332-
model,
333-
StateDictType.FULL_STATE_DICT,
334-
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
335-
):
336-
if is_compiled:
337-
model_state = model._orig_mod.state_dict()
338-
else:
339-
model_state = model.state_dict()
340-
if self.rank == 0:
341-
metadata = kwargs
342-
metadata["step"] = step
343-
metadata["model_state"] = model_state
344-
torch.save(metadata, save_name)
345-
self.report("Checkpoint written", model_save_time=time.time() - save_time)
346-
347-
return self._cleanup()

main_training_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ def main(**kwargs):
167167
tokens_seen,
168168
)
169169

170-
checkpointer.save_single_file(cfg.num_steps, model)
171-
172170
dist.barrier()
173171
dist.destroy_process_group()
174172

main_training_mamba.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,6 @@ def main(**kwargs):
167167
tokens_seen,
168168
)
169169

170-
checkpointer.save_single_file(cfg.num_steps, model)
171-
172170
dist.barrier()
173171
dist.destroy_process_group()
174172

speculator/train_speculator_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def train_speculator(
412412

413413
if (
414414
batch_idx % cfg.checkpoint_interval == 0
415+
or batch_idx == cfg.num_steps
415416
or do_ckpt(cfg.ckpt_save_path) is True
416417
):
417418
torch.cuda.empty_cache()
@@ -425,13 +426,6 @@ def train_speculator(
425426
torch.cuda.empty_cache()
426427
do_ckpt(cfg.ckpt_save_path, reset=True)
427428

428-
checkpointer.save_single_file(
429-
batch_idx,
430-
speculator,
431-
tokens_seen=elapsed_tokens + n_tok,
432-
is_compiled=cfg.use_torch_compile,
433-
)
434-
435429

436430
class EmbedGPTBigCode(GPTBigCode):
437431
# Overrides the forward function of GPTBigCode to allow returning embedding vectors

0 commit comments

Comments
 (0)