Skip to content

Commit e886dd2

Browse files
committed
Add Warmup-Stable-Decay (WSD) learning rate scheduler with configurable stable and decay phases
Signed-off-by: bzantium <ryumin93@gmail.com>
1 parent 08216c6 commit e886dd2

4 files changed

Lines changed: 211 additions & 27 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ grain_file_type: 'arrayrecord' # arrayrecord or parquet
607607
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
608608
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
609609
grain_per_worker_buffer_size: 1
610-
# num_threads and prefetch_buffer_size are per-worker per-dataset.
610+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
611611
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
612612
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
613613
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
@@ -635,15 +635,28 @@ skip_jax_distributed_system: False # If True we will not initialize the jax dist
635635
# However when run on google internal TPUs the coordination service is started automatically
636636
# and we should set this to True so we won't try to initialize a second time manually.
637637

638-
# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
639-
# Learning rate schedule has either two or three parts:
638+
# Learning rate schedule structure depends on lr_schedule_type:
639+
#
640+
# Cosine schedule (lr_schedule_type='cosine'):
641+
# Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
642+
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
643+
# 2) Cosine decay from [learning_rate] to [learning_rate * learning_rate_final_fraction] until learning_rate_schedule_steps
644+
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
645+
#
646+
# WSD schedule (lr_schedule_type='wsd', Warmup-Stable-Decay):
640647
# 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
641-
# 2) Cosine decay from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] from warmup to learning_rate_schedule_steps
642-
# 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
648+
# 2) Stable phase at [learning_rate] for the majority of training
649+
# 3) Decay from [learning_rate] to [learning_rate * learning_rate_final_fraction] over [learning_rate_schedule_steps * wsd_decay_steps_fraction] steps
650+
# The decay can be either linear or cosine based on wsd_decay_style
651+
# 4) Constant learning rate of 0 from learning_rate_schedule_steps to steps (if steps > learning_rate_schedule_steps)
652+
#
643653
# The zero learning rate section can be used to more accurately measure the fully trained model's performance.
644654
learning_rate: 3.e-5
645-
cosine_learning_rate_final_fraction: 0.1
646-
warmup_steps_fraction: 0.1
655+
lr_schedule_type: 'cosine' # Options: 'cosine' or 'wsd'
656+
learning_rate_final_fraction: 0.1 # Final LR as fraction of peak LR (applies to both cosine and WSD schedules)
657+
wsd_decay_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for decay phase in WSD (e.g., 0.1 = 10%)
658+
wsd_decay_style: 'linear' # Decay style for WSD schedule: 'linear' or 'cosine'
659+
warmup_steps_fraction: 0.1 # Fraction of learning_rate_schedule_steps used for warmup phase (applies to both schedules)
647660
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
648661
# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before
649662
# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0.

src/MaxText/configs/types.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ class OptimizerType(str, Enum):
124124
MUON = "muon"
125125

126126

127+
class LearningRateScheduleType(str, Enum):
128+
"""Supported learning rate schedule types."""
129+
130+
COSINE = "cosine"
131+
WSD = "wsd"
132+
133+
134+
class WsdDecayStyle(str, Enum):
135+
"""Supported decay styles for WSD schedule."""
136+
137+
LINEAR = "linear"
138+
COSINE = "cosine"
139+
140+
127141
class RopeType(str, Enum):
128142
"""Supported Rotary Positional Embedding (RoPE) implementations."""
129143

