From a3cc654d992e60a39ae5123936d8d93957773a64 Mon Sep 17 00:00:00 2001 From: Igor Tsvetkov Date: Wed, 11 Mar 2026 11:31:30 -0700 Subject: [PATCH] Ensure checkpoint directory is writable and mark GCS utils tests as CPU only --- src/maxtext/common/checkpointing.py | 13 ++- src/maxtext/utils/gcs_utils.py | 46 ++++++++++ tests/unit/gcs_utils_test.py | 132 ++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+), 7 deletions(-) create mode 100644 tests/unit/gcs_utils_test.py diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 54c0cbc48b..220ff6f16d 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -28,6 +28,7 @@ from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator from maxtext.utils import exceptions from maxtext.utils import max_logging +from maxtext.utils import gcs_utils import numpy as np import orbax.checkpoint as ocp from orbax.checkpoint import v1 as ocp_v1 @@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager( item_handlers["iter"] = GrainCheckpointHandler() # local storage checkpoint needs parent directory created - p = epath.Path(checkpoint_dir) - p.mkdir(exist_ok=True, parents=True) + p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir) if enable_continuous_checkpointing: save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy() preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep) @@ -300,19 +300,18 @@ def create_orbax_emergency_checkpoint_manager( flags.FLAGS.experimental_orbax_use_distributed_process_id = True max_logging.log("Creating emergency checkpoint manager...") - # Only create directories if running on GPUs as the previous - # directory structure might be assumed by TPUs + # Only create local directories if running on GPUs as the previous directory structure might be assumed by TPUs. if global_mesh.devices.flatten()[0].platform == "gpu": # pylint: disable=protected-access local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}" local_p = epath.Path(local_checkpoint_dir) - persistent_p = epath.Path(persistent_checkpoint_dir) local_p.mkdir(exist_ok=True, parents=True) - persistent_p.mkdir(exist_ok=True, parents=True) + + persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir) manager = EmergencyCheckpointManager( local_checkpoint_dir, - epath.Path(persistent_checkpoint_dir), + persistent_p, global_mesh=global_mesh, abstract_state=abstract_state, options=emergency_checkpoint_manager.CheckpointManagerOptions( diff --git a/src/maxtext/utils/gcs_utils.py b/src/maxtext/utils/gcs_utils.py index cdd2e1a8f8..b0f8a98c01 100644 --- a/src/maxtext/utils/gcs_utils.py +++ b/src/maxtext/utils/gcs_utils.py @@ -18,6 +18,8 @@ import os import socket from pathlib import Path +from etils import epath +import uuid import yaml @@ -242,3 +244,47 @@ def write_dict_to_gcs_json(data_dict, file_path): blob.upload_from_string(json_string, content_type="application/json") except (ValueError, TypeError, RecursionError) as e: print(f"Failed to write json file at {file_path} with error: {str(e)}") + + +def mkdir_and_check_permissions(path: str | epath.Path) -> epath.Path: + """Creates a directory if it doesn't exist and verifies write permissions. + + This function prevents the program from hanging when an output directory is inaccessible. The standard + `epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket. + + For example, the following code can hang indefinitely: + + from etils import epath + path = epath.Path("gs://no_such_bucket/path/to/output") + path.mkdir(exist_ok=True, parents=True) + """ + if isinstance(path, str): + path = epath.Path(path) + + if path.as_posix().startswith("gs://"): + if len(path.parts) < 3: + raise ValueError(f"Invalid GCS path (missing bucket name): '{path}'") + bucket_name = path.parts[2] + try: + storage_client = storage.Client() + storage_client.get_bucket(bucket_name) + except Exception as e: + raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e + path.mkdir(exist_ok=True, parents=True) + if not path.exists(): + raise PermissionError(f"Failed to create the directory '{path}'. Please check that you have write access.") + + # Verify write permissions by creating and deleting a temporary file. + # This handles the case where the directory exists but is not writable. + temp_file_path = path / f".write_test_{uuid.uuid4()}" + try: + temp_file_path.write_text("test") + except Exception as e: # pylint: disable=broad-exception-caught + raise PermissionError(f"Directory '{path}' exists, but is not writable. Please check permissions.") from e + finally: + try: + temp_file_path.unlink() # Delete the temp file. + except Exception: # pylint: disable=broad-exception-caught + pass # Suppress errors during cleanup to not hide the original error. + + return path diff --git a/tests/unit/gcs_utils_test.py b/tests/unit/gcs_utils_test.py new file mode 100644 index 0000000000..fd6c0bfb84 --- /dev/null +++ b/tests/unit/gcs_utils_test.py @@ -0,0 +1,132 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for GCS utility functions.""" + +import unittest +from unittest import mock +import os +import tempfile +import pytest + +# Module to be tested +from maxtext.utils import gcs_utils + + +@pytest.mark.cpu_only +class GcsUtilsTest(unittest.TestCase): + """Unit tests for GCS utility functions.""" + + def test_add_trailing_slash(self): + """Tests the simple add_trailing_slash utility.""" + self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/") + self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/") + + def test_mkdir_non_existing_dir(self): + """Tests that a non-existing directory is created and is empty.""" + with tempfile.TemporaryDirectory() as temp_dir: + new_dir_path = os.path.join(temp_dir, "new_dir") + self.assertFalse(os.path.exists(new_dir_path)) + + # Act + gcs_utils.mkdir_and_check_permissions(new_dir_path) + + # Assert + self.assertTrue(os.path.isdir(new_dir_path)) + self.assertEqual(os.listdir(new_dir_path), []) + + def test_mkdir_existing_non_empty_dir(self): + """Tests that an existing, non-empty directory's contents are unmodified.""" + with tempfile.TemporaryDirectory() as temp_dir: + existing_dir_path = os.path.join(temp_dir, "existing_dir") + os.makedirs(existing_dir_path) + dummy_file_path = os.path.join(existing_dir_path, "dummy.txt") + with open(dummy_file_path, "w", encoding="utf-8") as f: + f.write("test") + + # Act + gcs_utils.mkdir_and_check_permissions(existing_dir_path) + + # Assert + self.assertTrue(os.path.isdir(existing_dir_path)) + self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"]) + + def test_mkdir_existing_read_only_dir(self): + """Tests that a PermissionError is raised for a read-only directory.""" + # Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the + # tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions. + # Instead we use "/sys", which is a universally read-only directory, even for superusers. + read_only_dir_path = "/sys" + with self.assertRaises((PermissionError, OSError)): + gcs_utils.mkdir_and_check_permissions(read_only_dir_path) + + def test_mkdir_read_only_parent_dir(self): + """Tests that a PermissionError is raised when the parent is read-only.""" + # Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the + # tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions. + # Instead we use "/sys", which is a universally read-only directory, even for superusers. + new_dir_path = "/sys/new_dir" + with self.assertRaises((PermissionError, OSError)): + gcs_utils.mkdir_and_check_permissions(new_dir_path) + + @mock.patch("maxtext.utils.gcs_utils.storage.Client") + def test_mkdir_gcs_no_such_bucket(self, mock_storage_client): + """Tests that an exception is raised for a non-existent GCS bucket.""" + mock_client_instance = mock_storage_client.return_value + mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!") + gcs_path = "gs://no_such_bucket" + + with self.assertRaises(FileNotFoundError): + gcs_utils.mkdir_and_check_permissions(gcs_path) + mock_client_instance.get_bucket.assert_called_with("no_such_bucket") + + @mock.patch("maxtext.utils.gcs_utils.storage.Client") + def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client): + """Tests an exception for a non-existent bucket with a subdirectory.""" + mock_client_instance = mock_storage_client.return_value + mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!") + gcs_path = "gs://no_such_bucket/some/dir" + + with self.assertRaises(FileNotFoundError): + gcs_utils.mkdir_and_check_permissions(gcs_path) + mock_client_instance.get_bucket.assert_called_with("no_such_bucket") + + @mock.patch("maxtext.utils.gcs_utils.epath.Path") + @mock.patch("maxtext.utils.gcs_utils.storage.Client") + def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath): + """Tests that a valid GCS path is handled correctly without errors.""" + # Arrange: Mock the GCS client to simulate a valid bucket + mock_client_instance = mock_storage_client.return_value + + # Arrange: Mock epath to prevent real GCS calls + mock_path_instance = mock.MagicMock() + mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir" + mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"] + mock_path_instance.exists.return_value = True + + mock_temp_file_instance = mock.MagicMock() + mock_path_instance.__truediv__.return_value = mock_temp_file_instance + + mock_epath.return_value = mock_path_instance + gcs_path = "gs://valid_bucket/some/dir" + + # Act + gcs_utils.mkdir_and_check_permissions(gcs_path) + + # Assert + mock_client_instance.get_bucket.assert_called_with("valid_bucket") + mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True) + mock_path_instance.exists.assert_called_once() + mock_temp_file_instance.write_text.assert_called_once_with("test") + mock_temp_file_instance.unlink.assert_called_once()