Skip to content

Commit 1907615

Browse files
Merge pull request #3691 from AI-Hypercomputer:fix-sft-eval-metrics
PiperOrigin-RevId: 901518896
2 parents dbc1584 + 3797c23 commit 1907615

2 files changed

Lines changed: 44 additions & 20 deletions

File tree

src/maxtext/trainers/post_train/sft/hooks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def on_eval_step_end(self, train_ctx: peft_trainer.PeftTrainer, eval_loss: float
161161
"scalar": {
162162
"eval/total_loss": eval_loss,
163163
"eval/avg_loss": avg_loss,
164+
"eval/avg_perplexity": jnp.exp(avg_loss),
164165
"eval/total_weights": self.eval_metadata["total_weights"],
165166
}
166167
}

tests/post_training/unit/sft_hooks_test.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
"""Tests for training and data loading hooks for SFT"""
16+
17+
from collections import defaultdict
1618
import pytest
1719

1820
pytest.importorskip("tunix")
@@ -21,33 +23,48 @@
2123
import jax
2224

2325
import numpy as np
26+
import json
2427
import os
28+
import shutil
29+
import tempfile
2530
import unittest
2631
from unittest.mock import MagicMock, patch
2732
from jax.sharding import Mesh
2833

2934
from maxtext.configs import pyconfig
3035
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
3136
from maxtext.trainers.post_train.sft import hooks
37+
from maxtext.common.metric_logger import MetricLogger
3238
from maxtext.utils import maxtext_utils
3339

3440

3541
class SFTHooksTest(unittest.TestCase):
3642

3743
def setUp(self):
3844
super().setUp()
45+
self.test_dir = tempfile.mkdtemp()
46+
self.metrics_file = os.path.join(self.test_dir, "metrics.txt")
3947
self.config = pyconfig.initialize(
4048
["", os.path.join(MAXTEXT_CONFIGS_DIR, "post_train", "sft.yml")],
4149
per_device_batch_size=1,
4250
run_name="test",
43-
base_output_directory="test",
51+
base_output_directory=self.test_dir,
52+
tensorboard_dir=self.test_dir,
53+
metrics_dir=self.test_dir,
54+
metrics_file=self.metrics_file,
4455
skip_jax_distributed_system=True,
4556
)
4657
self.mesh = Mesh(maxtext_utils.create_device_mesh(self.config), self.config.mesh_axes)
4758
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(self.config)
4859

4960
self.training_hooks = hooks.SFTTrainingHooks(self.config, self.mesh, learning_rate_schedule, goodput_recorder=None)
50-
self.training_hooks.metric_logger = MagicMock()
61+
# We will use the written metrics to validate the correctness.
62+
# The reason to use a real MetricLogger is to avoid a problem like what was observed in
63+
# https://github.com/AI-Hypercomputer/maxtext/pull/3691, where the MetricLogger was changed
64+
# to expect a new metric that the SFT code did not provide.
65+
self.training_hooks.metric_logger = MetricLogger(self.config, learning_rate_schedule)
66+
# Initialize metadata to avoid KeyErrors in real MetricLogger
67+
self.training_hooks.metric_logger.metadata = defaultdict(float)
5168

5269
expected_shape = [jax.device_count(), self.config.max_target_length]
5370
self.expected_batch = {
@@ -59,6 +76,20 @@ def setUp(self):
5976

6077
self.mock_train_ctx = MagicMock()
6178

79+
def tearDown(self):
80+
shutil.rmtree(self.test_dir)
81+
super().tearDown()
82+
83+
def _read_logged_metrics(self, num_expected=1):
84+
"""Read and parse metrics logged by the MetricLogger."""
85+
metrics = []
86+
if os.path.exists(self.metrics_file):
87+
with open(self.metrics_file, "r", encoding="utf8") as f:
88+
for line in f:
89+
metrics.append(json.loads(line))
90+
self.assertEqual(len(metrics), num_expected)
91+
return metrics
92+
6293
@patch("maxtext.trainers.post_train.sft.hooks.create_data_iterator")
6394
def test_data_hooks_load_next_train_batch(self, mock_create_data_iterator):
6495
mock_create_data_iterator.return_value = self.mock_data_iterator, None
@@ -88,15 +119,10 @@ def test_training_hooks_for_train_step(self):
88119
self.training_hooks.on_train_step_start(self.mock_train_ctx)
89120
self.training_hooks.on_train_step_end(self.mock_train_ctx, train_step=1, train_loss=5.0, step_time=0.004)
90121

91-
expected_metrics = {
92-
"scalar": {
93-
"learning/loss": 5.0,
94-
"learning/total_weights": (jax.device_count() * self.config.max_target_length),
95-
}
96-
}
97-
self.training_hooks.metric_logger.record_train_metrics.assert_called()
98-
self.training_hooks.metric_logger.write_metrics.assert_called_with(expected_metrics, 1)
99-
self.assertEqual(len(self.training_hooks.train_metadata), 1)
122+
metrics = self._read_logged_metrics(num_expected=1)[0]
123+
self.assertEqual(metrics["step"], 1)
124+
self.assertAlmostEqual(metrics["learning/loss"], 5.0)
125+
self.assertEqual(metrics["learning/total_weights"], (jax.device_count() * self.config.max_target_length))
100126

101127
def test_training_hooks_for_eval_step(self):
102128
self.mock_train_ctx.data_hooks.eval_batch = self.expected_batch
@@ -106,15 +132,12 @@ def test_training_hooks_for_eval_step(self):
106132
self.training_hooks.on_eval_step_start(self.mock_train_ctx)
107133
self.training_hooks.on_eval_step_end(self.mock_train_ctx, eval_loss=10.0)
108134

109-
expected_metrics = {
110-
"scalar": {
111-
"eval/total_loss": 10.0,
112-
"eval/avg_loss": 5.0,
113-
"eval/total_weights": jax.device_count() * self.config.max_target_length * total_eval_steps,
114-
}
115-
}
116-
self.training_hooks.metric_logger.write_metrics.assert_called_with(expected_metrics, 0, is_training=False)
117-
self.assertEqual(len(self.training_hooks.eval_metadata), 0)
135+
metrics = self._read_logged_metrics(num_expected=1)[0]
136+
self.assertEqual(metrics["step"], 0)
137+
self.assertAlmostEqual(metrics["eval/total_loss"], 10.0)
138+
self.assertAlmostEqual(metrics["eval/avg_loss"], 5.0)
139+
self.assertAlmostEqual(metrics["eval/avg_perplexity"], np.exp(5.0), places=2)
140+
self.assertEqual(metrics["eval/total_weights"], jax.device_count() * self.config.max_target_length * total_eval_steps)
118141

119142
def test_on_train_end_asserts_if_on_train_start_not_called(self):
120143
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)