Skip to content

Commit b03fe07

Browse files
Script to test gdn changes
Add backward pass checks & memory checks Add backward pass & memory consumption checks Update memory calcs Optimizations made to GDN impl in qwen3.py (3x speedup) Update dummy configs to align with q3-next Update tflops calc to align with WY-optimized GDN remove mixed precision Update config for chunk size update dtype Add NaN test in backward pass Fix exploding gradient in gdn Reintroduce mixed precision typo in bloat16 typo fixed convert to float test pallas kernel for gdn wrong api name fix function positional args fix pallas code fix tensor indexing error only optimize forward pass update pallas code use float mask fix function returns add shardmap to kernel update with kernel agent suggestions fix matrix indexing fix matrix indexing mask before exp update benchmarking script
1 parent 093ab89 commit b03fe07

5 files changed

Lines changed: 968 additions & 9 deletions

File tree

src/maxtext/models/qwen3.py

Lines changed: 179 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from flax import linen as nn
3030
from flax import nnx
3131

32+
<<<<<<< HEAD:src/maxtext/models/qwen3.py
3233
from maxtext.common.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN
3334
from maxtext.layers import attentions
3435
from maxtext.layers import initializers as max_initializers
@@ -43,9 +44,29 @@
4344
from maxtext.layers.moe import RoutedMoE
4445
from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned
4546

47+
=======
48+
from jax.sharding import PartitionSpec as P
49+
from jax.experimental.shard_map import shard_map
50+
51+
from MaxText.common_types import AttentionType, Config, DType, Array, BATCH, LENGTH_NO_EXP, EMBED, MODEL_MODE_TRAIN
52+
from MaxText.layers import attentions
53+
from MaxText.layers import initializers as max_initializers
54+
from MaxText.layers import moe
55+
from MaxText.layers import nnx_wrappers
56+
from MaxText.layers import quantizations
57+
from MaxText.layers.embeddings import Qwen3OmniMoeVisionPosEmbedInterpolate, PositionalEmbedding
58+
from MaxText.layers.normalizations import RMSNorm, l2norm, Qwen3NextRMSNorm, Qwen3NextRMSNormGated
59+
from MaxText.layers.quantizations import AqtQuantization as Quant
60+
from MaxText.layers.attentions import Attention
61+
from MaxText.layers.linears import DenseGeneral, MlpBlock
62+
from MaxText.layers.moe import RoutedMoE
63+
from MaxText.layers.initializers import nd_dense_init, variable_to_logically_partitioned
64+
from maxtext.inference import page_manager
65+
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
4666
from maxtext.utils import max_utils
4767
from maxtext.inference import page_manager, kvcache
4868

69+
from maxtext.scratch_code import gdn_pallas
4970

5071
# -----------------------------------------
5172
# Qwen3-Next Layer Implementations
@@ -177,6 +198,150 @@ def scan_body(prev_state, x):
177198
return core_attn_out, final_state if output_final_state else None
178199

179200

