|
29 | 29 | from flax import linen as nn |
30 | 30 | from flax import nnx |
31 | 31 |
|
| 32 | +<<<<<<< HEAD:src/maxtext/models/qwen3.py |
32 | 33 | from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN |
33 | 34 | from maxtext.layers import attentions |
34 | 35 | from maxtext.layers import initializers as max_initializers |
|
43 | 44 | from maxtext.layers.moe import RoutedMoE |
44 | 45 | from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned |
45 | 46 |
|
| 47 | +======= |
| 48 | +from jax.sharding import PartitionSpec as P |
| 49 | +from jax.experimental.shard_map import shard_map |
| 50 | + |
| 51 | +from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN |
| 52 | +from MaxText.layers import attentions |
| 53 | +from MaxText.layers import initializers as max_initializers |
| 54 | +from MaxText.layers import moe |
| 55 | +from MaxText.layers import nnx_wrappers |
| 56 | +from MaxText.layers import quantizations |
| 57 | +from MaxText.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding |
| 58 | +from MaxText.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated |
| 59 | +from MaxText.layers.quantizations import AqtQuantization as Quant |
| 60 | +from MaxText.layers.attentions import Attention |
| 61 | +from MaxText.layers.linears import DenseGeneral, MlpBlock |
| 62 | +from MaxText.layers.moe import RoutedMoE |
| 63 | +from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned |
| 64 | +from maxtext.inference import page_manager |
| 65 | +>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
46 | 66 | from maxtext.utils import max_utils |
47 | 67 | from maxtext.inference import page_manager, kvcache |
48 | 68 |
|
| 69 | +from maxtext.scratch_code import gdn_pallas |
49 | 70 |
|
50 | 71 | # ----------------------------------------- |
51 | 72 | # Qwen3-Next Layer Implementations |
@@ -177,6 +198,150 @@ def scan_body(prev_state, x): |
177 | 198 | return core_attn_out, final_state if output_final_state else None |
178 | 199 |
|
179 | 200 |
|
| 201 | +def pallas_chunk_gated_delta_rule( |
| 202 | + query: jax.Array, |
| 203 | + key: jax.Array, |
| 204 | + value: jax.Array, |
| 205 | + g: jax.Array, |
| 206 | + beta: jax.Array, |
| 207 | + chunk_size: int = 64, |
| 208 | + initial_state: None | jax.Array = None, |
| 209 | + use_qk_norm_in_gdn: bool = False, |
| 210 | + compute_dtype: jnp.dtype = jnp.bfloat16, |
| 211 | + mesh: Mesh | None = None, |
| 212 | +) -> tuple[jax.Array, None | jax.Array]: |
| 213 | + """ |
| 214 | + Pallas-accelerated version of Gated Delta Rule. |
| 215 | + """ |
| 216 | + # ========================================================================= |
| 217 | + # STAGE 1: PREPARATION & PADDING |
| 218 | + # ========================================================================= |
| 219 | + initial_dtype = query.dtype |
| 220 | + if use_qk_norm_in_gdn: |
| 221 | + from MaxText.layers.normalizations import l2norm |
| 222 | + query = l2norm(query, dim=-1, eps=1e-6) |
| 223 | + key = l2norm(key, dim=-1, eps=1e-6) |
| 224 | + |
| 225 | + g = g.astype(jnp.float32) |
| 226 | + query = query.astype(compute_dtype) |
| 227 | + key = key.astype(compute_dtype) |
| 228 | + value = value.astype(compute_dtype) |
| 229 | + beta = beta.astype(compute_dtype) |
| 230 | + |
| 231 | + scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype) |
| 232 | + query = query * scale |
| 233 | + |
| 234 | + B, S, H, K_dim = key.shape |
| 235 | + V_dim = value.shape[-1] |
| 236 | + |
| 237 | + pad_len = (chunk_size - (S % chunk_size)) % chunk_size |
| 238 | + if pad_len > 0: |
| 239 | + pad_fn = lambda x, val=0.0: jnp.pad(x, ((0,0), (0, pad_len)) + ((0,0),)*(x.ndim-2), constant_values=val) |
| 240 | + query = pad_fn(query) |
| 241 | + key = pad_fn(key) |
| 242 | + value = pad_fn(value) |
| 243 | + g = pad_fn(g) |
| 244 | + beta = pad_fn(beta) |
| 245 | + |
| 246 | + num_chunks = query.shape[1] // chunk_size |
| 247 | + |
| 248 | + def to_chunk(x): |
| 249 | + return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4) |
| 250 | + def to_chunk_scalar(x): |
| 251 | + return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2) |
| 252 | + |
| 253 | + q_c = to_chunk(query) |
| 254 | + k_c = to_chunk(key) |
| 255 | + v_c = to_chunk(value) |
| 256 | + g_c = to_chunk_scalar(g) |
| 257 | + beta_c = to_chunk_scalar(beta) |
| 258 | + |
| 259 | + # ========================================================================= |
| 260 | + # STAGE 2: INTRA-CHUNK PRE-COMPUTATION |
| 261 | + # ========================================================================= |
| 262 | + g_cumsum = jnp.cumsum(g_c, axis=-1) |
| 263 | + k_beta = k_c * beta_c[..., None] |
| 264 | + |
| 265 | + S = jnp.matmul(k_c, k_beta.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST) |
| 266 | + S = S.astype(jnp.float32) |
| 267 | + g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :] |
| 268 | + mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1) |
| 269 | + g_diff = jnp.where(mask, g_diff, -1e30) |
| 270 | + S = S * jnp.exp(g_diff) |
| 271 | + S = jnp.where(mask, S, 0.0) |
| 272 | + |
| 273 | + identity = jnp.eye(chunk_size, dtype=jnp.float32) |
| 274 | + identity_broadcasted = jnp.broadcast_to(identity, S.shape) |
| 275 | + A = jax.scipy.linalg.solve_triangular(identity + S, identity_broadcasted, lower=True, unit_diagonal=True) |
| 276 | + |
| 277 | + v_beta = v_c * beta_c[..., None] |
| 278 | + u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST) |
| 279 | + u_chunks = u_chunks.astype(compute_dtype) |
| 280 | + |
| 281 | + k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None] |
| 282 | + w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST) |
| 283 | + w_chunks = w_chunks.astype(compute_dtype) |
| 284 | + |
| 285 | + # ========================================================================= |
| 286 | + # STAGE 3: INTER-CHUNK RECURRENCE (Pallas Kernel + shard_map) |
| 287 | + # ========================================================================= |
| 288 | + # Transpose to (Batch, Heads, NumChunks, ChunkSize, Dim) for Pallas |
| 289 | + w_p = w_chunks.transpose(0, 2, 1, 3, 4) |
| 290 | + u_p = u_chunks.transpose(0, 2, 1, 3, 4) |
| 291 | + q_p = q_c.transpose(0, 2, 1, 3, 4) |
| 292 | + k_p = k_c.transpose(0, 2, 1, 3, 4) |
| 293 | + v_p = v_c.transpose(0, 2, 1, 3, 4) |
| 294 | + g_p = g_cumsum.transpose(0, 2, 1, 3) |
| 295 | + beta_p = beta_c.transpose(0, 2, 1, 3) |
| 296 | + |
| 297 | + # Handle initial state |
| 298 | + if initial_state is None: |
| 299 | + h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=compute_dtype) |
| 300 | + else: |
| 301 | + h_init = initial_state.astype(compute_dtype) |
| 302 | + |
| 303 | + # Invoke Kernel |
| 304 | + if mesh is not None: |
| 305 | + # Mesh Partitioning |
| 306 | + axis_names = mesh.axis_names |
| 307 | + batch_axes = [ax for ax in ('data', 'fsdp', 'fsdp_transpose', 'expert') if ax in axis_names] |
| 308 | + batch_spec = tuple(batch_axes) if batch_axes else None |
| 309 | + head_axes = [ax for ax in ('tensor', 'model') if ax in axis_names] |
| 310 | + head_spec = tuple(head_axes) if head_axes else None |
| 311 | + |
| 312 | + # Specs: B, H, ... |
| 313 | + # h_init is (B, H, K, V) |
| 314 | + in_specs = P(batch_spec, head_spec, None, None, None) |
| 315 | + scalar_specs = P(batch_spec, head_spec, None, None) |
| 316 | + state_spec = P(batch_spec, head_spec, None, None) |
| 317 | + |
| 318 | + sharded_gdn = shard_map( |
| 319 | + gdn_pallas.gdn_pallas_layer, |
| 320 | + mesh=mesh, |
| 321 | + in_specs=(in_specs, in_specs, in_specs, in_specs, in_specs, scalar_specs, scalar_specs, state_spec), |
| 322 | + out_specs=(in_specs, state_spec), # Returns (out, final_state) |
| 323 | + check_rep=False |
| 324 | + ) |
| 325 | + |
| 326 | + o_pallas, h_final = sharded_gdn(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init) |
| 327 | + else: |
| 328 | + # Single Device |
| 329 | + o_pallas, h_final = gdn_pallas.gdn_pallas_layer(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init) |
| 330 | + |
| 331 | + o_chunks = o_pallas.transpose(0, 2, 1, 3, 4) |
| 332 | + |
| 333 | + # ========================================================================= |
| 334 | + # STAGE 4: FINALIZATION |
| 335 | + # ========================================================================= |
| 336 | + o = o_chunks.reshape(B, -1, H, V_dim) |
| 337 | + |
| 338 | + if pad_len > 0: |
| 339 | + o = o[:, :S, :, :] |
| 340 | + |
| 341 | + o = o.astype(initial_dtype) |
| 342 | + |
| 343 | + return o, h_final |
| 344 | + |
180 | 345 | def jax_chunk_gated_delta_rule( |
181 | 346 | query: Array, |
182 | 347 | key: Array, |
@@ -381,13 +546,18 @@ class Qwen3NextGatedDeltaNet(nnx.Module): |
381 | 546 | 2. output = Linear_out(y) |
382 | 547 | """ |
383 | 548 |
|
| 549 | +<<<<<<< HEAD:src/maxtext/models/qwen3.py |
384 | 550 | def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs): |
| 551 | +======= |
| 552 | + def __init__(self, config: Config, *, rngs: nnx.Rngs, mesh: Mesh=None): |
| 553 | +>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
385 | 554 | """ |
386 | 555 | Args: |
387 | 556 | config: MaxText configuration object. |
388 | 557 | rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper. |
389 | 558 | """ |
390 | 559 | self.config = config |
| 560 | + self.mesh = mesh |
391 | 561 | cfg = self.config |
392 | 562 |
|
393 | 563 | in_features = cfg.emb_dim |
@@ -637,16 +807,12 @@ def extract_state(c_in, v_len): |
637 | 807 | else: |
638 | 808 | recurrent_state = recurrent_state[:batch] |
639 | 809 |
|
640 | | - core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule( |
641 | | - query, |
642 | | - key, |
643 | | - value, |
644 | | - g, |
645 | | - beta, |
646 | | - chunk_size=cfg.gdn_chunk_size, |
647 | | - initial_state=recurrent_state, |
648 | | - use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, |
| 810 | + core_attn_out, recurrent_state_out = pallas_chunk_gated_delta_rule( |
| 811 | + query, key, value, g, beta, |
| 812 | + chunk_size=cfg.gdn_chunk_size, |
| 813 | + use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn, |
649 | 814 | compute_dtype=cfg.dtype, |
| 815 | + mesh=self.mesh |
650 | 816 | ) |
651 | 817 |
|
652 | 818 | if model_mode != MODEL_MODE_TRAIN: |
@@ -982,7 +1148,11 @@ def __init__( |
982 | 1148 | rngs=rngs, |
983 | 1149 | ) |
984 | 1150 | else: |
| 1151 | +<<<<<<< HEAD:src/maxtext/models/qwen3.py |
985 | 1152 | self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs) |
| 1153 | +======= |
| 1154 | + self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs, mesh=self.mesh) |
| 1155 | +>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py |
986 | 1156 |
|
987 | 1157 | # Second LayerNorm, applied before the MoE block. |
988 | 1158 | self.post_attention_layernorm = Qwen3NextRMSNorm( |
|
0 commit comments