Skip to content

Commit cff7f1c

Browse files
Wraps creation of PeftTrainer in a nn_partitioning.axis_rules(mt_config.logical_axis_rules) context, which prevents checkpoint resuming from accidentally treating norm as a physical axis.
PiperOrigin-RevId: 902936535
1 parent 1448d5a commit cff7f1c

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,12 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
160160
with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
161161
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
162162
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)
163-
164-
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
165-
trainer.with_training_hooks(training_hooks)
166-
trainer.with_data_hooks(data_hooks)
167-
trainer = use_maxtext_loss_function(trainer, mt_config)
163+
# Provide rules context so 'norm' is translated to mesh axes during maybe_restore
164+
with nn_partitioning.axis_rules(mt_config.logical_axis_rules):
165+
trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
166+
trainer.with_training_hooks(training_hooks)
167+
trainer.with_data_hooks(data_hooks)
168+
trainer = use_maxtext_loss_function(trainer, mt_config)
168169

169170
return trainer, mesh
170171

0 commit comments

Comments
 (0)