Skip to content

Commit 21a132d

Browse files
committed
Ensure checkpoint directory is writable and mark GCS utils tests as CPU only
1 parent 7b35530 commit 21a132d

3 files changed

Lines changed: 185 additions & 7 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
31+
from maxtext.utils import gcs_utils
3132
import numpy as np
3233
import orbax.checkpoint as ocp
3334
from orbax.checkpoint import v1 as ocp_v1
@@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
245246
item_handlers["iter"] = GrainCheckpointHandler()
246247

247248
# local storage checkpoint needs parent directory created
248-
p = epath.Path(checkpoint_dir)
249-
p.mkdir(exist_ok=True, parents=True)
249+
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250250
if enable_continuous_checkpointing:
251251
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252252
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
@@ -300,19 +300,18 @@ def create_orbax_emergency_checkpoint_manager(
300300
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
301301
max_logging.log("Creating emergency checkpoint manager...")
302302

303-
# Only create directories if running on GPUs as the previous
304-
# directory structure might be assumed by TPUs
303+
# Only create local directories if running on GPUs as the previous directory structure might be assumed by TPUs.
305304
if global_mesh.devices.flatten()[0].platform == "gpu":
306305
# pylint: disable=protected-access
307306
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
308307
local_p = epath.Path(local_checkpoint_dir)
309-
persistent_p = epath.Path(persistent_checkpoint_dir)
310308
local_p.mkdir(exist_ok=True, parents=True)
311-
persistent_p.mkdir(exist_ok=True, parents=True)
309+
310+
persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)
312311

