Skip to content

Commit 1125324

Browse files
committed
feat: Improve permission tests and squash feature branch
- Reworks tests for read-only directories to use /sys instead of temporary directories with chmod. This fixes CI failures when tests are run as root. - Squashes commits related to int4 support, checkpoint directory checks, and RL parsing unit tests.
1 parent f11f550 commit 1125324

3 files changed

Lines changed: 176 additions & 5 deletions

File tree

src/maxtext/common/checkpointing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
31+
from maxtext.utils import gcs_utils
3132
import numpy as np
3233
import orbax.checkpoint as ocp
3334
from orbax.checkpoint import v1 as ocp_v1
@@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
245246
item_handlers["iter"] = GrainCheckpointHandler()
246247

247248
# local storage checkpoint needs parent directory created
248-
p = epath.Path(checkpoint_dir)
249-
p.mkdir(exist_ok=True, parents=True)
249+
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250250
if enable_continuous_checkpointing:
251251
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252252
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
@@ -300,19 +300,19 @@ def create_orbax_emergency_checkpoint_manager(
300300
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
301301
max_logging.log("Creating emergency checkpoint manager...")
302302

303+
persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)
304+
303305
# Only create directories if running on GPUs as the previous
304306
# directory structure might be assumed by TPUs
305307
if global_mesh.devices.flatten()[0].platform == "gpu":
306308
# pylint: disable=protected-access
307309
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
308310
local_p = epath.Path(local_checkpoint_dir)
309-
persistent_p = epath.Path(persistent_checkpoint_dir)
310311
local_p.mkdir(exist_ok=True, parents=True)
311-
persistent_p.mkdir(exist_ok=True, parents=True)
312312

