Skip to content

Commit ca7e2df

Browse files
Merge pull request #3317 from AI-Hypercomputer:igorts/b476521717-gke-logs
PiperOrigin-RevId: 883405275
2 parents d359768 + a3cc654 commit ca7e2df

3 files changed

Lines changed: 184 additions & 7 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
31+
from maxtext.utils import gcs_utils
3132
import numpy as np
3233
import orbax.checkpoint as ocp
3334
from orbax.checkpoint import v1 as ocp_v1
@@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
245246
item_handlers["iter"] = GrainCheckpointHandler()
246247

247248
# local storage checkpoint needs parent directory created
248-
p = epath.Path(checkpoint_dir)
249-
p.mkdir(exist_ok=True, parents=True)
249+
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250250
if enable_continuous_checkpointing:
251251
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252252
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
@@ -300,19 +300,18 @@ def create_orbax_emergency_checkpoint_manager(
300300
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
301301
max_logging.log("Creating emergency checkpoint manager...")
302302

303-
# Only create directories if running on GPUs as the previous
304-
# directory structure might be assumed by TPUs
303+
# Only create local directories if running on GPUs as the previous directory structure might be assumed by TPUs.
305304
if global_mesh.devices.flatten()[0].platform == "gpu":
306305
# pylint: disable=protected-access
307306
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
308307
local_p = epath.Path(local_checkpoint_dir)
309-
persistent_p = epath.Path(persistent_checkpoint_dir)
310308
local_p.mkdir(exist_ok=True, parents=True)
311-
persistent_p.mkdir(exist_ok=True, parents=True)
309+
310+
persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)
312311

