Skip to content

Commit 27eb0b4

Browse files
committed
Enable selective parameter training strategy
1 parent b672dc7 commit 27eb0b4

7 files changed

Lines changed: 204 additions & 16 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ index_topk: 2048
367367
sparse_indexer_loss: False
368368
# Multiplier for the indexer KL divergence loss
369369
indexer_loss_scaling_factor: 0.0
370+
# Whether to enable sparse training for indexer by detaching its input from the computational graph
371+
indexer_sparse_training: False
370372

371373
# MLA parameters
372374
q_lora_rank: 0
@@ -797,6 +799,10 @@ adam_eps: 1.e-8 # A small constant applied to denominator outside of the square
797799
adam_eps_root: 0. # A small constant applied to denominator inside the square root.
798800
adam_weight_decay: 0.1 # AdamW Weight decay
799801
adamw_mask: [] # List of parameter names/patterns to exclude from weight decay in AdamW, like ['bias', '.*norm', '.*ln.*'].
802+
# List of parameter names/patterns to train.
803+
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
804+
# If empty (default), all parameters are trained.
805+
trainable_parameters_mask: []
800806
mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inherits from weight_dtype if unset.
801807
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
802808
# See b/399961932 for more.

src/maxtext/configs/types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,10 @@ class AttentionIndexer(BaseModel):
537537
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
538538
sparse_indexer_loss: bool = Field(False, description="Determines the token selection strategy for indexer loss.")
539539
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")
540+
indexer_sparse_training: bool = Field(
541+
False,
542+
description="Whether to enable sparse training for indexer by detaching its input from the computational graph.",
543+
)
540544

541545