201+
def pallas_chunk_gated_delta_rule(
202+
query: jax.Array,
203+
key: jax.Array,
204+
value: jax.Array,
205+
g: jax.Array,
206+
beta: jax.Array,
207+
chunk_size: int = 64,
208+
initial_state: None | jax.Array = None,
209+
use_qk_norm_in_gdn: bool = False,
210+
compute_dtype: jnp.dtype = jnp.bfloat16,
211+
mesh: Mesh | None = None,
212+
) -> tuple[jax.Array, None | jax.Array]:
213+
"""
214+
Pallas-accelerated version of Gated Delta Rule.
215+
"""
216+
# =========================================================================
217+
# STAGE 1: PREPARATION & PADDING
218+
# =========================================================================
219+
initial_dtype = query.dtype
220+
if use_qk_norm_in_gdn:
221+
from MaxText.layers.normalizations import l2norm
222+
query = l2norm(query, dim=-1, eps=1e-6)
223+
key = l2norm(key, dim=-1, eps=1e-6)
224+
225+
g = g.astype(jnp.float32)
226+
query = query.astype(compute_dtype)
227+
key = key.astype(compute_dtype)
228+
value = value.astype(compute_dtype)
229+
beta = beta.astype(compute_dtype)
230+
231+
scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype)
232+
query = query * scale
233+
234+
B, S, H, K_dim = key.shape
235+
V_dim = value.shape[-1]
236+
237+
pad_len = (chunk_size - (S % chunk_size)) % chunk_size
238+
if pad_len > 0:
239+
pad_fn = lambda x, val=0.0: jnp.pad(x, ((0,0), (0, pad_len)) + ((0,0),)*(x.ndim-2), constant_values=val)
240+
query = pad_fn(query)
241+
key = pad_fn(key)
242+
value = pad_fn(value)
243+
g = pad_fn(g)
244+
beta = pad_fn(beta)
245+
246+
num_chunks = query.shape[1] // chunk_size
247+
248+
def to_chunk(x):
249+
return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4)
250+
def to_chunk_scalar(x):
251+
return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2)
252+
253+
q_c = to_chunk(query)
254+
k_c = to_chunk(key)
255+
v_c = to_chunk(value)
256+
g_c = to_chunk_scalar(g)
257+
beta_c = to_chunk_scalar(beta)
258+
259+
# =========================================================================
260+
# STAGE 2: INTRA-CHUNK PRE-COMPUTATION
261+
# =========================================================================
262+
g_cumsum = jnp.cumsum(g_c, axis=-1)
263+
k_beta = k_c * beta_c[..., None]
264+
265+
S = jnp.matmul(k_c, k_beta.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST)
266+
S = S.astype(jnp.float32)
267+
g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :]
268+
mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1)
269+
g_diff = jnp.where(mask, g_diff, -1e30)
270+
S = S * jnp.exp(g_diff)
271+
S = jnp.where(mask, S, 0.0)
272+
273+
identity = jnp.eye(chunk_size, dtype=jnp.float32)
274+
identity_broadcasted = jnp.broadcast_to(identity, S.shape)
275+
A = jax.scipy.linalg.solve_triangular(identity + S, identity_broadcasted, lower=True, unit_diagonal=True)
276+
277+
v_beta = v_c * beta_c[..., None]
278+
u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST)
279+
u_chunks = u_chunks.astype(compute_dtype)
280+
281+
k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None]
282+
w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST)
283+
w_chunks = w_chunks.astype(compute_dtype)
284+
285+
# =========================================================================
286+
# STAGE 3: INTER-CHUNK RECURRENCE (Pallas Kernel + shard_map)
287+
# =========================================================================
288+
# Transpose to (Batch, Heads, NumChunks, ChunkSize, Dim) for Pallas
289+
w_p = w_chunks.transpose(0, 2, 1, 3, 4)
290+
u_p = u_chunks.transpose(0, 2, 1, 3, 4)
291+
q_p = q_c.transpose(0, 2, 1, 3, 4)
292+
k_p = k_c.transpose(0, 2, 1, 3, 4)
293+
v_p = v_c.transpose(0, 2, 1, 3, 4)
294+
g_p = g_cumsum.transpose(0, 2, 1, 3)
295+
beta_p = beta_c.transpose(0, 2, 1, 3)
296+
297+
# Handle initial state
298+
if initial_state is None:
299+
h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=compute_dtype)
300+
else:
301+
h_init = initial_state.astype(compute_dtype)
302+
303+
# Invoke Kernel
304+
if mesh is not None:
305+
# Mesh Partitioning
306+
axis_names = mesh.axis_names
307+
batch_axes = [ax for ax in ('data', 'fsdp', 'fsdp_transpose', 'expert') if ax in axis_names]
308+
batch_spec = tuple(batch_axes) if batch_axes else None
309+
head_axes = [ax for ax in ('tensor', 'model') if ax in axis_names]
310+
head_spec = tuple(head_axes) if head_axes else None
311+
312+
# Specs: B, H, ...
313+
# h_init is (B, H, K, V)
314+
in_specs = P(batch_spec, head_spec, None, None, None)
315+
scalar_specs = P(batch_spec, head_spec, None, None)
316+
state_spec = P(batch_spec, head_spec, None, None)
317+
318+
sharded_gdn = shard_map(
319+
gdn_pallas.gdn_pallas_layer,
320+
mesh=mesh,
321+
in_specs=(in_specs, in_specs, in_specs, in_specs, in_specs, scalar_specs, scalar_specs, state_spec),
322+
out_specs=(in_specs, state_spec), # Returns (out, final_state)
323+
check_rep=False
324+
)
325+
326+
o_pallas, h_final = sharded_gdn(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
327+
else:
328+
# Single Device
329+
o_pallas, h_final = gdn_pallas.gdn_pallas_layer(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
330+
331+
o_chunks = o_pallas.transpose(0, 2, 1, 3, 4)
332+
333+
# =========================================================================
334+
# STAGE 4: FINALIZATION
335+
# =========================================================================
336+
o = o_chunks.reshape(B, -1, H, V_dim)
337+
338+
if pad_len > 0:
339+
o = o[:, :S, :, :]
340+
341+
o = o.astype(initial_dtype)
342+
343+
return o, h_final
344+
180345
def jax_chunk_gated_delta_rule(
181346
query: Array,
182347
key: Array,
@@ -381,13 +546,18 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
381546
2. output = Linear_out(y)
382547
"""
383548

549+
<<<<<<< HEAD:src/maxtext/models/qwen3.py
384550
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs):
551+
=======
552+
def __init__(self, config: Config, *, rngs: nnx.Rngs, mesh: Mesh=None):
553+
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
385554
"""
386555
Args:
387556
config: MaxText configuration object.
388557
rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper.
389558
"""
390559
self.config = config
560+
self.mesh = mesh
391561
cfg = self.config
392562

393563
in_features = cfg.emb_dim
@@ -637,16 +807,12 @@ def extract_state(c_in, v_len):
637807
else:
638808
recurrent_state = recurrent_state[:batch]
639809

640-
core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule(
641-
query,
642-
key,
643-
value,
644-
g,
645-
beta,
646-
chunk_size=cfg.gdn_chunk_size,
647-
initial_state=recurrent_state,
648-
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
810+
core_attn_out, recurrent_state_out = pallas_chunk_gated_delta_rule(
811+
query, key, value, g, beta,
812+
chunk_size=cfg.gdn_chunk_size,
813+
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
649814
compute_dtype=cfg.dtype,
815+
mesh=self.mesh
650816
)
651817

652818
if model_mode != MODEL_MODE_TRAIN:
@@ -982,7 +1148,11 @@ def __init__(
9821148
rngs=rngs,
9831149
)
9841150
else:
1151+
<<<<<<< HEAD:src/maxtext/models/qwen3.py
9851152
self.attention = Qwen3NextGatedDeltaNet(config=cfg, dtype=cfg.dtype, model_mode=model_mode, rngs=rngs)
1153+
=======
1154+
self.attention = Qwen3NextGatedDeltaNet(config=cfg, rngs=rngs, mesh=self.mesh)
1155+
>>>>>>> 7461955dc (add shardmap to kernel):src/MaxText/layers/qwen3.py
9861156

9871157
# Second LayerNorm, applied before the MoE block.
9881158
self.post_attention_layernorm = Qwen3NextRMSNorm(

0 commit comments

Comments
 (0)