313312
manager = EmergencyCheckpointManager(
314313
local_checkpoint_dir,
315-
epath.Path(persistent_checkpoint_dir),
314+
persistent_p,
316315
global_mesh=global_mesh,
317316
abstract_state=abstract_state,
318317
options=emergency_checkpoint_manager.CheckpointManagerOptions(

src/maxtext/utils/gcs_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import socket
2020
from pathlib import Path
21+
from etils import epath
22+
import uuid
2123

2224
import yaml
2325

@@ -242,3 +244,47 @@ def write_dict_to_gcs_json(data_dict, file_path):
242244
blob.upload_from_string(json_string, content_type="application/json")
243245
except (ValueError, TypeError, RecursionError) as e:
244246
print(f"Failed to write json file at {file_path} with error: {str(e)}")
247+
248+
249+
def mkdir_and_check_permissions(path: str | epath.Path) -> epath.Path:
250+
"""Creates a directory if it doesn't exist and verifies write permissions.
251+
252+
This function prevents the program from hanging when an output directory is inaccessible. The standard
253+
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
254+
255+
For example, the following code can hang indefinitely:
256+
257+
from etils import epath
258+
path = epath.Path("gs://no_such_bucket/path/to/output")
259+
path.mkdir(exist_ok=True, parents=True)
260+
"""
261+
if isinstance(path, str):
262+
path = epath.Path(path)
263+
264+
if path.as_posix().startswith("gs://"):
265+
if len(path.parts) < 3:
266+
raise ValueError(f"Invalid GCS path (missing bucket name): '{path}'")
267+
bucket_name = path.parts[2]
268+
try:
269+
storage_client = storage.Client()
270+
storage_client.get_bucket(bucket_name)
271+
except Exception as e:
272+
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
273+
path.mkdir(exist_ok=True, parents=True)
274+
if not path.exists():
275+
raise PermissionError(f"Failed to create the directory '{path}'. Please check that you have write access.")
276+
277+
# Verify write permissions by creating and deleting a temporary file.
278+
# This handles the case where the directory exists but is not writable.
279+
temp_file_path = path / f".write_test_{uuid.uuid4()}"
280+
try:
281+
temp_file_path.write_text("test")
282+
except Exception as e: # pylint: disable=broad-exception-caught
283+
raise PermissionError(f"Directory '{path}' exists, but is not writable. Please check permissions.") from e
284+
finally:
285+
try:
286+
temp_file_path.unlink() # Delete the temp file.
287+
except Exception: # pylint: disable=broad-exception-caught
288+
pass # Suppress errors during cleanup to not hide the original error.
289+
290+
return path

tests/unit/gcs_utils_test.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for GCS utility functions."""
16+
17+
import unittest
18+
from unittest import mock
19+
import os
20+
import tempfile
21+
import pytest
22+
23+
# Module to be tested
24+
from maxtext.utils import gcs_utils
25+
26+
27+
@pytest.mark.cpu_only
28+
class GcsUtilsTest(unittest.TestCase):
29+
"""Unit tests for GCS utility functions."""
30+
31+
def test_add_trailing_slash(self):
32+
"""Tests the simple add_trailing_slash utility."""
33+
self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/")
34+
self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/")
35+
36+
def test_mkdir_non_existing_dir(self):
37+
"""Tests that a non-existing directory is created and is empty."""
38+
with tempfile.TemporaryDirectory() as temp_dir:
39+
new_dir_path = os.path.join(temp_dir, "new_dir")
40+
self.assertFalse(os.path.exists(new_dir_path))
41+
42+
# Act
43+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
44+
45+
# Assert
46+
self.assertTrue(os.path.isdir(new_dir_path))
47+
self.assertEqual(os.listdir(new_dir_path), [])
48+
49+
def test_mkdir_existing_non_empty_dir(self):
50+
"""Tests that an existing, non-empty directory's contents are unmodified."""
51+
with tempfile.TemporaryDirectory() as temp_dir:
52+
existing_dir_path = os.path.join(temp_dir, "existing_dir")
53+
os.makedirs(existing_dir_path)
54+
dummy_file_path = os.path.join(existing_dir_path, "dummy.txt")
55+
with open(dummy_file_path, "w", encoding="utf-8") as f:
56+
f.write("test")
57+
58+
# Act
59+
gcs_utils.mkdir_and_check_permissions(existing_dir_path)
60+
61+
# Assert
62+
self.assertTrue(os.path.isdir(existing_dir_path))
63+
self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"])
64+
65+
def test_mkdir_existing_read_only_dir(self):
66+
"""Tests that a PermissionError is raised for a read-only directory."""
67+
# Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the
68+
# tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions.
69+
# Instead we use "/sys", which is a universally read-only directory, even for superusers.
70+
read_only_dir_path = "/sys"
71+
with self.assertRaises((PermissionError, OSError)):
72+
gcs_utils.mkdir_and_check_permissions(read_only_dir_path)
73+
74+
def test_mkdir_read_only_parent_dir(self):
75+
"""Tests that a PermissionError is raised when the parent is read-only."""
76+
# Ideally, we would create a temporary directory here and mark it read-only. Unfortunately it does not work when the
77+
# tests run inside a GitHub action. I think that those tests are executed as sudo and they ignore the permissions.
78+
# Instead we use "/sys", which is a universally read-only directory, even for superusers.
79+
new_dir_path = "/sys/new_dir"
80+
with self.assertRaises((PermissionError, OSError)):
81+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
82+
83+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
84+
def test_mkdir_gcs_no_such_bucket(self, mock_storage_client):
85+
"""Tests that an exception is raised for a non-existent GCS bucket."""
86+
mock_client_instance = mock_storage_client.return_value
87+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
88+
gcs_path = "gs://no_such_bucket"
89+
90+
with self.assertRaises(FileNotFoundError):
91+
gcs_utils.mkdir_and_check_permissions(gcs_path)
92+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
93+
94+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
95+
def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client):
96+
"""Tests an exception for a non-existent bucket with a subdirectory."""
97+
mock_client_instance = mock_storage_client.return_value
98+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
99+
gcs_path = "gs://no_such_bucket/some/dir"
100+
101+
with self.assertRaises(FileNotFoundError):
102+
gcs_utils.mkdir_and_check_permissions(gcs_path)
103+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
104+
105+
@mock.patch("maxtext.utils.gcs_utils.epath.Path")
106+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
107+
def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath):
108+
"""Tests that a valid GCS path is handled correctly without errors."""
109+
# Arrange: Mock the GCS client to simulate a valid bucket
110+
mock_client_instance = mock_storage_client.return_value
111+
112+
# Arrange: Mock epath to prevent real GCS calls
113+
mock_path_instance = mock.MagicMock()
114+
mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir"
115+
mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"]
116+
mock_path_instance.exists.return_value = True
117+
118+
mock_temp_file_instance = mock.MagicMock()
119+
mock_path_instance.__truediv__.return_value = mock_temp_file_instance
120+
121+
mock_epath.return_value = mock_path_instance
122+
gcs_path = "gs://valid_bucket/some/dir"
123+
124+
# Act
125+
gcs_utils.mkdir_and_check_permissions(gcs_path)
126+
127+
# Assert
128+
mock_client_instance.get_bucket.assert_called_with("valid_bucket")
129+
mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True)
130+
mock_path_instance.exists.assert_called_once()
131+
mock_temp_file_instance.write_text.assert_called_once_with("test")
132+
mock_temp_file_instance.unlink.assert_called_once()

0 commit comments

Comments
 (0)