Skip to content

Commit c6a7ef6

Browse files
committed
feat: Improve permission tests and squash feature branch
- Reworks tests for read-only directories to use /sys instead of temporary directories with chmod. This fixes CI failures when tests are run as root. - Squashes commits related to int4 support, checkpoint directory checks, and RL parsing unit tests.
1 parent ccd91f4 commit c6a7ef6

13 files changed

Lines changed: 361 additions & 14 deletions

File tree

.github/CODEOWNERS

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
# Model bring-up
44
src/MaxText/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande
5-
src/MaxText/configs/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
5+
src/MaxText/configs/models @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google @NuojCheng
66
src/maxtext/checkpoint_conversion @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @hengtaoguo @gagika @shralex @richjames0 @NicoGrande
7-
src/MaxText/layers @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google
7+
src/MaxText/layers @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande @suexu1025 @jesselu-google @NuojCheng
88

99
# Features
1010
src/maxtext/experimental/rl @A9isha @khatwanimohit @xuefgu @gagika @richjames0 @shralex @NicoGrande
1111
src/MaxText/input_pipeline @aireenmei @SurbhiJainUSC @richjames0 @shralex @NicoGrande
1212
src/MaxText/kernels/megablox @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
1313
src/MaxText/kernels/ragged_attention.py @patemotter @vipannalla @richjames0 @shralex
14-
src/MaxText/layers/pipeline.py @gobbleturk @richjames0 @shralex
14+
src/MaxText/layers/pipeline.py @gobbleturk @richjames0 @shralex @NuojCheng
1515
src/MaxText/layers/moe.py @RissyRan @michelle-yooh @gagika @richjames0 @shralex @suexu1025 @jesselu-google
1616
src/MaxText/layers/multi_token_prediction.py @parambole @RissyRan @gagika @richjames0 @shralex
1717
src/MaxText/elastic_train.py @lukebaumann @shauryagup @richjames0 @shralex

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ jax
1717
jaxlib
1818
jaxtyping
1919
jsonlines
20+
math-verify
2021
ml-collections
2122
ml-goodput-measurement
2223
numpy

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ lxml>=6.0.2
120120
markdown-it-py>=4.0.0
121121
markdown>=3.10
122122
markupsafe>=3.0.3
123+
math-verify>=0.9.0
123124
matplotlib>=3.10.7
124125
mccabe>=0.7.0
125126
mdurl>=0.1.2