542546
class Llama4Attention(BaseModel):
@@ -1195,6 +1199,13 @@ class AdamW(BaseModel):
11951199
"List of parameter names/patterns to exclude from weight decay in AdamW," " like ['bias', '.*norm', '.*ln.*']"
11961200
),
11971201
)
1202+
trainable_parameters_mask: list[str] = Field(
1203+
default_factory=list,
1204+
description=(
1205+
"List of parameter names/patterns to train. If non-empty, all other parameters will be frozen, "
1206+
"example: ['.*indexer.*']. If empty (default), all parameters are trained."
1207+
),
1208+
)
11981209
mu_dtype: str = Field(
11991210
"",
12001211
description="Data type for 'mu' (first moment) in AdamW. Inherits from weight_dtype if empty.",

src/maxtext/layers/attention_mla.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,23 @@ def __call__(
266266
bsz, seqlen, _ = inputs_q.shape # s = t = seqlen
267267

268268
# Query Processing: Project from Latent low_rank_q
269-
q = self.wq_b(low_rank_q) # [b, t, q_lora_rank] -> [b, t, h * d]
269+
if self.config.indexer_sparse_training:
270+
# Detach indexer input from the computational graph so main loss doesn't backprop through indexer,
271+
# and indexer loss doesn't backprop into main model embeddings/latent variables.
272+
inputs_q_for_indexer = jax.lax.stop_gradient(inputs_q)
273+
low_rank_q_for_indexer = jax.lax.stop_gradient(low_rank_q)
274+
inputs_kv_for_indexer = jax.lax.stop_gradient(inputs_kv)
275+
else:
276+
inputs_q_for_indexer = inputs_q
277+
low_rank_q_for_indexer = low_rank_q
278+
inputs_kv_for_indexer = inputs_kv
279+
280+
q = self.wq_b(low_rank_q_for_indexer) # [b, t, q_lora_rank] -> [b, t, h * d]
270281
q = q.reshape(bsz, seqlen, self.n_heads, self.head_dim) # [b, t, h, d]
271282
q = self.apply_partial_rope(q, inputs_positions=inputs_positions)
272283

273284
# Key Processing: Project from Input
274-
k = self.wk(inputs_kv) # [b, s, embed_dim] -> [b, s, d]
285+
k = self.wk(inputs_kv_for_indexer) # [b, s, embed_dim] -> [b, s, d]
275286
k = self.k_norm(k)
276287
k = k[:, :, None, :] # [b, s, d] -> [b, s, 1, d]
277288
k = self.apply_partial_rope(k, inputs_positions=inputs_positions)
@@ -283,7 +294,7 @@ def __call__(
283294
logits = jnp.einsum("bthd, bsd -> btsh", q, k, precision=self.config.matmul_precision)
284295
logits = jax.nn.relu(logits)
285296
# Compute head weights: project from input, [b, t, embed_dim] -> [b, t, h]
286-
weights = self.weights_proj(inputs_q)
297+
weights = self.weights_proj(inputs_q_for_indexer)
287298
# Weights scaling affect indexer_score, but does not affect topk_indices. Keep scaling for numerical stability.
288299
# https://github.com/deepseek-ai/DeepSeek-V3.2-Exp/blob/87e509a2e5a100d221c97df52c6e8be7835f0057/inference/model.py#L478-L480
289300
weights = weights * (self.n_heads**-0.5) * self.softmax_scale

src/maxtext/optimizers/optimizers.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,35 @@
2424
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers
2525

2626

27-
def get_adamw_mask(config):
28-
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
29-
if not getattr(config, "adamw_mask", None):
27+
def _get_path_mask_fn(patterns, match_returns_true=True):
28+
"""Helper to create a mask function from a list of regex patterns."""
29+
if not patterns:
3030
return None
3131

32-
compiled_patterns = [re.compile(pattern) for pattern in config.adamw_mask]
32+
compiled_patterns = [re.compile(pattern) for pattern in patterns]
3333

3434
def mask_fn(params):
35-
def _is_decayed(path, _):
35+
def _is_masked(path, _):
3636
# Join path keys into a single string for pattern matching (e.g., "layer1/bias")
3737
path_str = "/".join(str(getattr(p, "key", getattr(p, "idx", getattr(p, "name", p)))) for p in path)
38-
# If any pattern in adamw_mask matches the path, exclude from weight decay (return False).
39-
# Otherwise, apply weight decay (return True).
40-
return not any(pattern.search(path_str) for pattern in compiled_patterns)
38+
matched = any(pattern.search(path_str) for pattern in compiled_patterns)
39+
return matched if match_returns_true else not matched
4140

42-
return jax.tree_util.tree_map_with_path(_is_decayed, params)
41+
return jax.tree_util.tree_map_with_path(_is_masked, params)
4342

4443
return mask_fn
4544

4645

46+
def get_adamw_mask(config):
47+
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
48+
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)
49+
50+
4751
def get_optimizer(config, learning_rate_schedule, model=None):
4852
"""Create optimizer."""
4953
if config.opt_type == "adamw":
5054
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
51-
return optax.adamw(
55+
base_opt = optax.adamw(
5256
learning_rate_schedule,
5357
b1=config.adam_b1,
5458
b2=config.adam_b2,
@@ -59,7 +63,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
5963
mask=get_adamw_mask(config),
6064
)
6165
elif config.opt_type == "adam_pax":
62-
return adam_pax(
66+
base_opt = adam_pax(
6367
learning_rate_schedule,
6468
beta1=config.adam_b1,
6569
beta2=config.adam_b2,
@@ -69,7 +73,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
6973
mask=get_adamw_mask(config),
7074
)
7175
elif config.opt_type == "sgd":
72-
return optax.sgd(learning_rate_schedule)
76+
base_opt = optax.sgd(learning_rate_schedule)
7377
elif config.opt_type == "muon":
7478
# extract muon dimension number from model structure
7579
if model is not None:
@@ -92,10 +96,19 @@ def get_optimizer(config, learning_rate_schedule, model=None):
9296
"adam_eps_root": config.adam_eps_root,
9397
"adam_weight_decay": config.adam_weight_decay,
9498
}
95-
return muon(**muon_kwargs)
99+
base_opt = muon(**muon_kwargs)
96100
else:
97101
raise ValueError(f"{config.opt_type=} is not a supported.")
98102

103+
# If a whitelist of trainable parameters is provided, freeze everything else.
104+
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
105+
trainable_patterns = getattr(config, "trainable_parameters_mask", None)
106+
freeze_mask_fn = _get_path_mask_fn(trainable_patterns, match_returns_true=False)
107+
if freeze_mask_fn is not None:
108+
return optax.chain(base_opt, optax.masked(optax.set_to_zero(), freeze_mask_fn))
109+
110+
return base_opt
111+
99112

100113
def adam_pax(
101114
learning_rate_fn: optax.Schedule,

src/maxtext/trainers/pre_train/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ def loss_fn(model, config, data, dropout_rng, params, is_train=True):
232232

233233
if indexer_losses:
234234
indexer_loss = jnp.mean(jnp.concatenate(indexer_losses))
235+
# DeepSeek V3.2: When `indexer_sparse_training` is true, we optimize the indexer
236+
# using ONLY the indexer loss, and the main model using ONLY the language modeling loss.
237+
# To do this, we decouple the gradients. We detach the indexer input from the
238+
# computational graph inside the Indexer module itself (in attention_mla.py)
239+
# by stopping gradients on its inputs.
240+
# So here, we just add the indexer loss to the total loss. The gradients will
241+
# naturally separate because the inputs to the indexer were stopped.
235242
loss += indexer_loss
236243
else:
237244
max_logging.debug("No indexer loss found.")

tests/unit/attention_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@
3737
DEFAULT_MASK_VALUE,
3838
)
3939
from maxtext.layers.attention_mla import MLA
40+
from maxtext.layers.attention_mla import Indexer
4041
from maxtext.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask
4142
from maxtext.layers.attentions import Attention
43+
from maxtext.layers import embeddings
4244
from maxtext.configs import pyconfig
4345
from maxtext.models.qwen3 import Qwen3NextGatedDeltaNet
4446
import numpy as np
@@ -1693,6 +1695,78 @@ def test_indexer_loss_kl_divergence_zero(self):
16931695

16941696
np.testing.assert_allclose(loss, 0.0, atol=1e-5)
16951697

1698+
def test_indexer_gradients(self):
1699+
# Test that gradients flow back to inputs when indexer_sparse_training=False
1700+
# but do NOT flow back when indexer_sparse_training=True
1701+
bsz, seqlen = 2, 8
1702+
inputs_positions = jnp.broadcast_to(jnp.arange(seqlen)[None, :], (bsz, seqlen))
1703+
1704+
for sparse_training in [False, True]:
1705+
with self.subTest(indexer_sparse_training=sparse_training):
1706+
argv = [
1707+
"",
1708+
get_test_config_path(),
1709+
"run_name=test",
1710+
"attention_type=mla",
1711+
f"indexer_sparse_training={sparse_training}",
1712+
"max_target_length=16",
1713+
"index_topk=4",
1714+
"index_n_heads=2",
1715+
"index_head_dim=8",
1716+
"emb_dim=16",
1717+
"qk_rope_head_dim=4",
1718+
"q_lora_rank=16",
1719+
]
1720+
config = pyconfig.initialize(argv)
1721+
rngs = nnx.Rngs(0)
1722+
mesh = jax.sharding.Mesh(jax.devices(), ("data",))
1723+
rope = embeddings.RotaryEmbedding(
1724+
min_timescale=1,
1725+
max_timescale=10000,
1726+
mesh=mesh,
1727+
embedding_dims=config.qk_rope_head_dim,
1728+
fprop_dtype=jnp.float32,
1729+
rngs=rngs,
1730+
)
1731+
rope.interleave = False
1732+
1733+
indexer = Indexer(
1734+
config=config,
1735+
rotary_embedding=rope,
1736+
rngs=rngs,
1737+
)
1738+
1739+
inputs_q = jnp.ones((bsz, seqlen, config.emb_dim))
1740+
low_rank_q = jnp.ones((bsz, seqlen, config.q_lora_rank))
1741+
inputs_kv = jnp.ones((bsz, seqlen, config.emb_dim))
1742+
1743+
def loss_fn(inputs_q, low_rank_q, inputs_kv, indexer):
1744+
_, _, indexer_score = indexer(
1745+
inputs_q=inputs_q,
1746+
low_rank_q=low_rank_q,
1747+
inputs_kv=inputs_kv,
1748+
inputs_positions=inputs_positions,
1749+
)
1750+
# A dummy loss function (e.g., sum of scores)
1751+
return jnp.sum(indexer_score)
1752+
1753+
# Calculate gradients with respect to the 3 inputs
1754+
grad_fn = nnx.grad(loss_fn, argnums=(0, 1, 2))
1755+
grads = grad_fn(inputs_q, low_rank_q, inputs_kv, indexer)
1756+
1757+
grad_q, grad_low_rank, grad_kv = grads
1758+
1759+
if sparse_training:
1760+
# Gradients should be exactly zero
1761+
self.assertTrue(jnp.all(grad_q == 0.0))
1762+
self.assertTrue(jnp.all(grad_low_rank == 0.0))
1763+
self.assertTrue(jnp.all(grad_kv == 0.0))
1764+
else:
1765+
# Gradients should be non-zero
1766+
self.assertFalse(jnp.all(grad_q == 0.0))
1767+
self.assertFalse(jnp.all(grad_low_rank == 0.0))
1768+
self.assertFalse(jnp.all(grad_kv == 0.0))
1769+
16961770

16971771
class Qwen3NextGatedDeltaNetTest(unittest.TestCase):
16981772
"""Test for the Gated Delta Net in Qwen3-Next"""

tests/unit/optimizers_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,5 +362,71 @@ def test_optimizer_without_mask(self, opt_type, mock_path):
362362
self.assertIsNone(kwargs["mask"])
363363

364364

365+
class TrainableParametersMaskTest(parameterized.TestCase):
366+
"""Tests for the trainable parameters mask functionality via get_optimizer"""
367+
368+
def test_get_optimizer_with_trainable_mask(self):
369+
"""Test get_optimizer with a valid trainable_parameters_mask."""
370+
argv = [
371+
"",
372+
get_test_config_path(),
373+
"run_name=test_with_trainable_mask",
374+
"trainable_parameters_mask=['.*indexer.*', 'layer_norm']",
375+
]
376+
config = pyconfig.initialize(argv)
377+
378+
# Use a constant learning rate > 0 to ensure non-zero updates
379+
def learning_rate_schedule(step):
380+
return 1.0
381+
382+
opt = optimizers.get_optimizer(config, learning_rate_schedule)
383+
384+
# We can test the optimizer by creating some dummy params and gradients
385+
# and checking if the updates are zeroed out for non-trainable parameters.
386+
params = {
387+
"layer1": {"kernel": jax.numpy.ones((2, 2)), "indexer": jax.numpy.ones((2, 2))},
388+
"layer2": {"layer_norm": {"scale": jax.numpy.ones((2, 2))}},
389+
"layer3": {"ln": {"scale": jax.numpy.ones((2, 2))}},
390+
}
391+
392+
# Give some non-zero gradients
393+
grads = jax.tree_util.tree_map(lambda x: jax.numpy.ones_like(x) * 0.5, params)
394+
395+
# Initialize optimizer state
396+
opt_state = opt.init(params)
397+
398+
# Compute updates
399+
updates, _ = opt.update(grads, opt_state, params)
400+
401+
# 'layer1/kernel' doesn't match the trainable mask, so it should be frozen (update == 0)
402+
self.assertTrue(jax.numpy.all(updates["layer1"]["kernel"] == 0))
403+
# 'layer3/ln/scale' doesn't match the trainable mask, so it should be frozen (update == 0)
404+
self.assertTrue(jax.numpy.all(updates["layer3"]["ln"]["scale"] == 0))
405+
# 'layer1/indexer' matches, so it should be trained (update != 0)
406+
self.assertFalse(jax.numpy.all(updates["layer1"]["indexer"] == 0))
407+
# 'layer2/layer_norm/scale' matches, so it should be trained (update != 0)
408+
self.assertFalse(jax.numpy.all(updates["layer2"]["layer_norm"]["scale"] == 0))
409+
410+
def test_get_optimizer_without_trainable_mask(self):
411+
"""Test get_optimizer when trainable_parameters_mask is empty."""
412+
argv = ["", get_test_config_path(), "run_name=test", "trainable_parameters_mask=[]"]
413+
config = pyconfig.initialize(argv)
414+
415+
# Use a constant learning rate > 0 to ensure non-zero updates
416+
def learning_rate_schedule(step):
417+
return 1.0
418+
419+
opt = optimizers.get_optimizer(config, learning_rate_schedule)
420+
421+
params = {"layer1": {"kernel": jax.numpy.ones((2, 2))}}
422+
grads = {"layer1": {"kernel": jax.numpy.ones((2, 2)) * 0.5}}
423+
424+
opt_state = opt.init(params)
425+
updates, _ = opt.update(grads, opt_state, params)
426+
427+
# When no trainable mask is provided, nothing is frozen by this mechanism
428+
self.assertFalse(jax.numpy.all(updates["layer1"]["kernel"] == 0))
429+
430+
365431
if __name__ == "__main__":
366432
unittest.main()

0 commit comments

Comments
 (0)