Skip to content

Commit 1137c42

Browse files
Merge pull request #2883 from bzantium:feature/#2882
PiperOrigin-RevId: 858771206
2 parents d4a259d + e886dd2 commit 1137c42

4 files changed

Lines changed: 210 additions & 26 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -636,15 +636,28 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
636636
# However when run on google internal TPUs the coordination service is started automatically
637637
# and we should set this to True so we won't try to initialize a second time manually.
638638

639-
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
640-
# Learning rate schedule has either two or three parts:
639+
# Learning rate schedule structure depends on lr_schedule_type:
640+
#
641+
# Cosine schedule (lr_schedule_type='cosine'):
642+
# Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
643+
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
644+
# 2) Cosine decay from [learning_rate] to [learning_rate * learning_rate_final_fraction] until learning_rate_schedule_steps
645+
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
646+
#
647+
# WSD schedule (lr_schedule_type='wsd', Warmup-Stable-Decay):
641648
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
642-
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
643-
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
649+
# 2) Stable phase at [learning_rate] for the majority of training
650+
# 3) Decay from [learning_rate] to [learning_rate * learning_rate_final_fraction] over [learning_rate_schedule_steps * wsd_decay_steps_fraction] steps
651+
# The decay can be either linear or cosine based on wsd_decay_style
652+
# 4) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
653+
#
644654
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
645655
learning_rate: 3.e-5
646-
cosine_learning_rate_final_fraction: 0.1
647-
warmup_steps_fraction: 0.1
656+
lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd'
657+
learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR (applies to both cosine and WSD schedules)
658+
wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%)
659+
wsd_decay_style: 'linear' # Decay style for WSD schedule: 'linear' or 'cosine'
660+
warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules)
648661
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
649662
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
650663
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

src/MaxText/configs/types.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ class OptimizerType(str, Enum):
124124
MUON = "muon"
125125

126126

127+
class LearningRateScheduleType(str, Enum):
128+
"""Supported learning rate schedule types."""
129+
130+
COSINE = "cosine"
131+
WSD = "wsd"
132+
133+
134+
class WsdDecayStyle(str, Enum):
135+
"""Supported decay styles for WSD schedule."""
136+
137+
LINEAR = "linear"
138+
COSINE = "cosine"
139+
140+
127141
class RopeType(str, Enum):
128142
"""Supported Rotary Positional Embedding (RoPE) implementations."""
129143

