Skip to content

Commit 73adf2a

Browse files
committed
update consistent rms
1 parent faf311a commit 73adf2a

3 files changed

Lines changed: 3 additions & 11 deletions

File tree

src/maxtext/configs/pyconfig.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,6 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
256256
Please pass tokenizer_path in your command if this is not intended."
257257
)
258258

259-
# Preprocess muon_consistent_rms to be None or float
260-
if key == "muon_consistent_rms":
261-
if value in ["None", "none"]:
262-
new_value = None
263-
else:
264-
try:
265-
new_value = float(value)
266-
except ValueError as e:
267-
raise ValueError("muon_consistent_rms should be None or float") from e
268-
269259
pydantic_kwargs[key] = new_value
270260

271261
return pydantic_kwargs

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1343,7 +1343,7 @@ class Muon(BaseModel):
13431343
0,
13441344
description="Strength of the weight decay regularization. This is multiplied with the learning rate.",
13451345
)
1346-
muon_consistent_rms: None | float = Field(
1346+
muon_consistent_rms: float | None = Field(
13471347
None,
13481348
description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).",
13491349
)

src/maxtext/optimizers/optimizers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def get_optimizer(config, learning_rate_schedule, model=None):
197197
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config)
198198
else:
199199
raise ValueError("Please specify model to extract muon dimension number.")
200+
# TODO(shuningjin): remove
201+
print(f"DEBUG: {config.muon_consistent_rms}, {type(config.muon_consistent_rms)}")
200202
muon_kwargs = {
201203
# Shared parameters: "nesterov" uses default
202204
"learning_rate": learning_rate_schedule,

0 commit comments

Comments
 (0)