Skip to content

Commit eab7dee

Browse files
committed
Add writability checks to the checkpoint directory.
1 parent f11f550 commit eab7dee

3 files changed

Lines changed: 184 additions & 7 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
31+
from maxtext.utils import gcs_utils
3132
import numpy as np
3233
import orbax.checkpoint as ocp
3334
from orbax.checkpoint import v1 as ocp_v1
@@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
245246
item_handlers["iter"] = GrainCheckpointHandler()
246247

247248
# local storage checkpoint needs parent directory created
248-
p = epath.Path(checkpoint_dir)
249-
p.mkdir(exist_ok=True, parents=True)
249+
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250250
if enable_continuous_checkpointing:
251251
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252252
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
@@ -300,19 +300,18 @@ def create_orbax_emergency_checkpoint_manager(
300300
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
301301
max_logging.log("Creating emergency checkpoint manager...")
302302

303-
# Only create directories if running on GPUs as the previous
304-
# directory structure might be assumed by TPUs
303+
# Only create local directories if running on GPUs as the previous directory structure might be assumed by TPUs.
305304
if global_mesh.devices.flatten()[0].platform == "gpu":
306305
# pylint: disable=protected-access
307306
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
308307
local_p = epath.Path(local_checkpoint_dir)
309-
persistent_p = epath.Path(persistent_checkpoint_dir)
310308
local_p.mkdir(exist_ok=True, parents=True)
311-
persistent_p.mkdir(exist_ok=True, parents=True)
309+
310+
persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)
312311