@@ -1005,8 +1019,17 @@ class Optimizer(BaseModel):
10051019
1.0, description="The threshold for gradient clipping. 0 disables clipping."
10061020
)
10071021
learning_rate: NonNegativeFloat = Field(3.0e-5, description="The peak learning rate.")
1008-
cosine_learning_rate_final_fraction: float = Field(
1009-
0.1, description="Final LR as a fraction of peak LR in cosine decay."
1022+
lr_schedule_type: LearningRateScheduleType = Field(
1023+
LearningRateScheduleType.COSINE, description="The type of learning rate schedule to use."
1024+
)
1025+
learning_rate_final_fraction: float = Field(
1026+
0.1, description="Final LR as a fraction of peak LR (applies to both cosine and WSD schedules)."
1027+
)
1028+
wsd_decay_steps_fraction: float = Field(
1029+
0.1, ge=0.0, le=1.0, description="Fraction of total steps for decay phase in WSD schedule."
1030+
)
1031+
wsd_decay_style: WsdDecayStyle = Field(
1032+
WsdDecayStyle.LINEAR, description="The decay style for WSD schedule ('linear' or 'cosine')."
10101033
)
10111034
warmup_steps_fraction: float = Field(0.1, ge=0.0, le=1.0, description="Fraction of total steps for LR warmup.")
10121035
learning_rate_schedule_steps: int = Field(
@@ -1748,6 +1771,17 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
17481771
# If steps is -1, it defaults to the length of the learning rate schedule.
17491772
if self.steps == -1:
17501773
self.steps = self.learning_rate_schedule_steps
1774+
1775+
# Validate WSD learning rate schedule fractions
1776+
if self.lr_schedule_type == LearningRateScheduleType.WSD:
1777+
total_fraction = self.warmup_steps_fraction + self.wsd_decay_steps_fraction
1778+
if total_fraction > 1.0:
1779+
raise ValueError(
1780+
f"Invalid WSD schedule: warmup_steps_fraction ({self.warmup_steps_fraction}) + "
1781+
f"wsd_decay_steps_fraction ({self.wsd_decay_steps_fraction}) must not exceed 1.0. "
1782+
f"Current sum: {total_fraction}"
1783+
)
1784+
17511785
# If eval_per_device_batch_size is not set, it defaults to the training per_device_batch_size.
17521786
if getattr(self, "eval_per_device_batch_size", 0.0) == 0.0:
17531787
self.eval_per_device_batch_size = self.per_device_batch_size

src/MaxText/maxtext_utils.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from MaxText import max_utils
4141
from MaxText import multimodal_utils
4242
from MaxText import sharding
43+
from MaxText.configs import types
4344
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
4445
from MaxText.inference.page_manager import PageState
4546

@@ -1103,44 +1104,72 @@ def create_device_mesh(config, devices=None):
11031104

11041105

11051106
def create_learning_rate_schedule(config):
1106-
"""Creates a warmup and cosine decay learning rate schedule:
1107-
We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1108-
Learning rate schedule has either two or three parts:
1107+
"""Creates a learning rate schedule with warmup and decay.
1108+
1109+
Supports two schedule types:
1110+
- Cosine: Inspired by Llama2's learning rate schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
1111+
- WSD (Warmup-Stable-Decay): Maintains constant learning rate for most of training before final decay
1112+
1113+
Schedule structure:
11091114
1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction]
1110-
2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps
1115+
2) Decay from [learning_rate] to a final value until learning_rate_schedule_steps
1116+
- Cosine: decays to [learning_rate * learning_rate_final_fraction]
1117+
- WSD: maintains [learning_rate] for a stable phase, then decays to [learning_rate * learning_rate_final_fraction]
1118+
using either linear or cosine decay based on wsd_decay_style
11111119
3) Constant learning rate of 0 from learning_rate_schedule_steps to steps.
11121120
The zero learning rate section can be used to more accurately measure the fully trained model's performance.
11131121
"""
11141122

11151123
def make_cos_schedule(init_lr, final_lr, len_steps):
11161124
def schedule(step):
1117-
pct = (step) / len_steps
1125+
pct = step / (len_steps - 1) if len_steps > 1 else 1.0
11181126
a = 0.5 * (jnp.cos(jnp.pi * pct) + 1)
11191127
lr = init_lr * a + final_lr * (1 - a)
11201128
return lr
11211129

11221130
return schedule
11231131

11241132
lr = config.learning_rate
1125-
cos_final_lr = lr * config.cosine_learning_rate_final_fraction
1126-
1133+
final_lr = lr * config.learning_rate_final_fraction
11271134
warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction)
1128-
cos_steps = config.learning_rate_schedule_steps - warmup_steps
11291135
constant_zero_steps = config.steps - config.learning_rate_schedule_steps
11301136

1131-
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps)
1132-
cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps)
1133-
constant_schedule = optax.constant_schedule(0.0)
1134-
1135-
pieces = [warmup_schedule, cos_schedule]
1136-
boundaries = [
1137-
warmup_steps,
1138-
warmup_steps + cos_steps,
1139-
]
1137+
pieces = []
1138+
boundaries = []
1139+
1140+
if warmup_steps > 0:
1141+
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps - 1)
1142+
pieces.append(warmup_schedule)
1143+
boundaries.append(warmup_steps)
1144+
1145+
if config.lr_schedule_type == types.LearningRateScheduleType.COSINE:
1146+
cos_steps = config.learning_rate_schedule_steps - warmup_steps
1147+
if cos_steps > 0:
1148+
cos_schedule = make_cos_schedule(lr, final_lr, cos_steps)
1149+
pieces.append(cos_schedule)
1150+
boundaries.append(warmup_steps + cos_steps)
1151+
1152+
else: # WSD
1153+
decay_steps = int(config.learning_rate_schedule_steps * config.wsd_decay_steps_fraction)
1154+
stable_steps = config.learning_rate_schedule_steps - warmup_steps - decay_steps
1155+
1156+
if stable_steps > 0:
1157+
stable_schedule = optax.constant_schedule(lr)
1158+
pieces.append(stable_schedule)
1159+
boundaries.append(warmup_steps + stable_steps)
1160+
if decay_steps > 0:
1161+
# Create decay schedule based on wsd_decay_style
1162+
if config.wsd_decay_style == types.WsdDecayStyle.LINEAR:
1163+
decay_schedule = optax.linear_schedule(init_value=lr, end_value=final_lr, transition_steps=decay_steps - 1)
1164+
else: # COSINE
1165+
decay_schedule = make_cos_schedule(lr, final_lr, decay_steps)
1166+
pieces.append(decay_schedule)
1167+
boundaries.append(warmup_steps + stable_steps + decay_steps)
11401168

11411169
if constant_zero_steps > 0:
1170+
constant_schedule = optax.constant_schedule(0.0)
11421171
pieces.append(constant_schedule)
1143-
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)
1172+
boundaries.append(config.learning_rate_schedule_steps)
11441173

11451174
return optax.join_schedules(pieces, boundaries)
11461175

tests/maxtext_utils_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,5 +682,113 @@ def test_bytes_from_pytree_empty_dict(self):
682682
self.assertEqual(max_utils.calculate_bytes_from_pytree({}), 0)
683683

684684

685+
class TestLearningRateSchedules(unittest.TestCase):
686+
"""Test suite for learning rate schedule functions."""
687+
688+
def test_cosine_schedule(self):
689+
"""Tests cosine learning rate schedule."""
690+
learning_rate = 1e-3
691+
learning_rate_schedule_steps = 1000
692+
steps = 1200
693+
warmup_steps_fraction = 0.1
694+
learning_rate_final_fraction = 0.1
695+
696+
warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction)
697+
698+
config = pyconfig.initialize(
699+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
700+
enable_checkpointing=False,
701+
learning_rate=learning_rate,
702+
learning_rate_schedule_steps=learning_rate_schedule_steps,
703+
steps=steps,
704+
warmup_steps_fraction=warmup_steps_fraction,
705+
lr_schedule_type="cosine",
706+
learning_rate_final_fraction=learning_rate_final_fraction,
707+
)
708+
709+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
710+
711+
# Warmup phase: 0 -> peak
712+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
713+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
714+
715+
# Cosine decay phase
716+
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
717+
expected_final = learning_rate * learning_rate_final_fraction
718+
self.assertLess(float(lr_end), learning_rate)
719+
self.assertAlmostEqual(float(lr_end), expected_final, places=6)
720+
721+
# Zero phase
722+
self.assertAlmostEqual(float(schedule_fn(steps - 1)), 0.0, places=6)
723+
724+
def test_wsd_schedule(self):
725+
"""Tests WSD learning rate schedule with both linear and cosine decay styles."""
726+
learning_rate = 1e-3
727+
learning_rate_schedule_steps = 1000
728+
steps = 1200
729+
warmup_steps_fraction = 0.1
730+
learning_rate_final_fraction = 0.1
731+
wsd_decay_steps_fraction = 0.1
732+
733+
warmup_steps = int(learning_rate_schedule_steps * warmup_steps_fraction)
734+
decay_steps = int(learning_rate_schedule_steps * wsd_decay_steps_fraction)
735+
stable_steps = learning_rate_schedule_steps - warmup_steps - decay_steps
736+
decay_start = warmup_steps + stable_steps
737+
738+
# Test both decay styles: linear and cosine
739+
for decay_style in ["linear", "cosine"]:
740+
config = pyconfig.initialize(
741+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
742+
enable_checkpointing=False,
743+
learning_rate=learning_rate,
744+
learning_rate_schedule_steps=learning_rate_schedule_steps,
745+
steps=steps,
746+
warmup_steps_fraction=warmup_steps_fraction,
747+
lr_schedule_type="wsd",
748+
learning_rate_final_fraction=learning_rate_final_fraction,
749+
wsd_decay_steps_fraction=wsd_decay_steps_fraction,
750+
wsd_decay_style=decay_style,
751+
)
752+
schedule_fn = maxtext_utils.create_learning_rate_schedule(config)
753+
754+
# Warmup phase: 0 -> peak
755+
self.assertAlmostEqual(float(schedule_fn(0)), 0.0, places=6)
756+
self.assertAlmostEqual(float(schedule_fn(warmup_steps)), learning_rate, places=6)
757+
758+
# Stable phase: constant at peak
759+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + 10)), learning_rate, places=6)
760+
self.assertAlmostEqual(float(schedule_fn(warmup_steps + stable_steps // 2)), learning_rate, places=6)
761+
self.assertAlmostEqual(float(schedule_fn(decay_start - 1)), learning_rate, places=6)
762+
763+
# Decay phase: peak -> final
764+
lr_mid_decay = schedule_fn(decay_start + decay_steps // 2)
765+
expected_final = learning_rate * learning_rate_final_fraction
766+
self.assertLess(float(lr_mid_decay), learning_rate)
767+
self.assertGreater(float(lr_mid_decay), expected_final)
768+
769+
# End of decay phase: should reach expected_final
770+
lr_end = schedule_fn(learning_rate_schedule_steps - 1)
771+
self.assertAlmostEqual(float(lr_end), expected_final, places=6)
772+
773+
# Zero phase
774+
self.assertAlmostEqual(float(schedule_fn(steps - 1)), 0.0, places=6)
775+
776+
# Test invalid fractions - should raise during config initialization
777+
with self.assertRaises(ValueError) as cm:
778+
pyconfig.initialize(
779+
[None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")],
780+
enable_checkpointing=False,
781+
learning_rate=learning_rate,
782+
learning_rate_schedule_steps=learning_rate_schedule_steps,
783+
steps=steps,
784+
warmup_steps_fraction=0.6,
785+
lr_schedule_type="wsd",
786+
learning_rate_final_fraction=learning_rate_final_fraction,
787+
wsd_decay_steps_fraction=0.5, # Sum > 1.0
788+
)
789+
self.assertIn("warmup_steps_fraction", str(cm.exception))
790+
self.assertIn("wsd_decay_steps_fraction", str(cm.exception))
791+
792+
685793
if __name__ == "__main__":
686794
unittest.main()

0 commit comments

Comments
 (0)