Skip to content

Commit 481ea60

Browse files
committed
restore train script
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 51c8a67 commit 481ea60

1 file changed

Lines changed: 3 additions & 12 deletions

File tree

recipes/esm2_accelerate/train.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ def main(args: DictConfig):
3838
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
3939
config.max_seq_length = args.max_seq_length
4040
config.micro_batch_size = args.trainer.per_device_train_batch_size
41-
42-
model = AutoModelForMaskedLM.from_config(
43-
config,
44-
trust_remote_code=True,
45-
torch_dtype=torch.bfloat16,
46-
)
41+
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
4742

4843
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
4944
tokenizer_name=args.model_tag,
@@ -62,7 +57,7 @@ def main(args: DictConfig):
6257
callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
6358
)
6459

65-
train_result, eval_result = None, None
60+
logger.info("ACCELERATE STATE:\n%s\n", trainer.accelerator.state)
6661

6762
if training_args.do_train:
6863
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
@@ -77,11 +72,7 @@ def main(args: DictConfig):
7772
trainer.save_model(str(Path(training_args.output_dir) / "checkpoint-last"))
7873

7974
if training_args.do_eval:
80-
eval_result = trainer.evaluate()
81-
logger.info("Evaluation complete. Metrics: %s", eval_result)
82-
trainer.save_metrics("eval", eval_result)
83-
84-
return train_result, eval_result
75+
trainer.evaluate()
8576

8677

8778
if __name__ == "__main__":

0 commit comments

Comments
 (0)