Skip to content

Commit eaf7dd6

Browse files
committed
fix(ci): stabilize GHA test workflows and use local checkpoints
- Wrap library path discovery in directory check to prevent crash in image-testing mode (when .venv is missing). - Force checkpoint integration tests to use local /tmp directory to avoid GCS permission failures on GHA runners.
1 parent f0842ca commit eaf7dd6

3 files changed

Lines changed: 20 additions & 9 deletions

File tree

.github/workflows/run_tests_against_package.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,15 @@ jobs:
162162
# Dynamically discover the 'nvidia' folder and prepend all its sub-library
163163
# directories (including nccl, cublas, cudnn) to LD_LIBRARY_PATH to prevent
164164
# JAX from partially loading incompatible system-level CUDA libraries.
165-
NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1)
166-
if [ -n "${NVIDIA_DIR}" ]; then
167-
for dir in "${NVIDIA_DIR}"/*; do
168-
if [ -d "$dir/lib" ]; then
169-
export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH}
170-
fi
171-
done
165+
if [ -d ".venv/lib" ]; then
166+
NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1)
167+
if [ -n "${NVIDIA_DIR}" ]; then
168+
for dir in "${NVIDIA_DIR}"/*; do
169+
if [ -d "$dir/lib" ]; then
170+
export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH}
171+
fi
172+
done
173+
fi
172174
fi
173175
fi
174176
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then

tests/integration/checkpoint_compatibility_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
4949
"grain_worker_count=0",
5050
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
5151
]
52+
local_ckpt_dir = "/tmp/maxtext_local_output"
5253

5354
# Run training using grain input pipeline
5455
train_main(
@@ -60,6 +61,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
6061
attention_type=attention_type,
6162
dataset_type="grain",
6263
dataset_path="/tmp/gcsfuse",
64+
base_output_directory=local_ckpt_dir,
6365
)
6466
+ grain_command
6567
)
@@ -74,6 +76,7 @@ def run_checkpoint_compatibility(hardware, attention_type):
7476
attention_type=attention_type,
7577
dataset_type="tfds",
7678
dataset_path="/tmp/gcsfuse",
79+
base_output_directory=local_ckpt_dir,
7780
)
7881
)
7982

tests/integration/checkpointing_test.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
)
4242

4343

44-
def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path):
44+
def get_checkpointing_command(
45+
run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path, base_output_directory=None
46+
):
4547
"""Generates a command list for a checkpointing test run.
4648
4749
Args:
@@ -56,7 +58,8 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
5658
Returns:
5759
A list of strings representing the command line arguments.
5860
"""
59-
base_output_directory = get_test_base_output_directory()
61+
if base_output_directory is None:
62+
base_output_directory = get_test_base_output_directory()
6063
model_params = [
6164
"base_emb_dim=128",
6265
"base_num_query_heads=2",
@@ -148,6 +151,7 @@ def run_checkpointing(hardware, attention_type):
148151
"grain_worker_count=0",
149152
f"grain_train_files={selected_pattern}",
150153
]
154+
local_ckpt_dir = "/tmp/maxtext_local_output"
151155
train_main(
152156
get_checkpointing_command(
153157
run_date,
@@ -157,6 +161,7 @@ def run_checkpointing(hardware, attention_type):
157161
attention_type=attention_type,
158162
dataset_type="grain",
159163
dataset_path=dataset_path,
164+
base_output_directory=local_ckpt_dir,
160165
)
161166
+ grain_command
162167
)
@@ -170,6 +175,7 @@ def run_checkpointing(hardware, attention_type):
170175
attention_type=attention_type,
171176
dataset_type="grain",
172177
dataset_path=dataset_path,
178+
base_output_directory=local_ckpt_dir,
173179
)
174180
+ grain_command
175181
)

0 commit comments

Comments
 (0)