Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,15 @@ moba_topk: 8

# DeepSeek Sparse Attention (DSA)
# deepseek3.2 introduces indexer in MLA
use_sparse_indexer: False
index_head_dim: 128
index_n_heads: 64
index_topk: 2048
# Determines the token selection strategy for indexer loss:
# - False: Uses all tokens (Dense Warm-up).
# - True: Uses only top-k tokens (Sparse Training).
use_indexer: False
indexer_head_dim: 128
indexer_n_heads: 64
indexer_topk: 2048
# Determines the training strategy for the indexer:
# - False (Dense Warm-up): Computes indexer loss over all tokens. Used with `trainable_parameters_mask` to freeze other model parameters.
Comment thread
RissyRan marked this conversation as resolved.
# - True (Sparse Training): Computes indexer loss over top-k tokens only and detaches the indexer input for independent optimization.
# Note: This is only active when `indexer_loss_scaling_factor` > 0.
sparse_indexer_loss: False
indexer_sparse_training: False
# Multiplier for the indexer KL divergence loss
indexer_loss_scaling_factor: 0.0

Expand Down Expand Up @@ -789,6 +789,10 @@ gradient_clipping_threshold: 1.0
gradient_accumulation_steps: 1

opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"
# List of parameter names/patterns to train.
# If non-empty, all other parameters will be frozen. Example: ['.*indexer.*'].
# If empty (default), all parameters are trained.
trainable_parameters_mask: []

# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/configs/models/deepseek-custom.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ rope_interleave: True
rope_truncate: True
rope_attention_scaling: False
# Indexer for DeepSeek Sparse Attention
use_sparse_indexer: True
index_n_heads: 16 # Reduced from 64
index_head_dim: 64 # Reduced from 128
index_topk: 256 # Reduced from 2048
use_indexer: True
indexer_n_heads: 16 # Reduced from 64
indexer_head_dim: 64 # Reduced from 128
indexer_topk: 256 # Reduced from 2048
# Hyper-connections: mHC enabled
mhc_expansion_rate: 4
sinkhorn_iterations: 20
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/configs/models/deepseek3.2-671b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ rope_interleave: True
rope_truncate: True
rope_attention_scaling: False
# Indexer for DeepSeek Sparse Attention
use_sparse_indexer: True
index_n_heads: 64
index_head_dim: 128
index_topk: 2048
use_indexer: True
indexer_n_heads: 64
indexer_head_dim: 128
indexer_topk: 2048
22 changes: 16 additions & 6 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,14 @@ class MlaAttention(BaseModel):
class AttentionIndexer(BaseModel):
"""Configuration for DeepSeek Sparse Attention (DSA): DeepSeek3.2-style MLA with indexer."""

use_sparse_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
index_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
index_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
index_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
sparse_indexer_loss: bool = Field(False, description="Determines the token selection strategy for indexer loss.")
use_indexer: bool = Field(False, description="Whether to use sparse indexer for MLA.")
indexer_head_dim: NonNegativeInt = Field(128, description="Head dim for indexer query and key.")
indexer_n_heads: NonNegativeInt = Field(64, description="Number of query heads in indexer.")
indexer_topk: NonNegativeInt = Field(2048, description="Number of tokens selected by the query token in indexer.")
indexer_sparse_training: bool = Field(
False,
description="Determines the training strategy for the indexer: Dense Warm-up or Sparse Training stage.",
)
indexer_loss_scaling_factor: float = Field(0.0, description="Multiplier for the indexer KL divergence loss.")


Expand Down Expand Up @@ -1185,6 +1188,13 @@ class Optimizer(BaseModel):
ge=-1,
description="Total steps for the LR schedule. -1 defaults to `steps`.",
)
trainable_parameters_mask: list[str] = Field(
default_factory=list,
description=(
"List of parameter names/patterns to train. If non-empty, all other parameters will be frozen, "
"example: ['.*indexer.*']. If empty (default), all parameters are trained."
),
)


