1313# limitations under the License.
1414
1515"""Tests for training and data loading hooks for SFT"""
16+
17+ from collections import defaultdict
1618import pytest
1719
1820pytest .importorskip ("tunix" )
2123import jax
2224
2325import numpy as np
26+ import json
2427import os
28+ import shutil
29+ import tempfile
2530import unittest
2631from unittest .mock import MagicMock , patch
2732from jax .sharding import Mesh
2833
2934from maxtext .configs import pyconfig
3035from maxtext .utils .globals import MAXTEXT_CONFIGS_DIR
3136from maxtext .trainers .post_train .sft import hooks
37+ from maxtext .common .metric_logger import MetricLogger
3238from maxtext .utils import maxtext_utils
3339
3440
3541class SFTHooksTest (unittest .TestCase ):
3642
3743 def setUp (self ):
3844 super ().setUp ()
45+ self .test_dir = tempfile .mkdtemp ()
46+ self .metrics_file = os .path .join (self .test_dir , "metrics.txt" )
3947 self .config = pyconfig .initialize (
4048 ["" , os .path .join (MAXTEXT_CONFIGS_DIR , "post_train" , "sft.yml" )],
4149 per_device_batch_size = 1 ,
4250 run_name = "test" ,
43- base_output_directory = "test" ,
51+ base_output_directory = self .test_dir ,
52+ tensorboard_dir = self .test_dir ,
53+ metrics_dir = self .test_dir ,
54+ metrics_file = self .metrics_file ,
4455 skip_jax_distributed_system = True ,
4556 )
4657 self .mesh = Mesh (maxtext_utils .create_device_mesh (self .config ), self .config .mesh_axes )
4758 learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (self .config )
4859
4960 self .training_hooks = hooks .SFTTrainingHooks (self .config , self .mesh , learning_rate_schedule , goodput_recorder = None )
50- self .training_hooks .metric_logger = MagicMock ()
61+ # We will use the written metrics to validate the correctness.
62+ # The reason to use a real MetricLogger is to avoid a problem like what was observed in
63+ # https://github.com/AI-Hypercomputer/maxtext/pull/3691, where the MetricLogger was changed
64+ # to expect a new metric that the SFT code did not provide.
65+ self .training_hooks .metric_logger = MetricLogger (self .config , learning_rate_schedule )
66+ # Initialize metadata to avoid KeyErrors in real MetricLogger
67+ self .training_hooks .metric_logger .metadata = defaultdict (float )
5168
5269 expected_shape = [jax .device_count (), self .config .max_target_length ]
5370 self .expected_batch = {
@@ -59,6 +76,20 @@ def setUp(self):
5976
6077 self .mock_train_ctx = MagicMock ()
6178
79+ def tearDown (self ):
80+ shutil .rmtree (self .test_dir )
81+ super ().tearDown ()
82+
83+ def _read_logged_metrics (self , num_expected = 1 ):
84+ """Read and parse metrics logged by the MetricLogger."""
85+ metrics = []
86+ if os .path .exists (self .metrics_file ):
87+ with open (self .metrics_file , "r" , encoding = "utf8" ) as f :
88+ for line in f :
89+ metrics .append (json .loads (line ))
90+ self .assertEqual (len (metrics ), num_expected )
91+ return metrics
92+
6293 @patch ("maxtext.trainers.post_train.sft.hooks.create_data_iterator" )
6394 def test_data_hooks_load_next_train_batch (self , mock_create_data_iterator ):
6495 mock_create_data_iterator .return_value = self .mock_data_iterator , None
@@ -88,15 +119,10 @@ def test_training_hooks_for_train_step(self):
88119 self .training_hooks .on_train_step_start (self .mock_train_ctx )
89120 self .training_hooks .on_train_step_end (self .mock_train_ctx , train_step = 1 , train_loss = 5.0 , step_time = 0.004 )
90121
91- expected_metrics = {
92- "scalar" : {
93- "learning/loss" : 5.0 ,
94- "learning/total_weights" : (jax .device_count () * self .config .max_target_length ),
95- }
96- }
97- self .training_hooks .metric_logger .record_train_metrics .assert_called ()
98- self .training_hooks .metric_logger .write_metrics .assert_called_with (expected_metrics , 1 )
99- self .assertEqual (len (self .training_hooks .train_metadata ), 1 )
122+ metrics = self ._read_logged_metrics (num_expected = 1 )[0 ]
123+ self .assertEqual (metrics ["step" ], 1 )
124+ self .assertAlmostEqual (metrics ["learning/loss" ], 5.0 )
125+ self .assertEqual (metrics ["learning/total_weights" ], (jax .device_count () * self .config .max_target_length ))
100126
101127 def test_training_hooks_for_eval_step (self ):
102128 self .mock_train_ctx .data_hooks .eval_batch = self .expected_batch
@@ -106,15 +132,12 @@ def test_training_hooks_for_eval_step(self):
106132 self .training_hooks .on_eval_step_start (self .mock_train_ctx )
107133 self .training_hooks .on_eval_step_end (self .mock_train_ctx , eval_loss = 10.0 )
108134
109- expected_metrics = {
110- "scalar" : {
111- "eval/total_loss" : 10.0 ,
112- "eval/avg_loss" : 5.0 ,
113- "eval/total_weights" : jax .device_count () * self .config .max_target_length * total_eval_steps ,
114- }
115- }
116- self .training_hooks .metric_logger .write_metrics .assert_called_with (expected_metrics , 0 , is_training = False )
117- self .assertEqual (len (self .training_hooks .eval_metadata ), 0 )
135+ metrics = self ._read_logged_metrics (num_expected = 1 )[0 ]
136+ self .assertEqual (metrics ["step" ], 0 )
137+ self .assertAlmostEqual (metrics ["eval/total_loss" ], 10.0 )
138+ self .assertAlmostEqual (metrics ["eval/avg_loss" ], 5.0 )
139+ self .assertAlmostEqual (metrics ["eval/avg_perplexity" ], np .exp (5.0 ), places = 2 )
140+ self .assertEqual (metrics ["eval/total_weights" ], jax .device_count () * self .config .max_target_length * total_eval_steps )
118141
119142 def test_on_train_end_asserts_if_on_train_start_not_called (self ):
120143 with self .assertRaises (AssertionError ):
0 commit comments