313313
manager = EmergencyCheckpointManager(
314314
local_checkpoint_dir,
315-
epath.Path(persistent_checkpoint_dir),
315+
persistent_p,
316316
global_mesh=global_mesh,
317317
abstract_state=abstract_state,
318318
options=emergency_checkpoint_manager.CheckpointManagerOptions(

src/maxtext/utils/gcs_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import socket
2020
from pathlib import Path
21+
from etils import epath
22+
import uuid
2123

2224
import yaml
2325

@@ -242,3 +244,45 @@ def write_dict_to_gcs_json(data_dict, file_path):
242244
blob.upload_from_string(json_string, content_type="application/json")
243245
except (ValueError, TypeError, RecursionError) as e:
244246
print(f"Failed to write json file at {file_path} with error: {str(e)}")
247+
248+
249+
def mkdir_and_check_permissions(p: str | epath.Path) -> epath.Path:
250+
"""Creates a directory if it doesn't exist and verifies write permissions.
251+
252+
This function prevents the program from hanging when an output directory is inaccessible. The standard
253+
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
254+
255+
For example, the following code can hang indefinitely:
256+
257+
from etils import epath
258+
p = epath.Path("gs://no_such_bucket/path/to/output")
259+
p.mkdir(exist_ok=True, parents=True)
260+
"""
261+
if isinstance(p, str):
262+
p = epath.Path(p)
263+
264+
if p.as_posix().startswith("gs://"):
265+
bucket_name = p.parts[2]
266+
try:
267+
storage_client = storage.Client()
268+
storage_client.get_bucket(bucket_name)
269+
except Exception as e:
270+
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
271+
p.mkdir(exist_ok=True, parents=True)
272+
if not p.exists():
273+
raise PermissionError(f"Failed to create the directory '{p}'. Please check that you have write access.")
274+
275+
# Verify write permissions by creating and deleting a temporary file.
276+
# This handles the case where the directory exists but is not writable.
277+
temp_file_path = p / f".write_test_{uuid.uuid4()}"
278+
try:
279+
temp_file_path.write_text("test")
280+
except Exception as e: # pylint: disable=broad-exception-caught
281+
raise PermissionError(f"Directory '{p}' exists, but is not writable. Please check permissions.") from e
282+
finally:
283+
try:
284+
temp_file_path.unlink() # Delete the temp file.
285+
except Exception: # pylint: disable=broad-exception-caught
286+
pass # Suppress errors during cleanup to not hide the original error.
287+
288+
return p

tests/unit/gcs_utils_test.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Unit tests for GCS utility functions."""
16+
17+
import unittest
18+
from unittest import mock
19+
import os
20+
import tempfile
21+
22+
# Module to be tested
23+
from maxtext.utils import gcs_utils
24+
25+
26+
class GcsUtilsTest(unittest.TestCase):
27+
"""Unit tests for GCS utility functions."""
28+
29+
def test_add_trailing_slash(self):
30+
"""Tests the simple add_trailing_slash utility."""
31+
self.assertEqual(gcs_utils.add_trailing_slash("a/b"), "a/b/")
32+
self.assertEqual(gcs_utils.add_trailing_slash("a/b/"), "a/b/")
33+
34+
def test_mkdir_non_existing_dir(self):
35+
"""Tests that a non-existing directory is created and is empty."""
36+
with tempfile.TemporaryDirectory() as temp_dir:
37+
new_dir_path = os.path.join(temp_dir, "new_dir")
38+
self.assertFalse(os.path.exists(new_dir_path))
39+
40+
# Act
41+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
42+
43+
# Assert
44+
self.assertTrue(os.path.isdir(new_dir_path))
45+
self.assertEqual(os.listdir(new_dir_path), [])
46+
47+
def test_mkdir_existing_non_empty_dir(self):
48+
"""Tests that an existing, non-empty directory's contents are unmodified."""
49+
with tempfile.TemporaryDirectory() as temp_dir:
50+
existing_dir_path = os.path.join(temp_dir, "existing_dir")
51+
os.makedirs(existing_dir_path)
52+
dummy_file_path = os.path.join(existing_dir_path, "dummy.txt")
53+
with open(dummy_file_path, "w", encoding="utf-8") as f:
54+
f.write("test")
55+
56+
# Act
57+
gcs_utils.mkdir_and_check_permissions(existing_dir_path)
58+
59+
# Assert
60+
self.assertTrue(os.path.isdir(existing_dir_path))
61+
self.assertEqual(os.listdir(existing_dir_path), ["dummy.txt"])
62+
63+
def test_mkdir_existing_read_only_dir(self):
64+
"""Tests that a PermissionError is raised for a read-only directory."""
65+
# /sys is a universally read-only directory, even for root
66+
read_only_dir_path = "/sys"
67+
with self.assertRaises(PermissionError):
68+
gcs_utils.mkdir_and_check_permissions(read_only_dir_path)
69+
70+
def test_mkdir_read_only_parent_dir(self):
71+
"""Tests that a PermissionError is raised when the parent is read-only."""
72+
# /sys is a universally read-only directory, even for root
73+
new_dir_path = "/sys/new_dir"
74+
with self.assertRaises(PermissionError):
75+
gcs_utils.mkdir_and_check_permissions(new_dir_path)
76+
77+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
78+
def test_mkdir_gcs_no_such_bucket(self, mock_storage_client):
79+
"""Tests that an exception is raised for a non-existent GCS bucket."""
80+
mock_client_instance = mock_storage_client.return_value
81+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
82+
gcs_path = "gs://no_such_bucket"
83+
84+
with self.assertRaises(FileNotFoundError):
85+
gcs_utils.mkdir_and_check_permissions(gcs_path)
86+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
87+
88+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
89+
def test_mkdir_gcs_no_such_bucket_with_path(self, mock_storage_client):
90+
"""Tests an exception for a non-existent bucket with a subdirectory."""
91+
mock_client_instance = mock_storage_client.return_value
92+
mock_client_instance.get_bucket.side_effect = Exception("Bucket not found!")
93+
gcs_path = "gs://no_such_bucket/some/dir"
94+
95+
with self.assertRaises(FileNotFoundError):
96+
gcs_utils.mkdir_and_check_permissions(gcs_path)
97+
mock_client_instance.get_bucket.assert_called_with("no_such_bucket")
98+
99+
@mock.patch("maxtext.utils.gcs_utils.uuid.uuid4", return_value="test-uuid")
100+
@mock.patch("maxtext.utils.gcs_utils.epath.Path")
101+
@mock.patch("maxtext.utils.gcs_utils.storage.Client")
102+
def test_mkdir_gcs_valid_bucket(self, mock_storage_client, mock_epath, mock_uuid):
103+
"""Tests that a valid GCS path is handled correctly without errors."""
104+
# Arrange: Mock the GCS client to simulate a valid bucket
105+
mock_client_instance = mock_storage_client.return_value
106+
107+
# Arrange: Mock epath to prevent real GCS calls
108+
mock_path_instance = mock.MagicMock()
109+
mock_path_instance.as_posix.return_value = "gs://valid_bucket/some/dir"
110+
mock_path_instance.parts = ["gs:", "", "valid_bucket", "some", "dir"]
111+
mock_path_instance.exists.return_value = True
112+
113+
mock_temp_file_instance = mock.MagicMock()
114+
mock_path_instance.__truediv__.return_value = mock_temp_file_instance
115+
116+
mock_epath.return_value = mock_path_instance
117+
gcs_path = "gs://valid_bucket/some/dir"
118+
119+
# Act
120+
gcs_utils.mkdir_and_check_permissions(gcs_path)
121+
122+
# Assert
123+
mock_client_instance.get_bucket.assert_called_with("valid_bucket")
124+
mock_path_instance.mkdir.assert_called_with(exist_ok=True, parents=True)
125+
mock_path_instance.exists.assert_called_once()
126+
mock_temp_file_instance.write_text.assert_called_once_with("test")
127+
mock_temp_file_instance.unlink.assert_called_once()

0 commit comments

Comments
 (0)