Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
from maxtext.utils import exceptions
from maxtext.utils import max_logging
from maxtext.utils import gcs_utils
import numpy as np
import orbax.checkpoint as ocp
from orbax.checkpoint import v1 as ocp_v1
Expand Down Expand Up @@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
item_handlers["iter"] = GrainCheckpointHandler()

# local storage checkpoint needs parent directory created
p = epath.Path(checkpoint_dir)
p.mkdir(exist_ok=True, parents=True)
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
if enable_continuous_checkpointing:
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
Expand Down Expand Up @@ -300,19 +300,18 @@ def create_orbax_emergency_checkpoint_manager(
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency checkpoint manager...")

# Only create directories if running on GPUs as the previous
# directory structure might be assumed by TPUs
# Only create local directories if running on GPUs as the previous directory structure might be assumed by TPUs.
if global_mesh.devices.flatten()[0].platform == "gpu":
# pylint: disable=protected-access
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
local_p = epath.Path(local_checkpoint_dir)
persistent_p = epath.Path(persistent_checkpoint_dir)
local_p.mkdir(exist_ok=True, parents=True)
persistent_p.mkdir(exist_ok=True, parents=True)

persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)

manager = EmergencyCheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
persistent_p,
global_mesh=global_mesh,
abstract_state=abstract_state,
options=emergency_checkpoint_manager.CheckpointManagerOptions(
Expand Down
46 changes: 46 additions & 0 deletions src/maxtext/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import os
import socket
from pathlib import Path
from etils import epath
import uuid

import yaml

Expand Down Expand Up @@ -242,3 +244,47 @@ def write_dict_to_gcs_json(data_dict, file_path):
blob.upload_from_string(json_string, content_type="application/json")
except (ValueError, TypeError, RecursionError) as e:
print(f"Failed to write json file at {file_path} with error: {str(e)}")


def mkdir_and_check_permissions(path: str | epath.Path) -> epath.Path:
"""Creates a directory if it doesn't exist and verifies write permissions.
This function prevents the program from hanging when an output directory is inaccessible. The standard
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
For example, the following code can hang indefinitely:
from etils import epath
path = epath.Path("gs://no_such_bucket/path/to/output")
path.mkdir(exist_ok=True, parents=True)
"""
if isinstance(path, str):
path = epath.Path(path)

if path.as_posix().startswith("gs://"):
if len(path.parts) < 3:
raise ValueError(f"Invalid GCS path (missing bucket name): '{path}'")
bucket_name = path.parts[2]
try:
storage_client = storage.Client()
storage_client.get_bucket(bucket_name)
except Exception as e:
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
path.mkdir(exist_ok=True, parents=True)
if not path.exists():
raise PermissionError(f"Failed to create the directory '{path}'. Please check that you have write access.")

# Verify write permissions by creating and deleting a temporary file.
# This handles the case where the directory exists but is not writable.
temp_file_path = path / f".write_test_{uuid.uuid4()}"
try:
temp_file_path.write_text("test")
except Exception as e: # pylint: disable=broad-exception-caught
raise PermissionError(f"Directory '{path}' exists, but is not writable. Please check permissions.") from e
finally:
try:
temp_file_path.unlink() # Delete the temp file.
except Exception: # pylint: disable=broad-exception-caught
pass # Suppress errors during cleanup to not hide the original error.

return path
132 changes: 132 additions & 0 deletions tests/unit/gcs_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for GCS utility functions."""

import unittest
from unittest import mock
import os
import tempfile
import pytest

# Module to be tested
from maxtext.utils import gcs_utils


@pytest.mark.cpu_only
class GcsUtilsTest(unittest.TestCase):
"""Unit tests for GCS utility functions."""

def test_add_trailing_slash(self):
"""Tests the simple add_trailing_slash utility."""
self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/")
self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/")

def test_mkdir_non_existing_dir(self):
"""Tests that a non-existing directory is created and is empty."""
with tempfile.TemporaryDirectory() as temp_dir:
new_dir_path = os.path.join(temp_dir, "new_dir")
self.assertFalse(os.path.exists(new_dir_path))

# Act
gcs_utils.mkdir_and_check_permissions(new_dir_path)

# Assert
self.assertTrue(os.path.isdir(new_dir_path))
self.assertEqual(os.listdir(new_dir_path), [])

def test_mkdir_existing_non_empty_dir(self):
"""Tests that an existing, non-empty directory's contents are unmodified."""
with tempfile.TemporaryDirectory() as temp_dir:
existing_dir_path = os.path.join(temp_dir, "existing_dir")
os.makedirs(existing_dir_path)
dummy_file_path = os.path.join(existing_dir_path, "dummy.txt")
with open(dummy_file_path, "w", encoding="utf-8") as f:
f.write("test")

# Act
gcs_utils.mkdir_and_check_permissions(existing_dir_path)

# Assert
self.assertTrue(os.path.isdir(existing_dir_path))
self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"])

def test_mkdir_existing_read_only_dir(self):
"""Tests that a PermissionError is raised for a read-only directory."""
# Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the
# tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions.
# Instead we use "/sys", which is a universally read-only directory, even for superusers.
read_only_dir_path = "/sys"
with self.assertRaises((PermissionError, OSError)):
gcs_utils.mkdir_and_check_permissions(read_only_dir_path)

def test_mkdir_read_only_parent_dir(self):
"""Tests that a PermissionError is raised when the parent is read-only."""
# Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the
# tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions.
# Instead we use "/sys", which is a universally read-only directory, even for superusers.
new_dir_path = "/sys/new_dir"
with self.assertRaises((PermissionError, OSError)):
gcs_utils.mkdir_and_check_permissions(new_dir_path)

@mock.patch("maxtext.utils.gcs_utils.storage.Client")
def test_mkdir_gcs_no_such_bucket(self, mock_storage_client):
"""Tests that an exception is raised for a non-existent GCS bucket."""
mock_client_instance = mock_storage_client.return_value
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
gcs_path = "gs://no_such_bucket"

with self.assertRaises(FileNotFoundError):
gcs_utils.mkdir_and_check_permissions(gcs_path)
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")

@mock.patch("maxtext.utils.gcs_utils.storage.Client")
def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client):
"""Tests an exception for a non-existent bucket with a subdirectory."""
mock_client_instance = mock_storage_client.return_value
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
gcs_path = "gs://no_such_bucket/some/dir"

with self.assertRaises(FileNotFoundError):
gcs_utils.mkdir_and_check_permissions(gcs_path)
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")

@mock.patch("maxtext.utils.gcs_utils.epath.Path")
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath):
"""Tests that a valid GCS path is handled correctly without errors."""
# Arrange: Mock the GCS client to simulate a valid bucket
mock_client_instance = mock_storage_client.return_value

# Arrange: Mock epath to prevent real GCS calls
mock_path_instance = mock.MagicMock()
mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir"
mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"]
mock_path_instance.exists.return_value = True

mock_temp_file_instance = mock.MagicMock()
mock_path_instance.__truediv__.return_value = mock_temp_file_instance

mock_epath.return_value = mock_path_instance
gcs_path = "gs://valid_bucket/some/dir"

# Act
gcs_utils.mkdir_and_check_permissions(gcs_path)

# Assert
mock_client_instance.get_bucket.assert_called_with("valid_bucket")
mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True)
mock_path_instance.exists.assert_called_once()
mock_temp_file_instance.write_text.assert_called_once_with("test")
mock_temp_file_instance.unlink.assert_called_once()
Loading