Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 167 additions & 10 deletions src/maxtext/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@
from maxtext.layers.moe import RoutedMoE
from maxtext.layers.initializers import nd_dense_init, variable_to_logically_partitioned

from jax.sharding import PartitionSpec as P
from jax.experimental.shard_map import shard_map
from maxtext.utils import max_utils
from maxtext.inference import page_manager, kvcache

from maxtext.scratch_code import gdn_pallas, gdn_pallas2

# -----------------------------------------
# Qwen3-Next Layer Implementations
Expand Down Expand Up @@ -177,6 +180,163 @@ def scan_body(prev_state, x):
return core_attn_out, final_state if output_final_state else None


def pallas_chunk_gated_delta_rule(
query: jax.Array,
key: jax.Array,
value: jax.Array,
g: jax.Array,
beta: jax.Array,
chunk_size: int = 64,
initial_state: None | jax.Array = None,
use_qk_norm_in_gdn: bool = False,
compute_dtype: jnp.dtype = jnp.bfloat16,
mesh: Mesh | None = None,
) -> tuple[jax.Array, None | jax.Array]:
"""
Pallas-accelerated version of Gated Delta Rule.
"""
# =========================================================================
# STAGE 1: PREPARATION & PADDING
# =========================================================================
initial_dtype = query.dtype
if use_qk_norm_in_gdn:
from maxtext.layers.normalizations import l2norm
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)

g = g.astype(jnp.float32)
query = query.astype(compute_dtype)
key = key.astype(compute_dtype)
value = value.astype(compute_dtype)
beta = beta.astype(compute_dtype)

scale = jax.lax.rsqrt(jnp.array(query.shape[-1], dtype=jnp.float32)).astype(compute_dtype)
query = query * scale

B, seq_len, H, K_dim = key.shape
V_dim = value.shape[-1]

pad_len = (chunk_size - (seq_len % chunk_size)) % chunk_size
if pad_len > 0:
pad_fn = lambda x, val=0.0: jnp.pad(x, ((0,0), (0, pad_len)) + ((0,0),)*(x.ndim-2), constant_values=val)
query = pad_fn(query)
key = pad_fn(key)
value = pad_fn(value)
g = pad_fn(g)
beta = pad_fn(beta)

num_chunks = query.shape[1] // chunk_size

def to_chunk(x):
return x.reshape(B, num_chunks, chunk_size, H, -1).transpose(0, 1, 3, 2, 4)
def to_chunk_scalar(x):
return x.reshape(B, num_chunks, chunk_size, H).transpose(0, 1, 3, 2)

q_c = to_chunk(query)
k_c = to_chunk(key)
v_c = to_chunk(value)
g_c = to_chunk_scalar(g)
beta_c = to_chunk_scalar(beta)

# =========================================================================
# STAGE 2: INTRA-CHUNK PRE-COMPUTATION
# =========================================================================
g_cumsum = jnp.cumsum(g_c, axis=-1)
k_beta = k_c * beta_c[..., None]

S = jnp.matmul(k_beta, k_c.swapaxes(-1, -2), precision=jax.lax.Precision.HIGHEST)
S = S.astype(jnp.float32)
g_diff = g_cumsum[..., :, None] - g_cumsum[..., None, :]
mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1)
g_diff = jnp.where(mask, g_diff, -1e30)
S = S * jnp.exp(g_diff)
S = jnp.where(mask, S, 0.0)

identity = jnp.eye(chunk_size, dtype=jnp.float32)
identity_broadcasted = jnp.broadcast_to(identity, S.shape)
A = jax.scipy.linalg.solve_triangular(identity + S, identity_broadcasted, lower=True, unit_diagonal=True)

# OPTIMIZED TPU INVERSION: (I+S)^-1 using logarithmic expansion
# Since S is strictly lower triangular, S^N = 0. We can invert it with pure matmuls.
# X = -S
# A = identity + X
# prec = jax.lax.Precision.HIGHEST
# X_pow = jnp.matmul(X, X, precision=prec)

# num_iters = int(math.ceil(math.log2(chunk_size))) - 1
# for _ in range(num_iters):
# A = A + jnp.matmul(A, X_pow, precision=prec)
# X_pow = jnp.matmul(X_pow, X_pow, precision=prec)

