Skip to content

Commit 9662a79

Browse files
committed
Fix validation for model config override
1 parent c73595a commit 9662a79

5 files changed

Lines changed: 34 additions & 7 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
run_name: ""
2222

2323
model_name: "default" # override config settings to match a specific model. other than the override, nothing should use this!
24-
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
24+
override_model_config: False # When set to true allows overriding model parameters via CLI (or kwargs or env vars) for the purpose of debugging/testing.
2525
override_logical_axis_rules: False # When set overrides logical axis rules instead of merging them.
2626
debug:
2727
rl: False # RL-specific debugging

src/maxtext/configs/pyconfig.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ def yaml_key_to_env_key(s: str) -> str:
9393
return _MAX_PREFIX + s.upper()
9494

9595

96+
def validate_no_keys_overridden_twice(keys1: list[str], keys2: list[str]):
97+
overridden_keys = [k for k in keys1 if k in keys2]
98+
if overridden_keys:
99+
raise ValueError(
100+
f"Keys {overridden_keys} are overridden by both model config and CLI/kwargs."
101+
"This is not allowed, unless setting `override_model_config=True`."
102+
)
103+
104+
96105
def resolve_config_path(param: str) -> str:
97106
"""Resolve config path to auto rewrite to use new src folder."""
98107
if os.path.isfile(param):
@@ -330,6 +339,8 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
330339
model_cfg = {k: v for k, v in model_loaded_cfg.items() if k not in overrides_cfg}
331340
else:
332341
model_cfg = model_loaded_cfg
342+
# Validate that no keys are overridden by both model config and CLI/kwargs
343+
validate_no_keys_overridden_twice(model_loaded_cfg.keys(), overrides_cfg.keys())
333344
else:
334345
logger.warning("Model config for '%s' not found at %s", model_name, model_config_path)
335346

@@ -368,10 +379,17 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
368379
for k in tuple(raw_keys_dict.keys()):
369380
env_key = yaml_key_to_env_key(k)
370381
if env_key in os.environ:
382+
# Validate that no keys are overridden by both CLI/kwargs and environment variable
371383
if k in cli_keys or k in kwargs_keys:
372384
raise ValueError(
373385
f"Key '{k}' is overridden by both CLI/kwargs and environment variable '{env_key}'. This is not allowed."
374386
)
387+
# Validate that no keys are overridden by both model config and environment variable
388+
if not temp_cfg.get("override_model_config") and k in model_cfg.keys():
389+
raise ValueError(
390+
f"Key '{k}' is overridden by both model config and environment variable '{env_key}'."
391+
"This is not allowed, unless setting `override_model_config=True`."
392+
)
375393

376394
new_proposal = os.environ.get(env_key)
377395
original_value = raw_keys_dict.get(k)

tests/unit/mhc_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,15 @@ def setUp(self):
101101
per_device_batch_size=4,
102102
max_target_length=7,
103103
max_prefill_predict_length=7,
104+
attention="dot_product",
105+
routed_bias_update_rate=0.01,
106+
load_balance_loss_weight=0.02,
107+
# override
108+
override_model_config=True,
104109
base_emb_dim=self.dim,
105110
mhc_expansion_rate=3,
106111
num_experts=4,
107112
num_experts_per_tok=2,
108-
attention="dot_product",
109-
routed_bias_update_rate=0.01,
110-
load_balance_loss_weight=0.02,
111113
engram_layers=[],
112114
)
113115
devices_array = maxtext_utils.create_device_mesh(self.config)

tests/unit/nnx_decoders_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _make_config(**overrides):
7070
**_BASE_CONFIG,
7171
**extra_args,
7272
**overrides,
73+
override_model_config=True,
7374
)
7475

7576

tests/unit/train_compile_test.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,8 @@ def test_indexer_dense_warmup(self):
817817
"max_target_length=1024",
818818
"attention=flash",
819819
"use_tokamax_splash=True",
820+
# override
821+
"override_model_config=True",
820822
"engram_layers=[]",
821823
# dense warmup
822824
"indexer_sparse_training=False",
@@ -842,6 +844,8 @@ def test_indexer_sparse_training(self):
842844
"max_target_length=1024",
843845
"attention=flash",
844846
"use_tokamax_splash=True",
847+
# override
848+
"override_model_config=True",
845849
"engram_layers=[]",
846850
# sparse training
847851
"indexer_sparse_training=True",
@@ -869,7 +873,7 @@ def test_olmo3_7b(self):
869873

870874
@pytest.mark.cpu_only
871875
def test_mhc_integration(self):
872-
"""AOT test for Manifold-onstrained Hyper Connection implementation"""
876+
"""AOT test for Manifold-constrained Hyper Connection implementation"""
873877
compiled_trainstep_file = "/tmp/test_mhc_integration"
874878
train_compile_main(
875879
(
@@ -881,10 +885,12 @@ def test_mhc_integration(self):
881885
"model_name=deepseek-custom",
882886
"per_device_batch_size=4",
883887
"scan_layers=True",
884-
"max_target_length=1024",
885-
"mhc_expansion_rate=4",
886888
"attention=flash",
887889
"use_tokamax_splash=True",
890+
"max_target_length=1024",
891+
# override
892+
"override_model_config=True",
893+
"mhc_expansion_rate=4",
888894
"engram_layers=[]",
889895
)
890896
)

0 commit comments

Comments
 (0)