313312
manager = EmergencyCheckpointManager(
314313
local_checkpoint_dir,
315-
epath.Path(persistent_checkpoint_dir),
314+
persistent_p,
316315
global_mesh=global_mesh,
317316
abstract_state=abstract_state,
318317
options=emergency_checkpoint_manager.CheckpointManagerOptions(

src/maxtext/utils/gcs_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import socket
2020
from pathlib import Path
21+
from etils import epath
22+
import uuid
2123

2224
import yaml
2325

@@ -242,3 +244,47 @@ def write_dict_to_gcs_json(data_dict, file_path):
242244
blob.upload_from_string(json_string, content_type="application/json")
243245
except (ValueError, TypeError, RecursionError) as e:
244246
print(f"Failed to write json file at {file_path} with error: {str(e)}")
247+
248+
249+
def mkdir_and_check_permissions(p: str | epath.Path) -> epath.Path:
250+
"""Creates a directory if it doesn't exist and verifies write permissions.
251+
252+
This function prevents the program from hanging when an output directory is inaccessible. The standard
253+
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
254+
255+
For example, the following code can hang indefinitely:
256+
257+
from etils import epath
258+
p = epath.Path("gs://no_such_bucket/path/to/output")
259+
p.mkdir(exist_ok=True, parents=True)
260+
"""
261+
if isinstance(p, str):
262+
p = epath.Path(p)
263+
264+
if p.as_posix().startswith("gs://"):
265+
if len(p.parts) < 3:
266+
raise ValueError(f"Invalid GCS path (missing bucket name): '{p}'")
267+
bucket_name = p.parts[2]
268+
try:
269+
storage_client = storage.Client()
270+
storage_client.get_bucket(bucket_name)
271+
except Exception as e:
272+
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
273+
p.mkdir(exist_ok=True, parents=True)
274+
if not p.exists():
275+
raise PermissionError(f"Failed to create the directory '{p}'. Please check that you have write access.")
276+
277+
# Verify write permissions by creating and deleting a temporary file.
278+
# This handles the case where the directory exists but is not writable.
279+
temp_file_path = p / f".write_test_{uuid.uuid4()}"
280+
try:
281+
temp_file_path.write_text("test")
282+
except Exception as e: # pylint: disable=broad-exception-caught
283+
raise PermissionError(f"Directory '{p}' exists, but is not writable. Please check permissions.") from e
284+
finally:
285+
try:
286+
temp_file_path.unlink() # Delete the temp file.
287+
except Exception: # pylint: disable=broad-exception-caught
288+
pass # Suppress errors during cleanup to not hide the original error.
289+
290+
return p

tests/unit/gcs_utils_test.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for GCS utility functions."""
16+
17+
import unittest
18+
from unittest import mock
19+
import os
20+
import tempfile
21+
22+
# Module to be tested
23+
from maxtext.utils import gcs_utils
24+
25+
26+
class GcsUtilsTest(unittest.TestCase):
27+
"""Unit tests for GCS utility functions."""
28+
29+
def test_add_trailing_slash(self):
30+
"""Tests the simple add_trailing_slash utility."""
31+
self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/")
32+
self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/")
33+
34+
def test_mkdir_non_existing_dir(self):
35+
"""Tests that a non-existing directory is created and is empty."""
36+
with tempfile.TemporaryDirectory() as temp_dir:
37+
new_dir_path = os.path.join(temp_dir, "new_dir")
38+
self.assertFalse(os.path.exists(new_dir_path))
39+
40+
# Act
41+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
42+
43+
# Assert
44+
self.assertTrue(os.path.isdir(new_dir_path))
45+
self.assertEqual(os.listdir(new_dir_path), [])
46+
47+
def test_mkdir_existing_non_empty_dir(self):
48+
"""Tests that an existing, non-empty directory's contents are unmodified."""
49+
with tempfile.TemporaryDirectory() as temp_dir:
50+
existing_dir_path = os.path.join(temp_dir, "existing_dir")
51+
os.makedirs(existing_dir_path)
52+
dummy_file_path = os.path.join(existing_dir_path, "dummy.txt")
53+
with open(dummy_file_path, "w", encoding="utf-8") as f:
54+
f.write("test")
55+
56+
# Act
57+
gcs_utils.mkdir_and_check_permissions(existing_dir_path)
58+
59+
# Assert
60+
self.assertTrue(os.path.isdir(existing_dir_path))
61+
self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"])
62+
63+
def test_mkdir_existing_read_only_dir(self):
64+
"""Tests that a PermissionError is raised for a read-only directory."""
65+
with tempfile.TemporaryDirectory() as temp_dir:
66+
read_only_dir_path = os.path.join(temp_dir, "read_only_dir")
67+
os.makedirs(read_only_dir_path)
68+
os.chmod(read_only_dir_path, 0o555)
69+
with self.assertRaises(PermissionError):
70+
gcs_utils.mkdir_and_check_permissions(read_only_dir_path)
71+
72+
def test_mkdir_read_only_parent_dir(self):
73+
"""Tests that a PermissionError is raised when the parent is read-only."""
74+
with tempfile.TemporaryDirectory() as temp_dir:
75+
parent_dir_path = os.path.join(temp_dir, "read_only_parent")
76+
os.makedirs(parent_dir_path)
77+
os.chmod(parent_dir_path, 0o555)
78+
new_dir_path = os.path.join(parent_dir_path, "new_dir")
79+
with self.assertRaises(PermissionError):
80+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
81+
82+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
83+
def test_mkdir_gcs_no_such_bucket(self, mock_storage_client):
84+
"""Tests that an exception is raised for a non-existent GCS bucket."""
85+
mock_client_instance = mock_storage_client.return_value
86+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
87+
gcs_path = "gs://no_such_bucket"
88+
89+
with self.assertRaises(FileNotFoundError):
90+
gcs_utils.mkdir_and_check_permissions(gcs_path)
91+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
92+
93+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
94+
def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client):
95+
"""Tests an exception for a non-existent bucket with a subdirectory."""
96+
mock_client_instance = mock_storage_client.return_value
97+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
98+
gcs_path = "gs://no_such_bucket/some/dir"
99+
100+
with self.assertRaises(FileNotFoundError):
101+
gcs_utils.mkdir_and_check_permissions(gcs_path)
102+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
103+
104+
@mock.patch("maxtext.utils.gcs_utils.uuid.uuid4", return_value="test-uuid")
105+
@mock.patch("maxtext.utils.gcs_utils.epath.Path")
106+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
107+
def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath, mock_uuid):
108+
"""Tests that a valid GCS path is handled correctly without errors."""
109+
# Arrange: Mock the GCS client to simulate a valid bucket
110+
mock_client_instance = mock_storage_client.return_value
111+
112+
# Arrange: Mock epath to prevent real GCS calls
113+
mock_path_instance = mock.MagicMock()
114+
mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir"
115+
mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"]
116+
mock_path_instance.exists.return_value = True
117+
118+
mock_temp_file_instance = mock.MagicMock()
119+
mock_path_instance.__truediv__.return_value = mock_temp_file_instance
120+
121+
mock_epath.return_value = mock_path_instance
122+
gcs_path = "gs://valid_bucket/some/dir"
123+
124+
# Act
125+
gcs_utils.mkdir_and_check_permissions(gcs_path)
126+
127+
# Assert
128+
mock_client_instance.get_bucket.assert_called_with("valid_bucket")
129+
mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True)
130+
mock_path_instance.exists.assert_called_once()
131+
mock_temp_file_instance.write_text.assert_called_once_with("test")
132+
mock_temp_file_instance.unlink.assert_called_once()

0 commit comments

Comments
 (0)