2424from maxtext .utils .muon_utils import get_muon_weight_dimension_numbers
2525
2626
27- def get_adamw_mask ( config ):
28- """Create a mask function for AdamW optimizer to exclude certain parameters from weight decay ."""
29- if not getattr ( config , "adamw_mask" , None ) :
27+ def _get_path_mask_fn ( patterns , match_returns_true = True ):
28+ """Helper to create a mask function from a list of regex patterns ."""
29+ if not patterns :
3030 return None
3131
32- compiled_patterns = [re .compile (pattern ) for pattern in config . adamw_mask ]
32+ compiled_patterns = [re .compile (pattern ) for pattern in patterns ]
3333
3434 def mask_fn (params ):
35- def _is_decayed (path , _ ):
35+ def _is_masked (path , _ ):
3636 # Join path keys into a single string for pattern matching (e.g., "layer1/bias")
3737 path_str = "/" .join (str (getattr (p , "key" , getattr (p , "idx" , getattr (p , "name" , p )))) for p in path )
38- # If any pattern in adamw_mask matches the path, exclude from weight decay (return False).
39- # Otherwise, apply weight decay (return True).
40- return not any (pattern .search (path_str ) for pattern in compiled_patterns )
38+ matched = any (pattern .search (path_str ) for pattern in compiled_patterns )
39+ return matched if match_returns_true else not matched
4140
42- return jax .tree_util .tree_map_with_path (_is_decayed , params )
41+ return jax .tree_util .tree_map_with_path (_is_masked , params )
4342
4443 return mask_fn
4544
4645
46+ def get_adamw_mask (config ):
47+ """Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
48+ return _get_path_mask_fn (getattr (config , "adamw_mask" , None ), match_returns_true = False )
49+
50+
4751def get_optimizer (config , learning_rate_schedule , model = None ):
4852 """Create optimizer."""
4953 if config .opt_type == "adamw" :
5054 # Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
51- return optax .adamw (
55+ base_opt = optax .adamw (
5256 learning_rate_schedule ,
5357 b1 = config .adam_b1 ,
5458 b2 = config .adam_b2 ,
@@ -59,7 +63,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
5963 mask = get_adamw_mask (config ),
6064 )
6165 elif config .opt_type == "adam_pax" :
62- return adam_pax (
66+ base_opt = adam_pax (
6367 learning_rate_schedule ,
6468 beta1 = config .adam_b1 ,
6569 beta2 = config .adam_b2 ,
@@ -69,7 +73,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
6973 mask = get_adamw_mask (config ),
7074 )
7175 elif config .opt_type == "sgd" :
72- return optax .sgd (learning_rate_schedule )
76+ base_opt = optax .sgd (learning_rate_schedule )
7377 elif config .opt_type == "muon" :
7478 # extract muon dimension number from model structure
7579 if model is not None :
@@ -92,10 +96,19 @@ def get_optimizer(config, learning_rate_schedule, model=None):
9296 "adam_eps_root" : config .adam_eps_root ,
9397 "adam_weight_decay" : config .adam_weight_decay ,
9498 }
95- return muon (** muon_kwargs )
99+ base_opt = muon (** muon_kwargs )
96100 else :
97101 raise ValueError (f"{ config .opt_type = } is not a supported." )
98102
103+ # If a whitelist of trainable parameters is provided, freeze everything else.
104+ # When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105+ trainable_patterns = getattr (config , "trainable_parameters_mask" , None )
106+ freeze_mask_fn = _get_path_mask_fn (trainable_patterns , match_returns_true = False )
107+ if freeze_mask_fn is not None :
108+ return optax .chain (base_opt , optax .masked (optax .set_to_zero (), freeze_mask_fn ))
109+
110+ return base_opt
111+
99112
100113def adam_pax (
101114 learning_rate_fn : optax .Schedule ,
0 commit comments