Skip to content

Commit 096fc41

Browse files
Merge pull request #2942 from ChingTsai:enable-grad-clipping-in-sft
PiperOrigin-RevId: 857296866
2 parents b314c5a + 4d08c26 commit 096fc41

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

src/MaxText/sft/sft_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from absl import app
4141
import os
4242
import jax
43+
import optax
4344
import pathwaysutils
4445

4546
from flax.linen import partitioning as nn_partitioning
@@ -147,6 +148,12 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
147148
# pass in model for muon
148149
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)
149150

151+
if mt_config.gradient_clipping_threshold > 0:
152+
optimizer = optax.chain(
153+
optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold),
154+
optimizer,
155+
)
156+
150157
with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
151158
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
152159
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)

0 commit comments

Comments
 (0)