313312
manager = EmergencyCheckpointManager(
314313
local_checkpoint_dir,
315-
epath.Path(persistent_checkpoint_dir),
314+
persistent_p,
316315
global_mesh=global_mesh,
317316
abstract_state=abstract_state,
318317
options=emergency_checkpoint_manager.CheckpointManagerOptions(

src/maxtext/utils/gcs_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import socket
2020
from pathlib import Path
21+
from etils import epath
22+
import uuid
2123

2224
import yaml
2325

@@ -242,3 +244,47 @@ def write_dict_to_gcs_json(data_dict, file_path):
242244
blob.upload_from_string(json_string, content_type="application/json")
243245
except (ValueError, TypeError, RecursionError) as e:
244246
print(f"Failed to write json file at {file_path} with error: {str(e)}")
247+
248+
249+
def mkdir_and_check_permissions(p: str | epath.Path) -> epath.Path:
250+
"""Creates a directory if it doesn't exist and verifies write permissions.
251+
252+
This function prevents the program from hanging when an output directory is inaccessible. The standard
253+
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
254+
255+
For example, the following code can hang indefinitely:
256+
257+
from etils import epath
258+
p = epath.Path("gs://no_such_bucket/path/to/output")
259+
p.mkdir(exist_ok=True, parents=True)
260+
"""
261+
if isinstance(p, str):
262+
p = epath.Path(p)
263+
264+
if p.as_posix().startswith("gs://"):
265+
if len(p.parts) < 3:
266+
raise ValueError(f"Invalid GCS path (missing bucket name): '{p}'")
267+
bucket_name = p.parts[2]
268+
try:
269+
storage_client = storage.Client()
270+
storage_client.get_bucket(bucket_name)
271+
except Exception as e:
272+
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
273+
p.mkdir(exist_ok=True, parents=True)
274+
if not p.exists():
275+
raise PermissionError(f"Failed to create the directory '{p}'. Please check that you have write access.")
276+
277+
# Verify write permissions by creating and deleting a temporary file.
278+
# This handles the case where the directory exists but is not writable.
279+
temp_file_path = p / f".write_test_{uuid.uuid4()}"
280+
try:
281+
temp_file_path.write_text("test")
282+
except Exception as e: # pylint: disable=broad-exception-caught
283+
raise PermissionError(f"Directory '{p}' exists, but is not writable. Please check permissions.") from e
284+
finally:
285+
try:
286+
temp_file_path.unlink() # Delete the temp file.
287+
except Exception: # pylint: disable=broad-exception-caught
288+
pass # Suppress errors during cleanup to not hide the original error.
289+
290+
return p

tests/unit/gcs_utils_test.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for GCS utility functions."""
16+
17+
import unittest
18+
from unittest import mock
19+
import os
20+
import tempfile
21+
import pytest
22+
23+
# Module to be tested
24+
from maxtext.utils import gcs_utils
25+
26+
27+
@pytest.mark.cpu_only
28+
class GcsUtilsTest(unittest.TestCase):
29+
"""Unit tests for GCS utility functions."""
30+
31+
def test_add_trailing_slash(self):
32+
"""Tests the simple add_trailing_slash utility."""
33+
self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/")
34+
self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/")
35+
36+
def test_mkdir_non_existing_dir(self):
37+
"""Tests that a non-existing directory is created and is empty."""
38+
with tempfile.TemporaryDirectory() as temp_dir:
39+
new_dir_path = os.path.join(temp_dir, "new_dir")
40+
self.assertFalse(os.path.exists(new_dir_path))
41+
42+
# Act
43+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
44+
45+
# Assert
46+
self.assertTrue(os.path.isdir(new_dir_path))
47+
self.assertEqual(os.listdir(new_dir_path), [])
48+
49+
def test_mkdir_existing_non_empty_dir(self):
50+
"""Tests that an existing, non-empty directory's contents are unmodified."""
51+
with tempfile.TemporaryDirectory() as temp_dir:
52+
existing_dir_path = os.path.join(temp_dir, "existing_dir")
53+
os.makedirs(existing_dir_path)
54+
dummy_file_path = os.path.join(existing_dir_path, "dummy.txt")
55+
with open(dummy_file_path, "w", encoding="utf-8") as f:
56+
f.write("test")
57+
58+
# Act
59+
gcs_utils.mkdir_and_check_permissions(existing_dir_path)
60+
61+
# Assert
62+
self.assertTrue(os.path.isdir(existing_dir_path))
63+
self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"])
64+
65+
def test_mkdir_existing_read_only_dir(self):
66+
"""Tests that a PermissionError is raised for a read-only directory."""
67+
with tempfile.TemporaryDirectory() as temp_dir:
68+
read_only_dir_path = os.path.join(temp_dir, "read_only_dir")
69+
os.makedirs(read_only_dir_path)
70+
os.chmod(read_only_dir_path, 0o555)
71+
gcs_utils.mkdir_and_check_permissions(read_only_dir_path)
72+
self.assertTrue(os.path.isdir(read_only_dir_path))
73+
74+
def test_mkdir_read_only_parent_dir(self):
75+
"""Tests that a PermissionError is raised when the parent is read-only."""
76+
with tempfile.TemporaryDirectory() as temp_dir:
77+
parent_dir_path = os.path.join(temp_dir, "read_only_parent")
78+
os.makedirs(parent_dir_path)
79+
os.chmod(parent_dir_path, 0o555)
80+
new_dir_path = os.path.join(parent_dir_path, "new_dir")
81+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
82+
self.assertFalse(os.path.isdir(new_dir_path))
83+
84+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
85+
def test_mkdir_gcs_no_such_bucket(self, mock_storage_client):
86+
"""Tests that an exception is raised for a non-existent GCS bucket."""
87+
mock_client_instance = mock_storage_client.return_value
88+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
89+
gcs_path = "gs://no_such_bucket"
90+
91+
with self.assertRaises(FileNotFoundError):
92+
gcs_utils.mkdir_and_check_permissions(gcs_path)
93+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
94+
95+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
96+
def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client):
97+
"""Tests an exception for a non-existent bucket with a subdirectory."""
98+
mock_client_instance = mock_storage_client.return_value
99+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
100+
gcs_path = "gs://no_such_bucket/some/dir"
101+
102+
with self.assertRaises(FileNotFoundError):
103+
gcs_utils.mkdir_and_check_permissions(gcs_path)
104+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
105+
106+
@mock.patch("maxtext.utils.gcs_utils.epath.Path")
107+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
108+
def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath):
109+
"""Tests that a valid GCS path is handled correctly without errors."""
110+
# Arrange: Mock the GCS client to simulate a valid bucket
111+
mock_client_instance = mock_storage_client.return_value
112+
113+
# Arrange: Mock epath to prevent real GCS calls
114+
mock_path_instance = mock.MagicMock()
115+
mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir"
116+
mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"]
117+
mock_path_instance.exists.return_value = True
118+
119+
mock_temp_file_instance = mock.MagicMock()
120+
mock_path_instance.__truediv__.return_value = mock_temp_file_instance
121+
122+
mock_epath.return_value = mock_path_instance
123+
gcs_path = "gs://valid_bucket/some/dir"
124+
125+
# Act
126+
gcs_utils.mkdir_and_check_permissions(gcs_path)
127+
128+
# Assert
129+
mock_client_instance.get_bucket.assert_called_with("valid_bucket")
130+
mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True)
131+
mock_path_instance.exists.assert_called_once()
132+
mock_temp_file_instance.write_text.assert_called_once_with("test")
133+
mock_temp_file_instance.unlink.assert_called_once()

0 commit comments

Comments
 (0)