class AdamW(BaseModel):
Expand Down Expand Up @@ -2388,7 +2398,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
raise ValueError("`local_checkpoint_period` must be > 0 for emergency checkpointing.")
if self.moba and self.attention not in ("dot_product"):
raise ValueError("MoBA is only supported with dot_product attention.")
if self.use_sparse_indexer:
if self.use_indexer:
if self.q_lora_rank == 0:
raise NotImplementedError("Sparse indexer has not implemented for q_lora_rank = 0.")
supports_dot_product = self.attention == "dot_product"
Expand Down
60 changes: 45 additions & 15 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def __init__(
self.dtype = config.dtype
self.weight_dtype = config.weight_dtype

self.n_heads = config.index_n_heads
self.head_dim = config.index_head_dim
self.index_topk = config.index_topk
self.n_heads = config.indexer_n_heads
self.head_dim = config.indexer_head_dim
self.indexer_topk = config.indexer_topk
self.emb_dim = config.emb_dim
self.rope_head_dim = config.qk_rope_head_dim
self.q_lora_rank = config.q_lora_rank
Expand Down Expand Up @@ -180,13 +180,13 @@ def apply_partial_rope(
2. Input Layout: Indexer uses concatenated layout (interleave=False), whereas MLA uses interleaved (interleave=True).

Args:
inputs: Input array of shape [batch, seqlen, index_n_heads, index_head_dim].
inputs: Input array of shape [batch, seqlen, indexer_n_heads, indexer_head_dim].
positions: Position array of shape [batch, seqlen].

Returns:
Array with partial RoPE applied, with shape [batch, seqlen, index_n_heads, index_head_dim]
Array with partial RoPE applied, with shape [batch, seqlen, indexer_n_heads, indexer_head_dim]
"""
# index_head_dim -> [rope_head_dim, index_head_dim - rope_head_dim]
# indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1)
# x_pe [B, S, H, rope_head_dim], positions [B, S]
x_pe = self.rotary_embedding(x_pe, position=inputs_positions)
Expand Down Expand Up @@ -256,14 +256,37 @@ def __call__(
b: Batch size
t: Query Sequence Length (Target), note t = s here
s: Key/Value Sequence Length (Source)
h: Number of Indexer Heads (index_n_heads)
d: Indexer Head Dimension (index_head_dim)
h: Number of Indexer Heads (indexer_n_heads)
d: Indexer Head Dimension (indexer_head_dim)
"""
# NOTE: If sequence length <= topk, indexer always selects all tokens.
if self.config.max_target_length <= self.index_topk:
if self.config.max_target_length <= self.indexer_topk:
return None, None, None

bsz, seqlen, _ = inputs_q.shape # s = t = seqlen
# ==============================================================================
# Gradient Isolation Strategy: Main Model vs. Indexer
# ==============================================================================
# This creates a barrier to train both components independently, and applies
# for both Dense Warm-up and Sparse Training stages:
#
# Forward Pass:
# - The Indexer receives a detached copy of the inputs (via `stop_gradient`)
# to independently calculate its scores and `indexer_loss`.
#
# Backward Pass (Main Model):
# - The main model optimizes its weights based solely on the LM loss.
# - The `indexer_mask` in the Attention layer prevents gradients from the main
# loss from flowing into the Indexer's weights.
#
# Backward Pass (Indexer):
# - Gradients from the `indexer_loss` flow back to update the Indexer's weights.
# - The `stop_gradient` applied to the inputs acts as a mathematical wall, dropping
# gradients to 0.0 and preventing the Indexer loss from altering the main model's
# earlier layers.
inputs_q = jax.lax.stop_gradient(inputs_q)
low_rank_q = jax.lax.stop_gradient(low_rank_q)
inputs_kv = jax.lax.stop_gradient(inputs_kv)

# Query Processing: Project from Latent low_rank_q
q = self.wq_b(low_rank_q) # [b, t, q_lora_rank] -> [b, t, h * d]
Expand Down Expand Up @@ -295,7 +318,7 @@ def __call__(
indexer_score += attention_mask

# TopK selection based on index score
_, topk_indices = jax.lax.top_k(indexer_score, k=self.index_topk) # topk_indices [b, t, k]
_, topk_indices = jax.lax.top_k(indexer_score, k=self.indexer_topk) # topk_indices [b, t, k]

# Create Sparse Index Mask: 0 and large negatives
indexer_mask = self.generate_mask(topk_indices, seqlen) # [b, t, s]
Expand Down Expand Up @@ -607,8 +630,8 @@ def __init__(
)

# Initialize Indexer
self.use_sparse_indexer = config.use_sparse_indexer
if self.use_sparse_indexer:
self.use_indexer = config.use_indexer
if self.use_indexer:
# Need two versions of rope.
# MLA applies yarn with interleave layout.
# Indexer applies yarn with concatenate layout.
Expand Down Expand Up @@ -989,6 +1012,13 @@ def calculate_indexer_loss(
Returns:
The computed KL divergence loss.
"""
# Detach main model components from the computational graph.
# The indexer should match the main model, but the main model should not be influenced
# by the indexer's learning progress via this loss in sparse training stage.
# We also apply this during the Dense Warm-up stage to save compute and memory.
query = jax.lax.stop_gradient(query)
key = jax.lax.stop_gradient(key)

# Compute attention scores: [b, t, h, d] @ [b, s, h, d] -> [b, h, t, s]
attention_scores = jnp.einsum("bthd, bshd -> bhts", query, key, precision=self.config.matmul_precision)

Expand Down Expand Up @@ -1080,7 +1110,7 @@ def __call__(

# Indexer Logic
indexer_mask = None
if self.use_sparse_indexer:
if self.use_indexer:
if model_mode != MODEL_MODE_TRAIN:
raise NotImplementedError("Sparse indexer has not implemented for inference yet.")
# generate mask: with 0 and large negative, [b, 1, 1, q_len, kv_len] -> [b, q_len, kv_len]
Expand All @@ -1098,14 +1128,14 @@ def __call__(
attention_mask=attention_mask,
)

if self.config.indexer_loss_scaling_factor > 0.0:
if indexer_mask is not None and self.config.indexer_loss_scaling_factor > 0.0:
indexer_loss = self.calculate_indexer_loss(
indexer_score=indexer_score,
query=query,
key=key,
attention_mask=attention_mask,
indexer_mask=indexer_mask,
sparse_loss=self.config.sparse_indexer_loss,
sparse_loss=self.config.indexer_sparse_training,
scaling_factor=self.config.indexer_loss_scaling_factor,
)
self.sow(nnx.Intermediate, "indexer_loss", indexer_loss)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,7 +1415,7 @@ def wrap_flash_attention(
decoder_segment_ids_tuple = None

if self.config.use_tokamax_splash:
if self.config.use_sparse_indexer and indexer_mask is not None:
if self.config.use_indexer and indexer_mask is not None:
# Construct the splash kernel call with dynamic mask
def dynamic_mask_splash_kernel(q, k, v, segment, sinks, indexer_mask):
splash_kernel = tokamax_splash_kernel.make_dynamic_splash_mha(
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,10 @@ def __call__(
# When invoking from vLLM with RPA attention, logit computation is deferred to a later stage.
if cfg.attention == "vllm_rpa":
logits = None

# When in the Indexer Dense Warm-up stage, skip the expensive output head projection
# for efficiency, as the main model is frozen and the LM loss is not needed.
elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN:
logits = None
# When vocab tiling is enabled in training mode, full logits won't generate to reduce memory
# Instead, we keep track on the hidden states, which has smaller size compared to full logits
elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN:
Expand Down
48 changes: 34 additions & 14 deletions src/maxtext/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,35 @@
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers


def get_adamw_mask(config):
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
if not getattr(config, "adamw_mask", None):
def _get_path_mask_fn(patterns, match_returns_true=True):
"""Helper to create a mask function from a list of regex patterns."""
if not patterns:
return None

compiled_patterns = [re.compile(pattern) for pattern in config.adamw_mask]
compiled_patterns = [re.compile(pattern) for pattern in patterns]

def mask_fn(params):
def _is_decayed(path, _):
def _is_masked(path, _):
# Join path keys into a single string for pattern matching (e.g., "layer1/bias")
path_str = "/".join(str(getattr(p, "key", getattr(p, "idx", getattr(p, "name", p)))) for p in path)
# If any pattern in adamw_mask matches the path, exclude from weight decay (return False).
# Otherwise, apply weight decay (return True).
return not any(pattern.search(path_str) for pattern in compiled_patterns)
path_str = jax.tree_util.keystr(path, simple=True, separator="/")
matched = any(pattern.search(path_str) for pattern in compiled_patterns)
return matched if match_returns_true else not matched

return jax.tree_util.tree_map_with_path(_is_decayed, params)
return jax.tree_util.tree_map_with_path(_is_masked, params)

return mask_fn


def get_adamw_mask(config):
"""Create a mask function for AdamW optimizer to exclude certain parameters from weight decay."""
return _get_path_mask_fn(getattr(config, "adamw_mask", None), match_returns_true=False)


def get_optimizer(config, learning_rate_schedule, model=None):
"""Create optimizer."""
if config.opt_type == "adamw":
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
return optax.adamw(
base_opt = optax.adamw(
learning_rate_schedule,
b1=config.adam_b1,
b2=config.adam_b2,
Expand All @@ -59,7 +63,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
mask=get_adamw_mask(config),
)
elif config.opt_type == "adam_pax":
return adam_pax(
base_opt = adam_pax(
learning_rate_schedule,
beta1=config.adam_b1,
beta2=config.adam_b2,
Expand All @@ -69,7 +73,7 @@ def get_optimizer(config, learning_rate_schedule, model=None):
mask=get_adamw_mask(config),
)
elif config.opt_type == "sgd":
return optax.sgd(learning_rate_schedule)
base_opt = optax.sgd(learning_rate_schedule)
elif config.opt_type == "muon":
# extract muon dimension number from model structure
if model is not None:
Expand All @@ -92,10 +96,26 @@ def get_optimizer(config, learning_rate_schedule, model=None):
"adam_eps_root": config.adam_eps_root,
"adam_weight_decay": config.adam_weight_decay,
}
return muon(**muon_kwargs)
base_opt = muon(**muon_kwargs)
else:
raise ValueError(f"{config.opt_type=} is not a supported.")

# If a whitelist of trainable parameters is provided, freeze everything else.
# When trainable_parameters_mask is empty, freeze_mask_fn is None and all parameters are trained.
trainable_patterns = getattr(config, "trainable_parameters_mask", None)
freeze_mask_fn = _get_path_mask_fn(trainable_patterns, match_returns_true=False)
if freeze_mask_fn is not None:
Comment thread
RissyRan marked this conversation as resolved.
# Use optax.multi_transform to explicitly map frozen parameters to a stateless set_to_zero() optimizer.
# If we simply wrapped base_opt in optax.masked() or chained it, Optax would still allocate
# massive states (momentum, variance) for the entire model before zeroing the updates.
# By using multi_transform, only the trainable parameters get states allocated.
return optax.multi_transform(
{"trainable": base_opt, "frozen": optax.set_to_zero()},
lambda params: jax.tree_util.tree_map(lambda x: "frozen" if x else "trainable", freeze_mask_fn(params)),
)

return base_opt


def adam_pax(
learning_rate_fn: optax.Schedule,
Expand Down
Loading
Loading