Skip to content

Commit 8c1a4d2

Browse files
Merge pull request #3761 from AI-Hypercomputer:chengnuojin-fix-eval
PiperOrigin-RevId: 907255273
2 parents 96215fd + dddb0cf commit 8c1a4d2

2 files changed

Lines changed: 71 additions & 21 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def get_first_step(model, state):
8787
# -----------------------------------------------------------------------------
8888

8989

90-
def loss_fn(
91-
model, config, data, dropout_rng, params, sparsity_state=None, is_train=True
92-
):
90+
def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_train=True):
9391
"""loss_fn for both train and eval.
9492
9593
Args:
@@ -121,9 +119,7 @@ def loss_fn(
121119
# make its specific collection mutable so the MTPBlock can sow into it.
122120
if config.mtp_eval_target_module > 0 and not is_train:
123121
mutable_collections.append("mtp_acceptance")
124-
sparsity_enabled = (
125-
is_train and config.weight_sparsity_n and config.weight_sparsity_m
126-
)
122+
sparsity_enabled = is_train and config.weight_sparsity_n and config.weight_sparsity_m
127123
if sparsity_enabled:
128124
mutable_collections.append("batch_stats")
129125
if isinstance(model, nn.Module):
@@ -143,9 +139,7 @@ def loss_fn(
143139
data["inputs_position"],
144140
decoder_segment_ids=data["inputs_segmentation"],
145141
encoder_images=data["images"] if config.use_multimodal else None,
146-
encoder_image_masks=data["image_masks"]
147-
if config.use_multimodal and "image_masks" in data
148-
else None,
142+
encoder_image_masks=data["image_masks"] if config.use_multimodal and "image_masks" in data else None,
149143
enable_dropout=config.enable_dropout if is_train else False,
150144
rngs={"dropout": rng1, "params": aqt_rng},
151145
mutable=mutable_collections,
@@ -286,11 +280,7 @@ def loss_fn(
286280
"indexer_loss": indexer_loss,
287281
"moe_bias_updates": moe_bias_updates,
288282
"mtp_loss": mtp_loss,
289-
"batch_stats": (
290-
intermediate_outputs.get("batch_stats", None)
291-
if hasattr(intermediate_outputs, "get")
292-
else None
293-
),
283+
"batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None),
294284
}
295285
return loss, aux
296286

@@ -416,9 +406,7 @@ def move(path, value):
416406
if sparsity_enabled:
417407
full_grads = {"params": grads}
418408
if sparsity_enabled and "batch_stats" in state.params:
419-
batch_stats_grads = jax.tree_util.tree_map(
420-
jnp.zeros_like, state.params.get("batch_stats", {})
421-
)
409+
batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {}))
422410
full_grads["batch_stats"] = batch_stats_grads
423411
full_grads = max_utils.unbox_logicallypartioned(full_grads)
424412
else:
@@ -501,9 +489,7 @@ def eval_step(model, config, state, data, dropout_rng):
501489
batch_stats = state.params.get("batch_stats", {})
502490

503491
eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False)
504-
loss, aux = eval_loss_fn(
505-
pure_params, *extra_dpo_args, sparsity_state=batch_stats
506-
)
492+
loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats)
507493

508494
mtp_acceptance_rate = 0.0
509495
if config.mtp_eval_target_module > 0:
@@ -630,6 +616,8 @@ def train_loop(config, recorder, state=None):
630616
eval_step_count = 0
631617
# pylint: disable=not-callable
632618
for eval_batch in eval_data_iterator:
619+
# Shard input eval data
620+
eval_batch = jax.device_put(eval_batch, sharding.get_input_data_sharding(config, mesh))
633621
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
634622
break
635623
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):

tests/integration/smoke/train_smoke_test.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
""" Smoke test """
15+
"""Smoke test"""
1616
import os
1717
import unittest
1818

@@ -94,6 +94,36 @@ def test_tiny_config_no_scan(self):
9494
]
9595
)
9696

97+
def test_tiny_eval(self):
98+
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
99+
train_main(
100+
[
101+
None,
102+
get_test_config_path(),
103+
# pylint: disable=f-string-without-interpolation
104+
f"base_output_directory={self.base_output_directory}",
105+
"run_name=runner_test",
106+
r"dataset_path={self.dataset_path}",
107+
"base_emb_dim=8",
108+
"base_num_query_heads=4",
109+
"base_num_kv_heads=4",
110+
"base_mlp_dim=32",
111+
"base_num_decoder_layers=1",
112+
"head_dim=128",
113+
"per_device_batch_size=2",
114+
"max_target_length=128",
115+
"dataset_type=synthetic",
116+
"steps=5",
117+
"eval_steps=2",
118+
"eval_interval=10",
119+
"enable_checkpointing=False",
120+
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
121+
"enable_goodput_recording=False",
122+
"enable_checkpoint_cloud_logger=False",
123+
"monitor_goodput=False",
124+
]
125+
)
126+
97127
def test_qwen3_custom_moe_config(self):
98128
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
99129
train_main(
@@ -159,6 +189,38 @@ def test_tiny_config_explicit_shardmode(self):
159189
]
160190
)
161191

192+
def test_eval_explicit_shardmode(self):
193+
test_tmpdir = os.environ.get("TEST_TMPDIR") # pylint: disable=unused-variable
194+
train_main(
195+
[
196+
None,
197+
get_test_config_path(),
198+
# pylint: disable=f-string-without-interpolation
199+
f"base_output_directory={self.base_output_directory}",
200+
"run_name=runner_test",
201+
r"dataset_path={self.dataset_path}",
202+
"base_emb_dim=8",
203+
"base_num_query_heads=4",
204+
"base_num_kv_heads=4",
205+
"base_mlp_dim=32",
206+
"base_num_decoder_layers=1",
207+
"head_dim=128",
208+
"per_device_batch_size=2",
209+
"max_target_length=128",
210+
"dataset_type=synthetic",
211+
"steps=5",
212+
"eval_steps=2",
213+
"eval_interval=10",
214+
"shard_mode=explicit",
215+
"remove_size_one_mesh_axis_from_type=false",
216+
"enable_checkpointing=False",
217+
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
218+
"enable_goodput_recording=False",
219+
"enable_checkpoint_cloud_logger=False",
220+
"monitor_goodput=False",
221+
]
222+
)
223+
162224

163225
if __name__ == "__main__":
164226
absltest.main()

0 commit comments

Comments
 (0)