File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ defaults :
2+ - defaults
3+ - _self_
4+
5+ model_tag : nvidia/esm2_t48_15B_UR50D
6+ stop_after_n_steps : 500
7+ trainer :
8+ run_name : " esm2_t48_15B_UR50D_perf"
9+ per_device_train_batch_size : 12
10+ per_device_eval_batch_size : 12
11+ report_to : " wandb"
12+ learning_rate : 1.6e-4
13+ weight_decay : 0.1
14+ warmup_steps : 20_000
Original file line number Diff line number Diff line change 1919import hydra
2020import torch
2121import transformers
22+ from accelerate import Accelerator
2223from omegaconf import DictConfig
2324from transformers import AutoConfig , AutoModelForMaskedLM
2425from transformers .trainer import Trainer
3536@hydra .main (config_path = "hydra_config" , config_name = "L0_sanity" , version_base = "1.2" )
3637def main (args : DictConfig ):
3738 """Entrypoint."""
39+ # We need to initialize the Accelerator manually prior to creating our model, otherwise we won't end up setting the
40+ # current torch device and the model creation will all happen on a single GPU, typically leading to an OOM.
41+ _ = Accelerator ()
42+
3843 config = AutoConfig .from_pretrained (args .model_tag , trust_remote_code = True )
3944 config .max_seq_length = args .max_seq_length
4045 config .micro_batch_size = args .trainer .per_device_train_batch_size
@@ -57,8 +62,6 @@ def main(args: DictConfig):
5762 callbacks = [StopAfterNStepsCallback (args .stop_after_n_steps )],
5863 )
5964
60- logger .info ("ACCELERATE STATE:\n %s\n " , trainer .accelerator .state )
61-
6265 if training_args .do_train :
6366 Path (training_args .output_dir ).mkdir (parents = True , exist_ok = True )
6467 last_checkpoint = transformers .trainer_utils .get_last_checkpoint (training_args .output_dir )
You can’t perform that action at this time.
0 commit comments