@@ -1030,8 +1044,17 @@ class Optimizer(BaseModel):
10301044
1.0, description="The threshold for gradient clipping. 0 disables clipping."
10311045
)
10321046
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
1033-
cosine_learning_rate_final_fraction: float = Field(
1034-
0.1, description="Final LR as a fraction of peak LR in cosine decay."
1047+
lr_schedule_type: LearningRateScheduleType = Field(
1048+
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
1049+
)
1050+
learning_rate_final_fraction: float = Field(
1051+
0.1, description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)."
1052+
)
1053+
wsd_decay_steps_fraction: float = Field(
1054+
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
1055+
)
1056+
wsd_decay_style: WsdDecayStyle = Field(
1057+
WsdDecayStyle.LINEAR, description="The decay style for WSD schedule ('linear' or 'cosine')."
10351058
)
10361059
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
10371060
learning_rate_schedule_steps: int = Field(
@@ -1775,6 +1798,17 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
17751798
# If steps is -1, it defaults to the length of the learning rate schedule.
17761799
if self.steps == -1:
17771800
self.steps = self.learning_rate_schedule_steps
1801+
1802+
# Validate WSD learning rate schedule fractions
1803+
if self.lr_schedule_type == LearningRateScheduleType.WSD:
1804+
total_fraction = self.warmup_steps_fraction + self.wsd_decay_steps_fraction
1805+
if total_fraction > 1.0:
1806+
raise ValueError(
1807+
f"Invalid WSD schedule: warmup_steps_fraction ({self.warmup_steps_fraction}) + "
1808+
f"wsd_decay_steps_fraction ({self.wsd_decay_steps_fraction}) must not exceed 1.0. "
1809+
f"Current sum: {total_fraction}"
1810+
)
1811+
17781812
# If eval_per_device_batch_size is not set, it defaults to the training per_device_batch_size.
17791813
if getattr(self, "eval_per_device_batch_size", 0.0) == 0.0:
17801814
self.eval_per_device_batch_size = self.per_device_batch_size

src/MaxText/maxtext_utils.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from MaxText import max_utils
4141
from MaxText import multimodal_utils
4242
from MaxText import sharding
43+
from MaxText.configs import types
4344
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4445
from MaxText.inference.page_manager import PageState
4546

@@ -1133,44 +1134,72 @@ def create_device_mesh(config, devices=None):
11331134

11341135

11351136
def create_learning_rate_schedule(config):
1136-
"""Creates a warmup and cosine decay learning rate schedule:
1137-
We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1138-
Learning rate schedule has either two or three parts:
1137+
"""Creates a learning rate schedule with warmup and decay.
1138+
1139+
Supports two schedule types:
1140+
- Cosine: Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1141+
- WSD (Warmup-Stable-Decay): Maintains constant learning rate for most of training before final decay
1142+
1143+
Schedule structure:
11391144
1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
1140-
2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
1145+
2) Decay from [learning_rate] to a final value until learning_rate_schedule_steps
1146+
- Cosine: decays to [learning_rate * learning_rate_final_fraction]
1147+
- WSD: maintains [learning_rate] for a stable phase, then decays to [learning_rate * learning_rate_final_fraction]
1148+
using either linear or cosine decay based on wsd_decay_style
11411149
3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
11421150
The zero learning rate section can be used to more accurately measure the fully trained model's performance.
11431151
"""
11441152

11451153
def make_cos_schedule(init_lr, final_lr, len_steps):
11461154
def schedule(step):
1147-
pct = (step) / len_steps
1155+
pct = step / (len_steps - 1) if len_steps > 1 else 1.0
11481156
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
11491157
lr = init_lr * a + final_lr * (1 - a)
11501158
return lr
11511159

11521160
return schedule
11531161

11541162
lr = config.learning_rate
1155-
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1156-
1163+
final_lr = lr * config.learning_rate_final_fraction
11571164
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
1158-
cos_steps = config.learning_rate_schedule_steps - warmup_steps
11591165
constant_zero_steps = config.steps - config.learning_rate_schedule_steps
11601166

1161-
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
1162-
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1163-
constant_schedule = optax.constant_schedule(0.0)
1164-
1165-
pieces = [warmup_schedule, cos_schedule]
1166-
boundaries = [
1167-
warmup_steps,
1168-
warmup_steps + cos_steps,
1169-
]
1167+
pieces = []
1168+
boundaries = []
1169+
1170+
if warmup_steps > 0:
1171+
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1)
1172+
pieces.append(warmup_schedule)
1173+
boundaries.append(warmup_steps)
1174+
1175+
if config.lr_schedule_type == types.LearningRateScheduleType.COSINE:
1176+
cos_steps = config.learning_rate_schedule_steps - warmup_steps
1177+
if cos_steps > 0:
1178+
cos_schedule = make_cos_schedule(lr, final_lr, cos_steps)
1179+
pieces.append(cos_schedule)
1180+
boundaries.append(warmup_steps + cos_steps)
1181+
1182+
else: # WSD
1183+
decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
1184+
stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
1185+
1186+
if stable_steps > 0:
1187+
stable_schedule = optax.constant_schedule(lr)
1188+
pieces.append(stable_schedule)
1189+
boundaries.append(warmup_steps + stable_steps)
1190+
if decay_steps > 0:
1191+
# Create decay schedule based on wsd_decay_style
1192+
if config.wsd_decay_style == types.WsdDecayStyle.LINEAR:
1193+
decay_schedule = optax.linear_schedule(init_value=lr, end_value=final_lr, transition_steps=decay_steps - 1)
1194+
else: # COSINE
1195+
decay_schedule = make_cos_schedule(lr, final_lr, decay_steps)
1196+
pieces.append(decay_schedule)
1197+
boundaries.append(warmup_steps + stable_steps + decay_steps)
11701198

11711199
if constant_zero_steps > 0:
1200+
constant_schedule = optax.constant_schedule(0.0)
11721201
pieces.append(constant_schedule)
1173-
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
1202+
boundaries.append(config.learning_rate_schedule_steps)
11741203

11751204
return optax.join_schedules(pieces, boundaries)
11761205

tests/maxtext_utils_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,5 +721,113 @@ def test_bytes_from_pytree_empty_dict(self):
721721
self.assertEqual(max_utils.calculate_bytes_from_pytree({}), 0)
722722

723723

724+
class TestLearningRateSchedules(unittest.TestCase):
725+
"""Test suite for learning rate schedule functions."""
726+
727+
def test_cosine_schedule(self):
728+
"""Tests cosine learning rate schedule."""
729+
learning_rate = 1e-3
730+
learning_rate_schedule_steps = 1000
731+
steps = 1200
732+
warmup_steps_fraction = 0.1
733+
learning_rate_final_fraction = 0.1
734+
735+
warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction)
736+
737+
config = pyconfig.initialize(
738+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
739+
enable_checkpointing=False,
740+
learning_rate=learning_rate,
741+
learning_rate_schedule_steps=learning_rate_schedule_steps,
742+
steps=steps,
743+
warmup_steps_fraction=warmup_steps_fraction,
744+
lr_schedule_type="cosine",
745+
learning_rate_final_fraction=learning_rate_final_fraction,
746+
)
747+
748+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
749+
750+
# Warmup phase: 0 -> peak
751+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
752+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
753+
754+
# Cosine decay phase
755+
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
756+
expected_final = learning_rate * learning_rate_final_fraction
757+
self.assertLess(float(lr_end), learning_rate)
758+
self.assertAlmostEqual(float(lr_end), expected_final, places=6)
759+
760+
# Zero phase
761+
self.assertAlmostEqual(float(schedule_fn(steps - 1)), 0.0, places=6)
762+
763+
def test_wsd_schedule(self):
764+
"""Tests WSD learning rate schedule with both linear and cosine decay styles."""
765+
learning_rate = 1e-3
766+
learning_rate_schedule_steps = 1000
767+
steps = 1200
768+
warmup_steps_fraction = 0.1
769+
learning_rate_final_fraction = 0.1
770+
wsd_decay_steps_fraction = 0.1
771+
772+
warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction)
773+
decay_steps = int(learning_rate_schedule_steps * wsd_decay_steps_fraction)
774+
stable_steps = learning_rate_schedule_steps - warmup_steps - decay_steps
775+
decay_start = warmup_steps + stable_steps
776+
777+
# Test both decay styles: linear and cosine
778+
for decay_style in ["linear", "cosine"]:
779+
config = pyconfig.initialize(
780+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
781+
enable_checkpointing=False,
782+
learning_rate=learning_rate,
783+
learning_rate_schedule_steps=learning_rate_schedule_steps,
784+
steps=steps,
785+
warmup_steps_fraction=warmup_steps_fraction,
786+
lr_schedule_type="wsd",
787+
learning_rate_final_fraction=learning_rate_final_fraction,
788+
wsd_decay_steps_fraction=wsd_decay_steps_fraction,
789+
wsd_decay_style=decay_style,
790+
)
791+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
792+
793+
# Warmup phase: 0 -> peak
794+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
795+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
796+
797+
# Stable phase: constant at peak
798+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)
799+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + stable_steps // 2)), learning_rate, places=6)
800+
self.assertAlmostEqual(float(schedule_fn(decay_start - 1)), learning_rate, places=6)
801+
802+
# Decay phase: peak -> final
803+
lr_mid_decay = schedule_fn(decay_start + decay_steps // 2)
804+
expected_final = learning_rate * learning_rate_final_fraction
805+
self.assertLess(float(lr_mid_decay), learning_rate)
806+
self.assertGreater(float(lr_mid_decay), expected_final)
807+
808+
# End of decay phase: should reach expected_final
809+
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
810+
self.assertAlmostEqual(float(lr_end), expected_final, places=6)
811+
812+
# Zero phase
813+
self.assertAlmostEqual(float(schedule_fn(steps - 1)), 0.0, places=6)
814+
815+
# Test invalid fractions - should raise during config initialization
816+
with self.assertRaises(ValueError) as cm:
817+
pyconfig.initialize(
818+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
819+
enable_checkpointing=False,
820+
learning_rate=learning_rate,
821+
learning_rate_schedule_steps=learning_rate_schedule_steps,
822+
steps=steps,
823+
warmup_steps_fraction=0.6,
824+
lr_schedule_type="wsd",
825+
learning_rate_final_fraction=learning_rate_final_fraction,
826+
wsd_decay_steps_fraction=0.5, # Sum > 1.0
827+
)
828+
self.assertIn("warmup_steps_fraction", str(cm.exception))
829+
self.assertIn("wsd_decay_steps_fraction", str(cm.exception))
830+
831+
724832
if __name__ == "__main__":
725833
unittest.main()

0 commit comments

Comments
 (0)