diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index f62c1d1997..6b02cd199f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -622,7 +622,7 @@ ici_autoregressive_parallelism: 1 ici_pipeline_parallelism: 1 ici_expert_parallelism: 1 -# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, use_ring_of_experts: False, and use_tokamax_gmm=False. +# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, and use_tokamax_gmm=False. check_vma: False # Enable ZeRO-1 optimizer sharding over data axis diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index aa7824ef3c..409760c836 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2363,8 +2363,6 @@ def _validate_check_vma_is_supported(self): raise ValueError("check_vma is not yet supported with tokamax gmm kernel.") if self.use_ragged_sort: raise ValueError("check_vma is not yet supported with ragged sort kernel.") - if self.use_ring_of_experts: - raise ValueError("check_vma is not yet supported with ring of experts.") _allowed = {"ici_expert_parallelism", "ici_fsdp_parallelism"} active = [name for name in IciParallelism.model_fields if name not in _allowed and getattr(self, name) != 1] if active: diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index baa951b4fe..7ee8b7cbe5 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1364,12 +1364,12 @@ def route(x, logits, pre_bias_logits, rngs): if self.config.use_ring_of_experts: # The ring-of-experts strategy first duplicates the inputs to all # expert shards, and then routes within each shard. - - # Duplicate inputs to all expert shards. - x, logits, pre_bias_logits = tuple( - jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) - for z in (x, logits, pre_bias_logits) - ) + if num_ep != 1: + # Duplicate inputs to all expert shards. + x, logits, pre_bias_logits = tuple( + jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True) + for z in (x, logits, pre_bias_logits) + ) # "Route" tokens within each shard. num_experts_per_shard = self.config.num_experts // num_ep