4141)
4242
4343
44- def get_checkpointing_command (run_date , hardware , steps , metrics_file , attention_type , dataset_type , dataset_path ):
44+ def get_checkpointing_command (run_date , hardware , steps , metrics_file , attention_type , dataset_type , dataset_path , base_output_directory = None ):
4545 """Generates a command list for a checkpointing test run.
4646
4747 Args:
@@ -56,7 +56,8 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
5656 Returns:
5757 A list of strings representing the command line arguments.
5858 """
59- base_output_directory = get_test_base_output_directory ()
59+ if base_output_directory is None :
60+ base_output_directory = get_test_base_output_directory ()
6061 model_params = [
6162 "base_emb_dim=128" ,
6263 "base_num_query_heads=2" ,
@@ -148,6 +149,7 @@ def run_checkpointing(hardware, attention_type):
148149 "grain_worker_count=0" ,
149150 f"grain_train_files={ selected_pattern } " ,
150151 ]
152+ local_ckpt_dir = "/tmp/maxtext_local_output"
151153 train_main (
152154 get_checkpointing_command (
153155 run_date ,
@@ -157,6 +159,7 @@ def run_checkpointing(hardware, attention_type):
157159 attention_type = attention_type ,
158160 dataset_type = "grain" ,
159161 dataset_path = dataset_path ,
162+ base_output_directory = local_ckpt_dir ,
160163 )
161164 + grain_command
162165 )
@@ -170,6 +173,7 @@ def run_checkpointing(hardware, attention_type):
170173 attention_type = attention_type ,
171174 dataset_type = "grain" ,
172175 dataset_path = dataset_path ,
176+ base_output_directory = local_ckpt_dir ,
173177 )
174178 + grain_command
175179 )
0 commit comments