Skip to content

Commit 4af13bf

Browse files
Merge pull request #3235 from AI-Hypercomputer:pw-elastic-training-gp-fixes
PiperOrigin-RevId: 878677943
2 parents 94f760d + e7b041d commit 4af13bf

7 files changed

Lines changed: 128 additions & 22 deletions

File tree

.github/CODEOWNERS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google
1+
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @gagika @shralex @SurbhiJainUSC @hengtaoguo @A9isha @aireenmei @NuojCheng @jiangjy1982 @suexu1025 @NicoGrande @jesselu-google @dipannita08 @igorts-git
22

33
# Model bring-up
44
src/MaxText/assets @parambole @shuningjin @RissyRan @suexu1025 @jiangjy1982 @gobbleturk @bvandermoon @gagika @shralex @richjames0 @NicoGrande

src/maxtext/common/goodput.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -36,6 +36,12 @@ class GoodputEvent(Enum):
3636
STEP = "step"
3737

3838

39+
# Recorder method name constants for explicit job start/end recording.
40+
# Derived from the enum so they stay in sync if the value ever changes.
41+
RECORD_JOB_START_TIME = f"record_{GoodputEvent.JOB.value}_start_time"
42+
RECORD_JOB_END_TIME = f"record_{GoodputEvent.JOB.value}_end_time"
43+
44+
3945
@contextlib.contextmanager
4046
def maybe_monitor_goodput(config):
4147
"""Monitor cumulative goodput if enabled on the lead host.
@@ -83,18 +89,22 @@ def maybe_monitor_goodput(config):
8389

8490
@contextlib.contextmanager
8591
def maybe_record_goodput(recorder, event_name, *args):
86-
"""Record goodput if `enable_goodput_recording=True`."""
92+
"""Record goodput if `enable_goodput_recording=True`.
93+
94+
The end-time event is only recorded when the wrapped block exits without
95+
raising an exception (i.e. the event truly completed). Callers that need
96+
explicit end-time control — e.g. GoodputEvent.JOB under elastic training
97+
where the elastic manager may suppress the JAX exception internally —
98+
should call record_goodput directly rather than using this context manager.
99+
"""
100+
record_goodput(recorder, f"record_{event_name.value}_start_time", *args)
101+
completed = False
87102
try:
88-
start_event_name = f"record_{event_name.value}_start_time"
89-
record_goodput(recorder, start_event_name, *args)
90103
yield
91-
except BaseException: # pylint: disable=W0706
92-
raise
93-
else:
94-
end_event_name = f"record_{event_name.value}_end_time"
95-
record_goodput(recorder, end_event_name, *args)
104+
completed = True
96105
finally:
97-
pass
106+
if completed:
107+
record_goodput(recorder, f"record_{event_name.value}_end_time", *args)
98108

99109

100110
def record_goodput(recorder, event_name, *args):

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -74,9 +74,12 @@
7474
from maxtext.common.data_loader import DataLoader
7575
from maxtext.common.goodput import (
7676
GoodputEvent,
77+
RECORD_JOB_END_TIME,
78+
RECORD_JOB_START_TIME,
7779
create_goodput_recorder,
7880
maybe_monitor_goodput,
7981
maybe_record_goodput,
82+
record_goodput,
8083
)
8184
from maxtext.experimental.rl import grpo_input_pipeline
8285
from maxtext.experimental.rl import grpo_utils
@@ -795,6 +798,7 @@ def generation_worker_fn(
795798
)
796799
generation_thread.start()
797800

801+
_job_completed_gracefully = False
798802
try:
799803
last_step_completion = datetime.datetime.now()
800804
for step in np.arange(start_step, config.steps):
@@ -881,9 +885,13 @@ def generation_worker_fn(
881885
elif checkpoint_manager is not None:
882886
# in case the last checkpoint_period checkpoint is still in progress
883887
checkpoint_manager.wait_until_finished()
888+
_job_completed_gracefully = True
884889
except exceptions.StopTraining as e:
885890
max_logging.log(f"Training stopped: {str(e)}")
891+
_job_completed_gracefully = True
886892
finally:
893+
if _job_completed_gracefully:
894+
record_goodput(recorder, RECORD_JOB_END_TIME)
887895
metric_logger.flush_metrics_and_cleanup()
888896
max_logging.log("Training loop finished or exited. Signaling generation worker to stop.")
889897
stop_event.set()
@@ -955,8 +963,9 @@ def main(argv: Sequence[str]) -> None:
955963
)
956964
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)
957965

966+
record_goodput(recorder, RECORD_JOB_START_TIME)
958967
with diagnostic.diagnose(diagnostic_config):
959-
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
968+
with maybe_monitor_goodput(config):
960969
train_loop(config, config_inference, recorder)
961970

962971

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@
5353
from maxtext.trainers.pre_train.train import loss_fn
5454
from maxtext.common.goodput import (
5555
GoodputEvent,
56+
RECORD_JOB_END_TIME,
57+
RECORD_JOB_START_TIME,
5658
create_goodput_recorder,
5759
maybe_monitor_goodput,
5860
maybe_record_goodput,
61+
record_goodput,
5962
)
6063
from maxtext.optimizers import optimizers
6164
from maxtext.trainers.post_train.sft import hooks
@@ -181,7 +184,13 @@ def train(mt_config, goodput_recorder=None):
181184
goodput_recorder: An optional GoodputRecorder to record performance metrics.
182185
"""
183186
trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
184-
trainer = train_model(mt_config, trainer, mesh)
187+
_job_completed_gracefully = False
188+
try:
189+
trainer = train_model(mt_config, trainer, mesh)
190+
_job_completed_gracefully = True
191+
finally:
192+
if _job_completed_gracefully:
193+
record_goodput(goodput_recorder, RECORD_JOB_END_TIME)
185194
return trainer, mesh
186195

