Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ ici_autoregressive_parallelism: 1
ici_pipeline_parallelism: 1
ici_expert_parallelism: 1

# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, use_ring_of_experts: False, and use_tokamax_gmm=False.
# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, and use_tokamax_gmm=False.
check_vma: False

# Enable ZeRO-1 optimizer sharding over data axis
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,8 +2363,6 @@ def _validate_check_vma_is_supported(self):
raise ValueError("check_vma is not yet supported with tokamax gmm kernel.")
if self.use_ragged_sort:
raise ValueError("check_vma is not yet supported with ragged sort kernel.")
if self.use_ring_of_experts:
raise ValueError("check_vma is not yet supported with ring of experts.")
_allowed = {"ici_expert_parallelism", "ici_fsdp_parallelism"}
active = [name for name in IciParallelism.model_fields if name not in _allowed and getattr(self, name) != 1]
if active:
Expand Down
12 changes: 6 additions & 6 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,12 +1364,12 @@ def route(x, logits, pre_bias_logits, rngs):
if self.config.use_ring_of_experts:
# The ring-of-experts strategy first duplicates the inputs to all
# expert shards, and then routes within each shard.

# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)
if num_ep != 1:
# Duplicate inputs to all expert shards.
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)

# "Route" tokens within each shard.
num_experts_per_shard = self.config.num_experts // num_ep
Expand Down
Loading