|
1 | | -# Copyright 2023–2025 Google LLC |
| 1 | +# Copyright 2023–2026 Google LLC |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
20 | 20 | import pytest |
21 | 21 |
|
22 | 22 | from maxtext.configs import pyconfig |
23 | | -from maxtext.common.goodput import create_goodput_recorder, maybe_monitor_goodput, maybe_record_goodput, GoodputEvent |
| 23 | +from maxtext.common.goodput import ( |
| 24 | + GoodputEvent, |
| 25 | + RECORD_JOB_END_TIME, |
| 26 | + RECORD_JOB_START_TIME, |
| 27 | + create_goodput_recorder, |
| 28 | + maybe_monitor_goodput, |
| 29 | + maybe_record_goodput, |
| 30 | + record_goodput, |
| 31 | +) |
24 | 32 | from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory |
25 | 33 |
|
26 | 34 | pytestmark = [pytest.mark.external_training] |
@@ -80,6 +88,59 @@ def test_monitor_goodput(self, mock_start_goodput_uploader, mock_stop_goodput_up |
80 | 88 | mock_start_goodput_uploader.assert_called() |
81 | 89 | mock_stop_goodput_uploader.assert_called() |
82 | 90 |
|
| 91 | + def test_job_recording_constants(self): |
| 92 | + """Constants must map to the recorder method names.""" |
| 93 | + self.assertEqual(RECORD_JOB_START_TIME, "record_job_start_time") |
| 94 | + self.assertEqual(RECORD_JOB_END_TIME, "record_job_end_time") |
| 95 | + |
| 96 | + @mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_end_time") |
| 97 | + @mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_start_time") |
| 98 | + @mock.patch("google.cloud.logging.Client") |
| 99 | + def test_explicit_job_recording_graceful_completion( |
| 100 | + self, mock_cloud_logger, mock_record_job_start_time, mock_record_job_end_time |
| 101 | + ): |
| 102 | + """Both start and end are recorded when the job completes gracefully.""" |
| 103 | + mock_cloud_logger.return_value = mock.MagicMock() |
| 104 | + recorder = create_goodput_recorder(self.config) |
| 105 | + |
| 106 | + record_goodput(recorder, RECORD_JOB_START_TIME) |
| 107 | + _job_completed_gracefully = False |
| 108 | + try: |
| 109 | + _job_completed_gracefully = True |
| 110 | + finally: |
| 111 | + if _job_completed_gracefully: |
| 112 | + record_goodput(recorder, RECORD_JOB_END_TIME) |
| 113 | + |
| 114 | + mock_record_job_start_time.assert_called_once() |
| 115 | + mock_record_job_end_time.assert_called_once() |
| 116 | + |
| 117 | + @mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_end_time") |
| 118 | + @mock.patch("ml_goodput_measurement.goodput.GoodputRecorder.record_job_start_time") |
| 119 | + @mock.patch("google.cloud.logging.Client") |
| 120 | + def test_explicit_job_recording_elastic_restart( |
| 121 | + self, mock_cloud_logger, mock_record_job_start_time, mock_record_job_end_time |
| 122 | + ): |
| 123 | + """Only start is recorded when the elastic manager handles the error internally. |
| 124 | +
|
| 125 | + This simulates the elastic-restart scenario: the manager catches the JAX |
| 126 | + exception inside train_loop, so the loop exits without raising. The |
| 127 | + _job_completed_gracefully flag is never set, so record_job_end_time must |
| 128 | + not be called. |
| 129 | + """ |
| 130 | + mock_cloud_logger.return_value = mock.MagicMock() |
| 131 | + recorder = create_goodput_recorder(self.config) |
| 132 | + |
| 133 | + record_goodput(recorder, RECORD_JOB_START_TIME) |
| 134 | + _job_completed_gracefully = False |
| 135 | + try: |
| 136 | + pass # Elastic manager caught and suppressed the exception. |
| 137 | + finally: |
| 138 | + if _job_completed_gracefully: |
| 139 | + record_goodput(recorder, RECORD_JOB_END_TIME) |
| 140 | + |
| 141 | + mock_record_job_start_time.assert_called_once() |
| 142 | + mock_record_job_end_time.assert_not_called() |
| 143 | + |
83 | 144 |
|
84 | 145 | if __name__ == "__main__": |
85 | 146 | unittest.main() |
0 commit comments