187196

@@ -198,8 +207,8 @@ def main(argv: Sequence[str]) -> None:
198207
max_utils.print_system_information()
199208

200209
goodput_recorder = create_goodput_recorder(mt_config)
201-
202-
with maybe_record_goodput(goodput_recorder, GoodputEvent.JOB), maybe_monitor_goodput(mt_config):
210+
record_goodput(goodput_recorder, RECORD_JOB_START_TIME)
211+
with maybe_monitor_goodput(mt_config):
203212
train(mt_config, goodput_recorder)
204213

205214

src/maxtext/trainers/post_train/sft/train_sft_deprecated.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -37,9 +37,12 @@
3737
from maxtext.common.data_loader import DataLoader
3838
from maxtext.common.goodput import (
3939
GoodputEvent,
40+
RECORD_JOB_END_TIME,
41+
RECORD_JOB_START_TIME,
4042
create_goodput_recorder,
4143
maybe_monitor_goodput,
4244
maybe_record_goodput,
45+
record_goodput,
4346
)
4447
from maxtext.common.metric_logger import MetricLogger
4548
from maxtext.utils import exceptions
@@ -90,6 +93,7 @@ def train_loop(config, recorder, state=None):
9093
# Write train config params, num model params, and XLA flags to tensorboard
9194
metric_logger.write_setup_info_to_tensorboard(state.params)
9295

96+
_job_completed_gracefully = False
9397
try:
9498
last_step_completion = datetime.datetime.now()
9599
for step in np.arange(start_step, config.steps):
@@ -147,9 +151,13 @@ def train_loop(config, recorder, state=None):
147151
elif checkpoint_manager is not None:
148152
# in case the last checkpoint_period checkpoint is still in progress
149153
checkpoint_manager.wait_until_finished()
154+
_job_completed_gracefully = True
150155
except exceptions.StopTraining as e:
151156
max_logging.log(f"Training stopped: {str(e)}")
157+
_job_completed_gracefully = True
152158
finally:
159+
if _job_completed_gracefully:
160+
record_goodput(recorder, RECORD_JOB_END_TIME)
153161
metric_logger.flush_metrics_and_cleanup()
154162

155163
return state
@@ -172,7 +180,8 @@ def main(argv: Sequence[str]) -> None:
172180
os.environ["TFDS_DATA_DIR"] = config.dataset_path
173181

174182
recorder = create_goodput_recorder(config)
175-
with maybe_record_goodput(recorder, GoodputEvent.JOB), maybe_monitor_goodput(config):
183+
record_goodput(recorder, RECORD_JOB_START_TIME)
184+
with maybe_monitor_goodput(config):
176185
train_loop(config, recorder)
177186

178187

src/maxtext/trainers/pre_train/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -48,9 +48,12 @@
4848
from maxtext.common import checkpointing, profiler
4949
from maxtext.common.goodput import (
5050
GoodputEvent,
51+
RECORD_JOB_END_TIME,
52+
RECORD_JOB_START_TIME,
5153
create_goodput_recorder,
5254
maybe_monitor_goodput,
5355
maybe_record_goodput,
56+
record_goodput,
5457
)
5558
from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled
5659
from maxtext.common.gcloud_stub import vertex_tensorboard_modules
@@ -493,6 +496,7 @@ def train_loop(config, recorder, state=None):
493496
# Write train config params, num model params, and XLA flags to tensorboard
494497
metric_logger.write_setup_info_to_tensorboard(state.params)
495498

