Skip to content

Commit d5130c5

Browse files
committed
fix(test): use local checkpoints in integration tests
1 parent 4b213c1 commit d5130c5

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

tests/integration/checkpoint_compatibility_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
4949
"grain_worker_count=0",
5050
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
5151
]
52+
local_ckpt_dir = "/tmp/maxtext_local_output"
5253

5354
# Run training using grain input pipeline
5455
train_main(
@@ -60,6 +61,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
6061
attention_type=attention_type,
6162
dataset_type="grain",
6263
dataset_path="/tmp/gcsfuse",
64+
base_output_directory=local_ckpt_dir,
6365
)
6466
+ grain_command
6567
)
@@ -74,6 +76,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
7476
attention_type=attention_type,
7577
dataset_type="tfds",
7678
dataset_path="/tmp/gcsfuse",
79+
base_output_directory=local_ckpt_dir,
7780
)
7881
)
7982

tests/integration/checkpointing_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242

4343

44-
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path):
44+
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path, base_output_directory=None):
4545
"""Generates a command list for a checkpointing test run.
4646
4747
Args:
@@ -56,7 +56,8 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
5656
Returns:
5757
A list of strings representing the command line arguments.
5858
"""
59-
base_output_directory = get_test_base_output_directory()
59+
if base_output_directory is None:
60+
base_output_directory = get_test_base_output_directory()
6061
model_params = [
6162
"base_emb_dim=128",
6263
"base_num_query_heads=2",
@@ -148,6 +149,7 @@ def run_checkpointing(hardware, attention_type):
148149
"grain_worker_count=0",
149150
f"grain_train_files={selected_pattern}",
150151
]
152+
local_ckpt_dir = "/tmp/maxtext_local_output"
151153
train_main(
152154
get_checkpointing_command(
153155
run_date,
@@ -157,6 +159,7 @@ def run_checkpointing(hardware, attention_type):
157159
attention_type=attention_type,
158160
dataset_type="grain",
159161
dataset_path=dataset_path,
162+
base_output_directory=local_ckpt_dir,
160163
)
161164
+ grain_command
162165
)
@@ -170,6 +173,7 @@ def run_checkpointing(hardware, attention_type):
170173
attention_type=attention_type,
171174
dataset_type="grain",
172175
dataset_path=dataset_path,
176+
base_output_directory=local_ckpt_dir,
173177
)
174178
+ grain_command
175179
)

0 commit comments

Comments
 (0)