From eaf7dd6c67bc1b9218d8f5bb6ddf06505425e9e1 Mon Sep 17 00:00:00 2001 From: Doruk Arisoy Date: Tue, 2 Jun 2026 20:18:38 +0000 Subject: [PATCH] fix(ci): stabilize GHA test workflows and use local checkpoints - Wrap library path discovery in directory check to prevent crash in image-testing mode (when .venv is missing). - Force checkpoint integration tests to use local /tmp directory to avoid GCS permission failures on GHA runners. --- .github/workflows/run_tests_against_package.yml | 16 +++++++++------- .../integration/checkpoint_compatibility_test.py | 3 +++ tests/integration/checkpointing_test.py | 10 ++++++++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 5c44e268bc..9df299ab21 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -162,13 +162,15 @@ jobs: # Dynamically discover the 'nvidia' folder and prepend all its sub-library # directories (including nccl, cublas, cudnn) to LD_LIBRARY_PATH to prevent # JAX from partially loading incompatible system-level CUDA libraries. - NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1) - if [ -n "${NVIDIA_DIR}" ]; then - for dir in "${NVIDIA_DIR}"/*; do - if [ -d "$dir/lib" ]; then - export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH} - fi - done + if [ -d ".venv/lib" ]; then + NVIDIA_DIR=$(find .venv/lib/ -maxdepth 3 -name "nvidia" -type d 2>/dev/null | head -n 1) + if [ -n "${NVIDIA_DIR}" ]; then + for dir in "${NVIDIA_DIR}"/*; do + if [ -d "$dir/lib" ]; then + export LD_LIBRARY_PATH=$(pwd)/$dir/lib:${LD_LIBRARY_PATH} + fi + done + fi fi fi if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then diff --git a/tests/integration/checkpoint_compatibility_test.py b/tests/integration/checkpoint_compatibility_test.py index 5b628695b5..a147a177df 100644 --- a/tests/integration/checkpoint_compatibility_test.py +++ b/tests/integration/checkpoint_compatibility_test.py @@ -49,6 +49,7 @@ def run_checkpoint_compatibility(hardware, attention_type): "grain_worker_count=0", "grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*", ] + local_ckpt_dir = "/tmp/maxtext_local_output" # Run training using grain input pipeline train_main( @@ -60,6 +61,7 @@ def run_checkpoint_compatibility(hardware, attention_type): attention_type=attention_type, dataset_type="grain", dataset_path="/tmp/gcsfuse", + base_output_directory=local_ckpt_dir, ) + grain_command ) @@ -74,6 +76,7 @@ def run_checkpoint_compatibility(hardware, attention_type): attention_type=attention_type, dataset_type="tfds", dataset_path="/tmp/gcsfuse", + base_output_directory=local_ckpt_dir, ) ) diff --git a/tests/integration/checkpointing_test.py b/tests/integration/checkpointing_test.py index 22b5cad5c4..bb7581fcf6 100644 --- a/tests/integration/checkpointing_test.py +++ b/tests/integration/checkpointing_test.py @@ -41,7 +41,9 @@ ) -def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path): +def get_checkpointing_command( + run_date, hardware, steps, metrics_file, attention_type, dataset_type, dataset_path, base_output_directory=None +): """Generates a command list for a checkpointing test run. Args: @@ -56,7 +58,8 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention Returns: A list of strings representing the command line arguments. """ - base_output_directory = get_test_base_output_directory() + if base_output_directory is None: + base_output_directory = get_test_base_output_directory() model_params = [ "base_emb_dim=128", "base_num_query_heads=2", @@ -148,6 +151,7 @@ def run_checkpointing(hardware, attention_type): "grain_worker_count=0", f"grain_train_files={selected_pattern}", ] + local_ckpt_dir = "/tmp/maxtext_local_output" train_main( get_checkpointing_command( run_date, @@ -157,6 +161,7 @@ def run_checkpointing(hardware, attention_type): attention_type=attention_type, dataset_type="grain", dataset_path=dataset_path, + base_output_directory=local_ckpt_dir, ) + grain_command ) @@ -170,6 +175,7 @@ def run_checkpointing(hardware, attention_type): attention_type=attention_type, dataset_type="grain", dataset_path=dataset_path, + base_output_directory=local_ckpt_dir, ) + grain_command )