@@ -38,12 +38,7 @@ def main(args: DictConfig):
3838 config = AutoConfig .from_pretrained (args .model_tag , trust_remote_code = True )
3939 config .max_seq_length = args .max_seq_length
4040 config .micro_batch_size = args .trainer .per_device_train_batch_size
41-
42- model = AutoModelForMaskedLM .from_config (
43- config ,
44- trust_remote_code = True ,
45- torch_dtype = torch .bfloat16 ,
46- )
41+ model = AutoModelForMaskedLM .from_config (config , trust_remote_code = True , torch_dtype = torch .bfloat16 )
4742
4843 train_dataset , eval_dataset , data_collator = create_datasets_and_collator (
4944 tokenizer_name = args .model_tag ,
@@ -62,7 +57,7 @@ def main(args: DictConfig):
6257 callbacks = [StopAfterNStepsCallback (args .stop_after_n_steps )],
6358 )
6459
65- train_result , eval_result = None , None
60+ logger . info ( "ACCELERATE STATE: \n %s \n " , trainer . accelerator . state )
6661
6762 if training_args .do_train :
6863 Path (training_args .output_dir ).mkdir (parents = True , exist_ok = True )
@@ -77,11 +72,7 @@ def main(args: DictConfig):
7772 trainer .save_model (str (Path (training_args .output_dir ) / "checkpoint-last" ))
7873
7974 if training_args .do_eval :
80- eval_result = trainer .evaluate ()
81- logger .info ("Evaluation complete. Metrics: %s" , eval_result )
82- trainer .save_metrics ("eval" , eval_result )
83-
84- return train_result , eval_result
75+ trainer .evaluate ()
8576
8677
8778if __name__ == "__main__" :
0 commit comments