Skip to content

Commit 33eab9e

Browse files
committed
change
1 parent dc29039 commit 33eab9e

1 file changed

Lines changed: 67 additions & 2 deletions

File tree

src/maxtext/layers/attention_mla.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,59 @@
7676
from maxtext.utils.globals import EPS
7777

7878

79+
def _apply_rope(x, cos, sin, interleave=True):
80+
"""Applies rotary positional embedding to the input.
81+
82+
Args:
83+
x: Input tensor [B, S, N, H] or [B, S, H].
84+
cos: Cosine component of RoPE, [B, S, 1, H/2] or [B, S, H/2].
85+
sin: Sine component of RoPE, [B, S, 1, H/2] or [B, S, H/2].
86+
interleave: Whether to use interleaved or concatenated layout.
87+
88+
Returns:
89+
Rotated input.
90+
"""
91+
if interleave:
92+
x1, x2 = x[..., ::2], x[..., 1::2]
93+
else:
94+
x1, x2 = jnp.split(x, 2, axis=-1)
95+
96+
# Handle cases with or without heads dimension
97+
if x.ndim == 4:
98+
cos = cos[:, :, None, :] if cos.ndim == 3 else cos
99+
sin = sin[:, :, None, :] if sin.ndim == 3 else sin
100+
elif x.ndim == 3:
101+
cos = cos[:, :, :] if cos.ndim == 3 else cos
102+
sin = sin[:, :, :] if sin.ndim == 3 else sin
103+
104+
y1 = x1 * cos - x2 * sin
105+
y2 = x1 * sin + x2 * cos
106+
107+
if interleave:
108+
rotated = jnp.stack([y1, y2], axis=-1)
109+
return rotated.reshape(x.shape)
110+
else:
111+
return jnp.concatenate([y1, y2], axis=-1)
112+
113+
114+
def _get_cos_sin(rotary_embedding, positions, dtype):
115+
"""Computes cos and sin embeddings from the rotary_embedding module."""
116+
# Get frequencies
117+
freqs = rotary_embedding.freqs_cis.at[positions].get(
118+
out_sharding=getattr(rotary_embedding, "freqs_sharding", None)
119+
) # [B, S, half_dim]
120+
freqs = freqs[:, :, jnp.newaxis, :] # [B, S, 1, half_dim]
121+
cos = jnp.real(freqs).astype(dtype)
122+
sin = jnp.imag(freqs).astype(dtype)
123+
124+
if getattr(rotary_embedding, "attention_scaling", False):
125+
rope_factor = getattr(rotary_embedding, "rope_factor", 1.0)
126+
scaling = 1.0 if rope_factor <= 1 else (0.1 * math.log(rope_factor) + 1.0)
127+
cos = cos * scaling
128+
sin = sin * scaling
129+
return cos, sin
130+
131+
79132
class Indexer(nnx.Module):
80133
"""Indexer for DeepSeek Sparse Attention (DSA).
81134
@@ -189,9 +242,10 @@ def apply_partial_rope(
189242
# indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim]
190243
x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1)
191244
# x_pe [B, S, H, rope_head_dim], positions [B, S]
192-
x_pe = self.rotary_embedding(x_pe, position=inputs_positions)
245+
cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype)
246+
x_pe = _apply_rope(x_pe, cos, sin, interleave=self.rotary_embedding.interleave)
193247
x = jnp.concatenate([x_pe, x_nope], axis=-1)
194-
return x
248+
return checkpoint_name(x, "indexer_partial_rope")
195249

196250
def generate_mask(self, topk_indices, s):
197251
"""
@@ -478,6 +532,17 @@ def mla_as_linen(
478532
class MLA(Attention):
479533
"""Multi-Head Latent Attention (MLA) layer."""
480534

535+
def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None):
536+
"""Overrides RoPE with optimized implementation for MLA."""
537+
with jax.named_scope("mla_rope"):
538+
if inputs_positions is None:
539+
seq_length = inputs.shape[1]
540+
inputs_positions = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :]
541+
542+
cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype)
543+
x_out = _apply_rope(inputs, cos, sin, interleave=self.rotary_embedding.interleave)
544+
return checkpoint_name(x_out, "mla_rope")
545+
481546
def __init__(
482547
self,
483548
config: Config,

0 commit comments

Comments
 (0)