diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index db6afd813c..8a486e3ac3 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -76,6 +76,71 @@ from maxtext.utils.globals import EPS +def _apply_rope(x, cos, sin, interleave=True): + """Applies rotary positional embedding to the input. + + Args: + x: Input tensor [B, S, N, H] or [B, S, H]. + cos: Cosine component of RoPE, [B, S, 1, H/2] or [B, S, H/2]. + sin: Sine component of RoPE, [B, S, 1, H/2] or [B, S, H/2]. + interleave: Whether to use interleaved or concatenated layout. + + Returns: + Rotated input. + """ + if interleave: + x1, x2 = x[..., ::2], x[..., 1::2] + else: + x1, x2 = jnp.split(x, 2, axis=-1) + + # Handle cases with or without heads dimension + if x.ndim == 4: + cos = cos[:, :, None, :] if cos.ndim == 3 else cos + sin = sin[:, :, None, :] if sin.ndim == 3 else sin + elif x.ndim == 3: + cos = cos[:, :, :] if cos.ndim == 3 else cos + sin = sin[:, :, :] if sin.ndim == 3 else sin + + y1 = x1 * cos - x2 * sin + y2 = x1 * sin + x2 * cos + + if interleave: + rotated = jnp.stack([y1, y2], axis=-1) + return rotated.reshape(x.shape) + else: + return jnp.concatenate([y1, y2], axis=-1) + + +def _compute_rope(head_dim, positions, theta, dtype): + """Computes RoPE frequencies on the fly for given positions.""" + freqs = 1.0 / ( + theta ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim) + ) + # positions shape [B, S], freqs shape [D/2] -> angles shape [B, S, D/2] + angles = positions[..., None] * freqs + return jnp.cos(angles).astype(dtype), jnp.sin(angles).astype(dtype) + + +def _get_cos_sin(rotary_embedding, positions, dtype): + """Computes cos and sin embeddings from the rotary_embedding module.""" + # Use optimized on-the-fly computation instead of table lookup + head_dim = rotary_embedding.embedding_dims + theta = rotary_embedding.rope_theta + + cos, sin = _compute_rope(head_dim, positions, theta, dtype) + + # Add heads dimension for broadcasting: [B, S, D/2] -> [B, S, 1, D/2] + cos = cos[:, :, jnp.newaxis, :] + sin = sin[:, :, jnp.newaxis, :] + + if getattr(rotary_embedding, "attention_scaling", False): + rope_factor = getattr(rotary_embedding, "rope_factor", 1.0) + scaling = 1.0 if rope_factor <= 1 else (0.1 * math.log(rope_factor) + 1.0) + cos = cos * scaling + sin = sin * scaling + return cos, sin + + class Indexer(nnx.Module): """Indexer for DeepSeek Sparse Attention (DSA). @@ -189,9 +254,10 @@ def apply_partial_rope( # indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim] x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1) # x_pe [B, S, H, rope_head_dim], positions [B, S] - x_pe = self.rotary_embedding(x_pe, position=inputs_positions) + cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype) + x_pe = _apply_rope(x_pe, cos, sin, interleave=self.rotary_embedding.interleave) x = jnp.concatenate([x_pe, x_nope], axis=-1) - return x + return checkpoint_name(x, "indexer_partial_rope") def generate_mask(self, topk_indices, s): """ @@ -478,6 +544,17 @@ def mla_as_linen( class MLA(Attention): """Multi-Head Latent Attention (MLA) layer.""" + def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None): + """Overrides RoPE with optimized implementation for MLA.""" + with jax.named_scope("mla_rope"): + if inputs_positions is None: + seq_length = inputs.shape[1] + inputs_positions = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :] + + cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype) + x_out = _apply_rope(inputs, cos, sin, interleave=self.rotary_embedding.interleave) + return checkpoint_name(x_out, "mla_rope") + def __init__( self, config: Config,