Skip to content

Commit 35a2f3b

Browse files
angel-coreGoogle-ML-Automation
authored andcommitted
Migrate Tunix's usage of Orbax V0 Checkpoint Manager to V1 Checkpointer.
PiperOrigin-RevId: 927511408
1 parent 6ff97f1 commit 35a2f3b

10 files changed

Lines changed: 102 additions & 11 deletions

File tree

.github/workflows/run_tests_against_package.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ jobs:
186186
else
187187
$PYTHON_EXE -m pytest ${INPUTS_PYTEST_ADDOPTS} \
188188
-v \
189+
-s \
189190
-m "${FINAL_PYTEST_MARKER}" \
190191
--durations=0 \
191192
$PYTEST_COV_ARGS \
@@ -203,6 +204,9 @@ jobs:
203204
INPUTS_PYTEST_EXTRA_ARGS: ${{ inputs.pytest_extra_args }}
204205
INPUTS_MAXTEXT_INSTALLED: ${{ inputs.maxtext_installed }}
205206
INPUTS_IS_UPDATE_HLO: ${{ inputs.is_update_hlo }}
207+
- name: surface hang dump
208+
if: always()
209+
run: cat "$GITHUB_WORKSPACE/hang_watchdog_dump.txt" || true
206210
- name: Upload Reference HLO
207211
if: ${{ inputs.is_update_hlo }}
208212
uses: actions/upload-artifact@v4