499+
_job_completed_gracefully = False
496500
try:
497501
last_step_completion = datetime.datetime.now()
498502
for step in np.arange(start_step, config.steps):
@@ -558,9 +562,13 @@ def train_loop(config, recorder, state=None):
558562
if checkpoint_manager is not None:
559563
# in case the last checkpoint_period checkpoint is still in progress
560564
checkpoint_manager.wait_until_finished()
565+
_job_completed_gracefully = True
561566
except exceptions.StopTraining as e:
562567
max_logging.log(f"Training stopped: {str(e)}")
568+
_job_completed_gracefully = True
563569
finally:
570+
if _job_completed_gracefully:
571+
record_goodput(recorder, RECORD_JOB_END_TIME)
564572
metric_logger.flush_metrics_and_cleanup()
565573

566574
return state
@@ -623,14 +631,14 @@ def run(config, recorder, diagnostic_config):
623631

624632
with (
625633
diagnostics_context,
626-
maybe_record_goodput(recorder, GoodputEvent.JOB),
627634
max_utils.maybe_get_transformer_engine_context(config),
628635
):
629636
train_loop(config, recorder)
630637

631638

632639
def main(argv: Sequence[str]) -> None:
633640
config, recorder, diagnostic_config = initialize(argv)
641+
record_goodput(recorder, RECORD_JOB_START_TIME)
634642
with maybe_monitor_goodput(config):
635643
run(config, recorder, diagnostic_config)
636644

tests/unit/goodput_utils_test.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -20,7 +20,15 @@
2020
import pytest
2121

2222
from maxtext.configs import pyconfig
23-
from maxtext.common.goodput import create_goodput_recorder, maybe_monitor_goodput, maybe_record_goodput, GoodputEvent
23+
from maxtext.common.goodput import (
24+
GoodputEvent,
25+
RECORD_JOB_END_TIME,
26+
RECORD_JOB_START_TIME,
27+
create_goodput_recorder,
28+
maybe_monitor_goodput,
29+
maybe_record_goodput,
30+
record_goodput,
31+
)
2432
from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory
2533

2634
pytestmark = [pytest.mark.external_training]
@@ -80,6 +88,59 @@ def test_monitor_goodput(self, mock_start_goodput_uploader, mock_stop_goodput_up
8088
mock_start_goodput_uploader.assert_called()
8189
mock_stop_goodput_uploader.assert_called()
8290

91+
def test_job_recording_constants(self):
92+
"""Constants must map to the recorder method names."""
93+
self.assertEqual(RECORD_JOB_START_TIME, "record_job_start_time")
94+
self.assertEqual(RECORD_JOB_END_TIME, "record_job_end_time")
95+
96+
@mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_end_time")
97+
@mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_start_time")
98+
@mock.patch("google.cloud.logging.Client")
99+
def test_explicit_job_recording_graceful_completion(
100+
self, mock_cloud_logger, mock_record_job_start_time, mock_record_job_end_time
101+
):
102+
"""Both start and end are recorded when the job completes gracefully."""
103+
mock_cloud_logger.return_value = mock.MagicMock()
104+
recorder = create_goodput_recorder(self.config)
105+
106+
record_goodput(recorder, RECORD_JOB_START_TIME)
107+
_job_completed_gracefully = False
108+
try:
109+
_job_completed_gracefully = True
110+
finally:
111+
if _job_completed_gracefully:
112+
record_goodput(recorder, RECORD_JOB_END_TIME)
113+
114+
mock_record_job_start_time.assert_called_once()
115+
mock_record_job_end_time.assert_called_once()
116+
117+
@mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_end_time")
118+
@mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_start_time")
119+
@mock.patch("google.cloud.logging.Client")
120+
def test_explicit_job_recording_elastic_restart(
121+
self, mock_cloud_logger, mock_record_job_start_time, mock_record_job_end_time
122+
):
123+
"""Only start is recorded when the elastic manager handles the error internally.
124+
125+
This simulates the elastic-restart scenario: the manager catches the JAX
126+
exception inside train_loop, so the loop exits without raising. The
127+
_job_completed_gracefully flag is never set, so record_job_end_time must
128+
not be called.
129+
"""
130+
mock_cloud_logger.return_value = mock.MagicMock()
131+
recorder = create_goodput_recorder(self.config)
132+
133+
record_goodput(recorder, RECORD_JOB_START_TIME)
134+
_job_completed_gracefully = False
135+
try:
136+
pass # Elastic manager caught and suppressed the exception.
137+
finally:
138+
if _job_completed_gracefully:
139+
record_goodput(recorder, RECORD_JOB_END_TIME)
140+
141+
mock_record_job_start_time.assert_called_once()
142+
mock_record_job_end_time.assert_not_called()
143+
83144

84145
if __name__ == "__main__":
85146
unittest.main()

0 commit comments

Comments
 (0)