Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/maxtext/kernels/megablox/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _calculate_bytes(x: jax.Array | qpl.QArray) -> int:
"tiling",
"transpose_rhs",
"interpret",
"vma_axes",
],
)
def gmm(
Expand All @@ -310,6 +311,7 @@ def gmm(
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
vma_axes: tuple = tuple(),
) -> jnp.ndarray:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.

Expand Down Expand Up @@ -522,7 +524,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
}
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, vma=set(vma_axes)),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down Expand Up @@ -558,13 +560,15 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
return out


# calculates drhs - expert weight gradient
@functools.partial(
jax.jit,
static_argnames=[
"preferred_element_type",
"tiling",
"num_actual_groups",
"interpret",
"vma_axes",
],
)
def tgmm(
Expand All @@ -577,6 +581,7 @@ def tgmm(
num_actual_groups: int | None = None,
existing_out: jnp.ndarray | None = None,
interpret: bool = False,
vma_axes: tuple = tuple(),
) -> jnp.ndarray:
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].

Expand Down Expand Up @@ -773,9 +778,10 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
"prefer_element_type": jnp.dtype(preferred_element_type).name,
"num_actual_groups": num_actual_groups,
}
# computes
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, vma=set(vma_axes)),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down
24 changes: 18 additions & 6 deletions src/maxtext/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def gmm(
weight_gather_axes: List[Tuple[str, int]] | None = None,
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
combine_scopes: bool = False,
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
lhs_vma_axes: tuple = tuple(),
rhs_vma_axes: tuple = tuple(),
):
"""Grouped matrix multiplication operation."""
quantization_rule = None
Expand All @@ -64,9 +64,13 @@ def gmm(
act_calibration_method="absmax",
)

gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
_gmm_fwd_vma = functools.partial(_gmm_fwd, lhs_vma_axes=tuple())
_gmm_bwd_vma = functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype, lhs_vma_axes=tuple(), rhs_vma_axes=("expert",))
gmm_fwd_bwd = lambda *args: _gmm_fwd_vma(*args)[0] # pylint: disable=C3001
# defined custom backward propagation to be more efficient
# computes: dlhs: gradients of activations (for previous layers); drhs: gradients of weights
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
gmm_fwd_bwd.defvjp(_gmm_fwd_vma, _gmm_bwd_vma)
return gmm_fwd_bwd(
lhs,
rhs,
Expand All @@ -85,6 +89,7 @@ def gmm(
)


# wraps backend kernel
def _gmm_fwd(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
Expand All @@ -100,6 +105,7 @@ def _gmm_fwd(
quantization_rule: qwix.QtRule | None = None,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
lhs_vma_axes: tuple = tuple(),
) -> tuple[
jnp.ndarray,
tuple[
Expand Down Expand Up @@ -129,7 +135,7 @@ def _gmm_fwd(
calibration_method=quantization_rule.weight_calibration_method,
)
# QAG is only supported for following conditions
if use_tokamax_backend:
if use_tokamax_backend: # false
if quantization_rule and quantization_rule.bwd_qtype:
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
if weight_gather_axes:
Expand Down Expand Up @@ -159,10 +165,12 @@ def _gmm_fwd(
existing_out,
transpose_rhs=transpose_rhs,
interpret=interpret,
vma_axes=tuple(),
)
return out, (lhs, rhs, group_sizes, group_offset)


# custom backward function
def _gmm_bwd(
lhs_dtype: jax.typing.DTypeLike,
rhs_dtype: jax.typing.DTypeLike,
Expand All @@ -182,6 +190,8 @@ def _gmm_bwd(
jnp.ndarray | None,
],
grad: jnp.ndarray,
lhs_vma_axes: tuple = tuple(), # axes for SiLU output - fsdp
rhs_vma_axes: tuple = tuple(), # axes for W_out - expert
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
"""Backward function for throughput GMM VJP."""
del preferred_element_type
Expand Down Expand Up @@ -223,7 +233,7 @@ def _gmm_bwd(
channelwise_axes=[] if quantization_rule.disable_channelwise_axes else [1],
calibration_method=quantization_rule.bwd_calibration_method,
)
if use_tokamax_backend:
if use_tokamax_backend: # false
dlhs = tokamax_backend.gmm(
lhs=dlhs_dout,
rhs=rhs,
Expand Down Expand Up @@ -263,6 +273,7 @@ def _gmm_bwd(
group_offset,
transpose_rhs=not transpose_rhs,
interpret=interpret,
vma_axes=lhs_vma_axes,
)
drhs = backend.tgmm(
lhs.swapaxes(0, 1),
Expand All @@ -273,6 +284,7 @@ def _gmm_bwd(
group_offset,
num_actual_groups,
interpret=interpret,
vma_axes=("expert",),
)

# NOTE: If the rhs transposition is fused into the forward pass we need to
Expand Down
62 changes: 26 additions & 36 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,7 @@ def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
# sort inputs for number of selected experts
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(
self.dtype
)
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(self.dtype)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
# Return the experts for each sorted input.
expert_indices = jnp.arange(self.num_experts)
Expand Down Expand Up @@ -893,9 +891,13 @@ def sparse_matmul(
):
"""Perform sparse matrix multiplication of inputs and Experts."""

def gmm(
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
):
vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
use_vma = not self.config.use_tokamax_gmm

vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
use_vma = not self.config.use_tokamax_gmm

def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes):
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
tokamax_group_sizes = group_sizes
Expand Down Expand Up @@ -1033,14 +1035,13 @@ def gmm(
batch_logical_axis = "activation_batch_no_exp_moe"

if self.get_tensor_transpose_parallelism_size() > 1:
input_partition_pspec = self._logical_to_mesh_axes(
(batch_logical_axis, "activation_norm_length_moe", "activation_embed_moe")
)
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))
w0_bias_pspec = self._logical_to_mesh_axes(("exp", None))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", None))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
else:
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length_moe", None))
input_partition_pspec = self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", None))
# expert weights are sharded by exp
w0_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
w1_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_mlp"))
wo_bias_pspec = self._logical_to_mesh_axes(("exp", "activation_embed_moe"))
Expand Down Expand Up @@ -1103,7 +1104,7 @@ def gmm(
P(), # Handle None or replicate the output
P(), # Handle None or replicate the output
),
check_vma=False,
check_vma=use_vma,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
Expand All @@ -1119,8 +1120,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r

# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) for z in (x, logits, pre_bias_logits)
)

# "Route" tokens within each shard.
Expand Down Expand Up @@ -1262,6 +1262,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):

wi_combine_scopes = self.config.wi_combine_scopes
wo_combine_scopes = self.config.wo_combine_scopes
# x * W_gate
layer_w0 = gmm_fn(
x,
w0,
Expand All @@ -1274,8 +1275,8 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
# x * W_up
layer_w1 = gmm_fn(
x,
w1,
Expand All @@ -1288,9 +1289,10 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
# multiplied result from W_gate and W_up before downward projection
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)

# output of FFN
intermediate_output = gmm_fn(
intermediate_layer,
wo,
Expand Down Expand Up @@ -1435,10 +1437,8 @@ def reshape_and_update_weights(self, weights, indices):
# output of updated weights: (batch_size, seq_len, num_experts)
update_weights = jnp.zeros((weights.shape[0], weights.shape[1], self.num_experts), dtype=self.dtype)
index_update = (
self._maybe_shard_with_logical(
jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp_moe", None, None)
),
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp_moe", None)),
self._maybe_shard_with_logical(jnp.arange(weights.shape[0])[:, None, None], ("activation_batch_no_exp", None, None)),
self._maybe_shard_with_logical(jnp.arange(weights.shape[1])[:, None], ("activation_length_no_exp", None)),
indices,
)
weight_sharding = (
Expand Down Expand Up @@ -1664,9 +1664,7 @@ def aqt_einsum(*args, **kwargs): # pylint: disable=unused-argument
einsum_op = jnp.einsum
return einsum_op

def maybe_all_gather_kernel_weight_in_expert_parallelism(
self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]
):
def maybe_all_gather_kernel_weight_in_expert_parallelism(self, kernel: jax.Array, kernel_axes: Tuple[Optional[str], ...]):
"""All-gather kernel weight in expert parallelism if needed."""
if self.get_expert_parallelism_size() > 1:
# This will trigger all-gather using weight_dtype
Expand All @@ -1691,14 +1689,10 @@ def dense_matmul(
) -> tuple[jax.Array, Optional[jax.Array], Optional[jax.Array]]:
"""Dense matrix multiplication."""
# gate_logits: batch, length, expert
gate_logits = self._maybe_shard_with_logical(
gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
)
gate_logits = self._maybe_shard_with_logical(gate_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None))
if self.config.model_name.startswith("deepseek3"):
# pre_bias_logits is None for non-DeepSeek v3 models
pre_bias_logits = self._maybe_shard_with_logical(
pre_bias_logits, ("activation_batch_moe", "activation_length_no_exp_moe", None)
)
pre_bias_logits = self._maybe_shard_with_logical(pre_bias_logits, ("activation_batch", "activation_norm_length", None))
top_k_weights, top_k_indices = self.get_topk(gate_logits, pre_bias_logits, self.rngs)
is_llama4_decoder_layer = self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4
if is_llama4_decoder_layer:
Expand All @@ -1711,9 +1705,7 @@ def dense_matmul(
# Calculate load balance loss
if self.config.model_call_mode != "inference":
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
lb_loss = (
self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
)
lb_loss = self.load_balance_loss(top_k_indices, softmax_probs) if self.config.load_balance_loss_weight > 0.0 else None
else:
lb_loss = None

Expand Down Expand Up @@ -1990,9 +1982,7 @@ def retrieve_quantized_weight(
# This is called only during tracing. This is to invoke creation of
# quantized tensor inside AqtEinsum. After jit, this will become no-op and
# will not affect performance.
_ = self.dense_matmul(
inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias
)
_ = self.dense_matmul(inputs, gate_logits, pre_bias_logits, w0_kernel, w1_kernel, wo_kernel, w0_bias, w1_bias, wo_bias)

w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"]
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def mlp_op(self, x, deterministic, *args, **kwargs):
def with_logical_constraint(self, x):
return maybe_shard_with_logical(
x,
logical_axes=self.logical_axis_names,
logical_axes=tuple(self.logical_axis_names),
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/utils/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _maybe_shard_with_name(inputs, sharding_names):
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding)

# For more efficient DP/ZeRO-1 + GA
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
if config.shard_mode == ShardMode.EXPLICIT and model.mesh.shape.get("data", 1) > 1:
ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
else:
Expand Down
Loading