v_beta = v_c * beta_c[..., None]
u_chunks = jnp.matmul(A, v_beta.astype(jnp.float32), precision=jax.lax.Precision.HIGHEST)
u_chunks = u_chunks.astype(compute_dtype)

k_beta_g = k_beta.astype(jnp.float32) * jnp.exp(g_cumsum)[..., None]
w_chunks = jnp.matmul(A, k_beta_g, precision=jax.lax.Precision.HIGHEST)
w_chunks = w_chunks.astype(compute_dtype)

# =========================================================================
# STAGE 3: INTER-CHUNK RECURRENCE (Pallas Kernel + shard_map)
# =========================================================================
# Transpose to (Batch, Heads, NumChunks, ChunkSize, Dim) for Pallas
w_p = w_chunks.transpose(0, 2, 1, 3, 4)
u_p = u_chunks.transpose(0, 2, 1, 3, 4)
q_p = q_c.transpose(0, 2, 1, 3, 4)
k_p = k_c.transpose(0, 2, 1, 3, 4)
v_p = v_c.transpose(0, 2, 1, 3, 4)
g_p = g_cumsum.transpose(0, 2, 1, 3)
beta_p = beta_c.transpose(0, 2, 1, 3)

# Handle initial state
if initial_state is None:
h_init = jnp.zeros((B, H, K_dim, V_dim), dtype=compute_dtype)
else:
h_init = initial_state.astype(compute_dtype)

kernel_to_use = gdn_pallas3.gdn_pallas_layer
# Invoke Kernel
if mesh is not None:
# Mesh Partitioning
axis_names = mesh.axis_names
batch_axes = [ax for ax in ('data', 'fsdp', 'fsdp_transpose', 'expert') if ax in axis_names]
batch_spec = tuple(batch_axes) if batch_axes else None
head_axes = [ax for ax in ('tensor', 'model') if ax in axis_names]
head_spec = tuple(head_axes) if head_axes else None

# Specs: B, H, ...
# h_init is (B, H, K, V)
in_specs = P(batch_spec, head_spec, None, None, None)
scalar_specs = P(batch_spec, head_spec, None, None)
state_spec = P(batch_spec, head_spec, None, None)

sharded_gdn = shard_map(
kernel_to_use,
mesh=mesh,
in_specs=(in_specs, in_specs, in_specs, in_specs, in_specs, scalar_specs, scalar_specs, state_spec),
out_specs=(in_specs, state_spec),
check_rep=False
)

o_pallas, h_final = sharded_gdn(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)
else:
# Single Device
o_pallas, h_final = kernel_to_use(w_p, u_p, q_p, k_p, v_p, g_p, beta_p, h_init)

o_chunks = o_pallas.transpose(0, 2, 3, 1, 4)

# =========================================================================
# STAGE 4: FINALIZATION
# =========================================================================
o = o_chunks.reshape(B, -1, H, V_dim)

if pad_len > 0:
o = o[:, :seq_len, :, :]

o = o.astype(initial_dtype)

return o, h_final

def jax_chunk_gated_delta_rule(
query: Array,
key: Array,
Expand Down Expand Up @@ -381,13 +541,14 @@ class Qwen3NextGatedDeltaNet(nnx.Module):
2. output = Linear_out(y)
"""

def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs):
def __init__(self, config: Config, dtype: DType = jnp.float32, model_mode: str = MODEL_MODE_TRAIN, *, rngs: nnx.Rngs, mesh: Mesh=None):
"""
Args:
config: MaxText configuration object.
rngs: The random number generators for initialization, passed by the nnx.to_linen wrapper.
"""
self.config = config
self.mesh = mesh
cfg = self.config

in_features = cfg.emb_dim
Expand Down Expand Up @@ -637,16 +798,12 @@ def extract_state(c_in, v_len):
else:
recurrent_state = recurrent_state[:batch]

core_attn_out, recurrent_state_out = jax_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=cfg.gdn_chunk_size,
initial_state=recurrent_state,
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
core_attn_out, recurrent_state_out = pallas_chunk_gated_delta_rule(
query, key, value, g, beta,
chunk_size=cfg.gdn_chunk_size,
use_qk_norm_in_gdn=cfg.use_qk_norm_in_gdn,
compute_dtype=cfg.dtype,
mesh=self.mesh
)

if model_mode != MODEL_MODE_TRAIN:
Expand Down
Loading
Loading