|
76 | 76 | from maxtext.utils.globals import EPS |
77 | 77 |
|
78 | 78 |
|
| 79 | +def _apply_rope(x, cos, sin, interleave=True): |
| 80 | + """Applies rotary positional embedding to the input. |
| 81 | +
|
| 82 | + Args: |
| 83 | + x: Input tensor [B, S, N, H] or [B, S, H]. |
| 84 | + cos: Cosine component of RoPE, [B, S, 1, H/2] or [B, S, H/2]. |
| 85 | + sin: Sine component of RoPE, [B, S, 1, H/2] or [B, S, H/2]. |
| 86 | + interleave: Whether to use interleaved or concatenated layout. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + Rotated input. |
| 90 | + """ |
| 91 | + if interleave: |
| 92 | + x1, x2 = x[..., ::2], x[..., 1::2] |
| 93 | + else: |
| 94 | + x1, x2 = jnp.split(x, 2, axis=-1) |
| 95 | + |
| 96 | + # Handle cases with or without heads dimension |
| 97 | + if x.ndim == 4: |
| 98 | + cos = cos[:, :, None, :] if cos.ndim == 3 else cos |
| 99 | + sin = sin[:, :, None, :] if sin.ndim == 3 else sin |
| 100 | + elif x.ndim == 3: |
| 101 | + cos = cos[:, :, :] if cos.ndim == 3 else cos |
| 102 | + sin = sin[:, :, :] if sin.ndim == 3 else sin |
| 103 | + |
| 104 | + y1 = x1 * cos - x2 * sin |
| 105 | + y2 = x1 * sin + x2 * cos |
| 106 | + |
| 107 | + if interleave: |
| 108 | + rotated = jnp.stack([y1, y2], axis=-1) |
| 109 | + return rotated.reshape(x.shape) |
| 110 | + else: |
| 111 | + return jnp.concatenate([y1, y2], axis=-1) |
| 112 | + |
| 113 | + |
| 114 | +def _get_cos_sin(rotary_embedding, positions, dtype): |
| 115 | + """Computes cos and sin embeddings from the rotary_embedding module.""" |
| 116 | + # Get frequencies |
| 117 | + freqs = rotary_embedding.freqs_cis.at[positions].get( |
| 118 | + out_sharding=getattr(rotary_embedding, "freqs_sharding", None) |
| 119 | + ) # [B, S, half_dim] |
| 120 | + freqs = freqs[:, :, jnp.newaxis, :] # [B, S, 1, half_dim] |
| 121 | + cos = jnp.real(freqs).astype(dtype) |
| 122 | + sin = jnp.imag(freqs).astype(dtype) |
| 123 | + |
| 124 | + if getattr(rotary_embedding, "attention_scaling", False): |
| 125 | + rope_factor = getattr(rotary_embedding, "rope_factor", 1.0) |
| 126 | + scaling = 1.0 if rope_factor <= 1 else (0.1 * math.log(rope_factor) + 1.0) |
| 127 | + cos = cos * scaling |
| 128 | + sin = sin * scaling |
| 129 | + return cos, sin |
| 130 | + |
| 131 | + |
79 | 132 | class Indexer(nnx.Module): |
80 | 133 | """Indexer for DeepSeek Sparse Attention (DSA). |
81 | 134 |
|
@@ -189,9 +242,10 @@ def apply_partial_rope( |
189 | 242 | # indexer_head_dim -> [rope_head_dim, indexer_head_dim - rope_head_dim] |
190 | 243 | x_pe, x_nope = jnp.split(inputs, [self.rope_head_dim], axis=-1) |
191 | 244 | # x_pe [B, S, H, rope_head_dim], positions [B, S] |
192 | | - x_pe = self.rotary_embedding(x_pe, position=inputs_positions) |
| 245 | + cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype) |
| 246 | + x_pe = _apply_rope(x_pe, cos, sin, interleave=self.rotary_embedding.interleave) |
193 | 247 | x = jnp.concatenate([x_pe, x_nope], axis=-1) |
194 | | - return x |
| 248 | + return checkpoint_name(x, "indexer_partial_rope") |
195 | 249 |
|
196 | 250 | def generate_mask(self, topk_indices, s): |
197 | 251 | """ |
@@ -478,6 +532,17 @@ def mla_as_linen( |
478 | 532 | class MLA(Attention): |
479 | 533 | """Multi-Head Latent Attention (MLA) layer.""" |
480 | 534 |
|
| 535 | + def apply_rotary_embedding(self, inputs: Array, inputs_positions: Optional[Array | None] = None, rope_kwargs: dict | None = None): |
| 536 | + """Overrides RoPE with optimized implementation for MLA.""" |
| 537 | + with jax.named_scope("mla_rope"): |
| 538 | + if inputs_positions is None: |
| 539 | + seq_length = inputs.shape[1] |
| 540 | + inputs_positions = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :] |
| 541 | + |
| 542 | + cos, sin = _get_cos_sin(self.rotary_embedding, inputs_positions, self.dtype) |
| 543 | + x_out = _apply_rope(inputs, cos, sin, interleave=self.rotary_embedding.interleave) |
| 544 | + return checkpoint_name(x_out, "mla_rope") |
| 545 | + |
481 | 546 | def __init__( |
482 | 547 | self, |
483 | 548 | config: Config, |
|
0 commit comments