src/maxtext/common/checkpointing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.input_pipeline.synthetic_data_processing import PlaceHolderDataIterator
2929
from maxtext.utils import exceptions
3030
from maxtext.utils import max_logging
31+
from maxtext.utils import gcs_utils
3132
import numpy as np
3233
import orbax.checkpoint as ocp
3334
from orbax.checkpoint import v1 as ocp_v1
@@ -245,8 +246,7 @@ def create_orbax_checkpoint_manager(
245246
item_handlers["iter"] = GrainCheckpointHandler()
246247

247248
# local storage checkpoint needs parent directory created
248-
p = epath.Path(checkpoint_dir)
249-
p.mkdir(exist_ok=True, parents=True)
249+
p = gcs_utils.mkdir_and_check_permissions(checkpoint_dir)
250250
if enable_continuous_checkpointing:
251251
save_decision_policy = save_decision_policy_lib.ContinuousCheckpointingPolicy()
252252
preservation_policy = preservation_policy_lib.LatestN(max_num_checkpoints_to_keep)
@@ -300,19 +300,19 @@ def create_orbax_emergency_checkpoint_manager(
300300
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
301301
max_logging.log("Creating emergency checkpoint manager...")
302302

303+
persistent_p = gcs_utils.mkdir_and_check_permissions(persistent_checkpoint_dir)
304+
303305
# Only create directories if running on GPUs as the previous
304306
# directory structure might be assumed by TPUs
305307
if global_mesh.devices.flatten()[0].platform == "gpu":
306308
# pylint: disable=protected-access
307309
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
308310
local_p = epath.Path(local_checkpoint_dir)
309-
persistent_p = epath.Path(persistent_checkpoint_dir)
310311
local_p.mkdir(exist_ok=True, parents=True)
311-
persistent_p.mkdir(exist_ok=True, parents=True)
312312

313313
manager = EmergencyCheckpointManager(
314314
local_checkpoint_dir,
315-
epath.Path(persistent_checkpoint_dir),
315+
persistent_p,
316316
global_mesh=global_mesh,
317317
abstract_state=abstract_state,
318318
options=emergency_checkpoint_manager.CheckpointManagerOptions(

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class QuantizationType(str, Enum):
8181
"""Supported quantization schemes."""
8282

8383
NONE = ""
84+
INT4 = "int4"
8485
INT8 = "int8"
8586
INTMP = "intmp"
8687
FP8 = "fp8"

src/maxtext/layers/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
13181318
)
13191319

13201320
# Sum up the partial outputs across the expert shards.
1321-
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim))
1321+
output = jnp.reshape(output, (-1, sequence_length, self.config.emb_dim // self.get_tensor_parallelism_size()))
13221322
output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True)
13231323

13241324
else:

src/maxtext/layers/quantizations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,15 @@ def get_fp8_full_qwix_rule(config: Config):
655655

656656
def get_quantization_rule(config: Config):
657657
match config.quantization:
658+
case "int4":
659+
return qwix.QtRule(
660+
module_path="decoder/.*layers.*",
661+
weight_qtype=jnp.int4,
662+
act_qtype=jnp.int4,
663+
bwd_qtype=jnp.int4,
664+
bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count,
665+
op_names=("dot_general",),
666+
)
658667
case "int8":
659668
return qwix.QtRule(
660669
module_path="decoder/.*layers.*",
@@ -702,6 +711,8 @@ def get_qt_provider(config):
702711
match config.quantization:
703712
case "int8":
704713
return qwix.QtProvider([get_quantization_rule(config)])
714+
case "int4":
715+
return qwix.QtProvider([get_quantization_rule(config)])
705716
case "fp8":
706717
return qwix.QtProvider([get_quantization_rule(config)])
707718
case "fp8_full":

src/maxtext/trainers/post_train/rl/evaluate_rl.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def score_responses(tmvp_config, question, responses, answer):
100100
Tuple of (is_correct, is_partially_correct, has_correct_format)
101101
"""
102102
match_format = utils_rl.get_match_format_regex(tmvp_config)
103+
answer_fallback = utils_rl.get_answer_fallback_regex(tmvp_config)
103104

104105
if tmvp_config.debug.rl:
105106
max_logging.log("========================================")
@@ -113,10 +114,19 @@ def score_responses(tmvp_config, question, responses, answer):
113114
has_correct_format = False
114115

115116
for response in responses:
116-
# Extract numerical response
117-
extracted_response = guess.group(1) if (guess := match_format.search(response)) is not None else "-1000000"
117+
# Extract answer: prefer the full format match; fall back to the last
118+
# <answer>...</answer> tag if full format match is not found, so result
119+
# scoring is decoupled from format.
120+
full_match = match_format.search(response)
121+
if full_match is not None:
122+
extracted_response = full_match.group(1)
123+
else:
124+
# Find the *last* occurrence of the answer tag (most likely the final answer).
125+
fallback_matches = answer_fallback.findall(response)
126+
extracted_response = fallback_matches[-1].strip() if fallback_matches else "-1000000"
118127
if tmvp_config.debug.rl:
119-
max_logging.log(f"Evaluation extracted_response: {extracted_response}")
128+
used = "full format" if full_match is not None else "answer-tag fallback"
129+
max_logging.log(f"Evaluation extracted_response ({used}): {extracted_response}")
120130

121131
# Check exact correctness
122132
try:
@@ -146,8 +156,8 @@ def score_responses(tmvp_config, question, responses, answer):
146156
max_logging.log(f"Evaluation Exception: {e}")
147157
max_logging.log("SKIPPED")
148158

149-
# Check format correctness
150-
if match_format.search(response) is not None:
159+
# Check format correctness (requires the full <reasoning>...</reasoning><answer>...</answer> structure)
160+
if full_match is not None:
151161
has_correct_format = True
152162

153163
# Early exit if all criteria are met

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def get_match_format_regex(tmvp_config):
118118
return match_format
119119

120120

121+
def get_answer_fallback_regex(tmvp_config):
122+
"""Returns a compiled regex that finds the *last* answer tag in a completion.
123+
124+
Used as a fallback when the full <reasoning>...</reasoning><answer>...</answer>
125+
format is incomplete (e.g. missing the closing reasoning tag). The result
126+
reward can still be computed independently from the format reward.
127+
"""
128+
return re.compile(
129+
rf"{re.escape(tmvp_config.solution_start_token)}(.+?){re.escape(tmvp_config.solution_end_token)}",
130+
flags=re.MULTILINE | re.DOTALL,
131+
)
132+
133+
121134
def match_format_exactly(prompts, completions, tmvp_config, **kargs):
122135
"""
123136
Give the model a reward of tmvp_config.reward_exact_format_match points if the format matches exactly.

src/maxtext/utils/gcs_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import os
1919
import socket
2020
from pathlib import Path
21+
from etils import epath
22+
import uuid
2123

2224
import yaml
2325

@@ -242,3 +244,45 @@ def write_dict_to_gcs_json(data_dict, file_path):
242244
blob.upload_from_string(json_string, content_type="application/json")
243245
except (ValueError, TypeError, RecursionError) as e:
244246
print(f"Failed to write json file at {file_path} with error: {str(e)}")
247+
248+
249+
def mkdir_and_check_permissions(p: str | epath.Path) -> epath.Path:
250+
"""Creates a directory if it doesn't exist and verifies write permissions.
251+
252+
This function prevents the program from hanging when an output directory is inaccessible. The standard
253+
`epath.Path.mkdir` can hang or fail silently when pointed at a path in a non-existent or inaccessible GCS bucket.
254+
255+
For example, the following code can hang indefinitely:
256+
257+
from etils import epath
258+
p = epath.Path("gs://no_such_bucket/path/to/output")
259+
p.mkdir(exist_ok=True, parents=True)
260+
"""
261+
if isinstance(p, str):
262+
p = epath.Path(p)
263+
264+
if p.as_posix().startswith("gs://"):
265+
bucket_name = p.parts[2]
266+
try:
267+
storage_client = storage.Client()
268+
storage_client.get_bucket(bucket_name)
269+
except Exception as e:
270+
raise FileNotFoundError(f"GCS bucket 'gs://{bucket_name}' not found or accessible.") from e
271+
p.mkdir(exist_ok=True, parents=True)
272+
if not p.exists():
273+
raise PermissionError(f"Failed to create the directory '{p}'. Please check that you have write access.")
274+
275+
# Verify write permissions by creating and deleting a temporary file.
276+
# This handles the case where the directory exists but is not writable.
277+
temp_file_path = p / f".write_test_{uuid.uuid4()}"
278+
try:
279+
temp_file_path.write_text("test")
280+
except Exception as e: # pylint: disable=broad-exception-caught
281+
raise PermissionError(f"Directory '{p}' exists, but is not writable. Please check permissions.") from e
282+
finally:
283+
try:
284+
temp_file_path.unlink() # Delete the temp file.
285+
except Exception: # pylint: disable=broad-exception-caught
286+
pass # Suppress errors during cleanup to not hide the original error.
287+
288+
return p

0 commit comments

Comments
 (0)