docs/tutorials/posttraining/knowledge_distillation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
234234
The online distillation trainer depends on Tunix. The XPK launcher script ([`scripts/run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh)) contains a `prep_image` step that layers Tunix on top of the MaxText base image. For local runs, install the same pin used by the launcher — the default `TUNIX_SOURCE` in `run_distill_xpk.sh` is the source of truth. As of this writing:
235235

236236
```bash
237-
pip install "git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1"
237+
pip install "git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef"
238238
```
239239

240240
> **Note:** The commit pin above will drift as the launcher is updated. Before installing, check the `TUNIX_SOURCE` default in [`run_distill_xpk.sh`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh) and use that spec. Once a Tunix PyPI release ships, this will become a versioned `google-tunix==<ver>` install.
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
google-tunix @ https://github.com/google/tunix/archive/387072374f99a100cb11f99dec951940b1475a04.zip
1+
orbax-checkpoint @ https://github.com/google/orbax/archive/030a16419688ca45d95e92990aeeb93891e12ec0.zip#subdirectory=checkpoint
2+
google-tunix @ https://github.com/google/tunix/archive/44af800726dd5b2c5779a1987a9294f9a3eec9ef.zip
23
tpu-inference @ https://github.com/vllm-project/tpu-inference/archive/a46baf9ee149da0fbc1cfe335650e3780e30b585.zip
34
vllm @ git+https://github.com/vllm-project/vllm@a51376b3f05a2f74eac6ceeed7e52598b871a0fb

src/dependencies/extra_deps/tpu_post_train_overrides.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ datasets>=4.8.5
22
flax==0.12.4
33
fsspec==2026.2.0
44
gcsfs==2026.2.0
5-
google-metrax>=0.2.3
5+
google-metrax>=0.2.4
66
optax==0.2.6

src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ google-cloud-storage>=3.10.1
9595
google-cloud-storage-control>=1.11.0
9696
google-crc32c>=1.8.0
9797
google-genai>=2.4.0
98-
google-metrax>=0.2.3
98+
google-metrax>=0.2.4
9999
google-pasta>=0.2.0
100100
google-resumable-media>=2.9.0
101101
google-tunix>=0.1.3

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,15 +661,16 @@ def __init__(
661661
super().__init__(root_directory=root_directory, options=options)
662662
self.student_config = student_config
663663
self._iterator = raw_iterator
664+
self._checkpoint_manager: checkpoint.CheckpointManager | None = None
664665

665666
# Re-initialize internal Orbax manager with MaxText's Grain handler
666667
# pylint: disable=access-member-before-definition
667668
# pytype: disable=attribute-error
668-
if self._checkpoint_manager is not None:
669-
root_directory = self._checkpoint_manager.directory
669+
if self._checkpointer is not None:
670+
root_directory = self._checkpointer.directory
670671

671672
if options is None:
672-
options = getattr(self._checkpoint_manager, "options", None)
673+
options = getattr(self._checkpointer, "options", None) or getattr(self._checkpointer._manager, "options", None)
673674

674675
item_handlers = {
675676
"model_params": checkpoint.PyTreeCheckpointHandler(),
@@ -679,12 +680,13 @@ def __init__(
679680
"iter": GrainCheckpointHandler(),
680681
}
681682

682-
self._checkpoint_manager.close()
683-
self._checkpoint_manager = checkpoint.CheckpointManager(
683+
self._checkpointer._manager.close()
684+
self._checkpointer._manager = checkpoint.CheckpointManager(
684685
root_directory,
685686
item_handlers=item_handlers,
686687
options=options,
687688
)
689+
self._checkpoint_manager = self._checkpointer._manager
688690
# pytype: enable=attribute-error
689691
# pylint: enable=access-member-before-definition
690692

src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
#
106106
# Image pinning (used by prep_image):
107107
# TUNIX_SOURCE pip-installable spec for tunix.
108-
# default: git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1
108+
# default: git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef
109109
# Use "google-tunix==<ver>" once a pypi release ships with the
110110
# multi-host shard_input fix.
111111
# JAX_PIN default: 0.10.0 — version to pin back after tunix deps resolve.
@@ -164,7 +164,7 @@ require_env() {
164164
: "${DISTILL_LAYER_INDICES:=[0,1,2,3,4,5,6,7]}"
165165

166166
# Image pinning (used by prep_image).
167-
: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@110932a8395086511228483312131841521695c1}"
167+
: "${TUNIX_SOURCE:=git+https://github.com/google/tunix@44af800726dd5b2c5779a1987a9294f9a3eec9ef}"
168168
: "${JAX_PIN:=0.10.0}"
169169
: "${JAXLIB_PIN:=0.10.0}"
170170
: "${LIBTPU_PIN:=0.0.39}"

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,11 @@ def setup_checkpoint_manager_and_restore(self, raw_train_iter, config):
523523

524524
# 3. Restore Model & Optimizer State correctly via MaxTextCheckpointManager.
525525
# Accessing protected variables of the base class IS allowed inside the subclass!
526+
# Fence: wait for the freshly-built teacher + student + optimizer device programs
527+
# to finish. Orbax v1 runs its restore device_put transfers on a background
528+
# thread, if they race the still-in-flight model build they can deadlock on
529+
# TPU with no timeout. Draining here removes that concurrency.
530+
jax.block_until_ready((nnx.state(self.model), nnx.state(self.optimizer)))
526531
self._train_steps, self._restored_custom_metadata = self.checkpoint_manager.maybe_restore(
527532
self.model,
528533
self.optimizer,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Pytest configuration and fixtures for post-training unit tests."""
16+
17+
import faulthandler
18+
import os
19+
import sys
20+
import threading
21+
22+
import pytest
23+
24+
25+
_DUMP_PATH = (
26+
os.environ.get("HANG_DUMP_FILE")
27+
or os.environ.get("GITHUB_STEP_SUMMARY")
28+
or "/tmp/hang_watchdog_dump.txt"
29+
)
30+
_DUMP_FH = open(_DUMP_PATH, "a", buffering=1)
31+
os.environ["HANG_DUMP_FILE"] = _DUMP_PATH
32+
_DUMP_AFTER_SECS = float(os.environ.get("HANG_DUMP_AFTER_SECS", "300"))
33+
_EXIT_AFTER_SECS = float(os.environ.get("HANG_EXIT_AFTER_SECS", "900"))
34+
faulthandler.enable(file=_DUMP_FH, all_threads=True)
35+
36+
37+
def _dump(header):
38+
for sink in (_DUMP_FH, sys.__stderr__):
39+
try:
40+
sink.write("\n" + header + "\n")
41+
sink.flush()
42+
faulthandler.dump_traceback(file=sink, all_threads=True)
43+
sink.flush()
44+
except Exception:
45+
pass
46+
47+
48+
@pytest.fixture(autouse=True)
49+
def _hang_watchdog(request):
50+
"""Watchdog fixture to detect and dump stack traces for hanging tests."""
51+
node = request.node.nodeid
52+
stop = threading.Event()
53+
54+
def _watch():
55+
waited = 0.0
56+
while not stop.wait(_DUMP_AFTER_SECS):
57+
waited += _DUMP_AFTER_SECS
58+
_dump(
59+
f"===== HANG WATCHDOG: {node!r} still running after {int(waited)}s;"
60+
" all threads: ====="
61+
)
62+
if waited >= _EXIT_AFTER_SECS:
63+
_dump("===== HANG WATCHDOG: aborting process for CI =====")
64+
try:
65+
os.fsync(_DUMP_FH.fileno())
66+
except Exception:
67+
pass
68+
os._exit(99)
69+
70+
t = threading.Thread(target=_watch, name="hang-watchdog", daemon=True)
71+
t.start()
72+
try:
73+
yield
74+
finally:
75+
stop.set()

tests/post_training/unit/train_distill_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,8 @@ def test_train_save_and_resume(self, mock_build_tokenizer, mock_writer):
890890
teacher_config_1 = pyconfig.initialize(argv_run1, **global_config_1.teacher_overrides)
891891

892892
# Execute first run
893+
with open(os.environ["HANG_DUMP_FILE"], "a") as _f:
894+
_f.write("\n>>> PHASE: RUN 1 -- train + save (step 1)\n")
893895
train_distill.train_distill(student_config_1, teacher_config_1)
894896

895897
# Run 2: Resume and train up to step 2
@@ -908,6 +910,8 @@ def side_effect(self, *args, **kwargs):
908910
mock_restore.side_effect = side_effect
909911

910912
# Execute second run
913+
with open(os.environ["HANG_DUMP_FILE"], "a") as _f:
914+
_f.write("\n>>> PHASE: RUN 2 -- restore step 1 + train to step 2\n")
911915
train_distill.train_distill(student_config_2, teacher_config_2)
912916

913917
# Verify that restore was called and returned train_steps = 1

0 commit comments

Comments
 (0)