Skip to content

Skip ring of experts all gather when EP = 1#4086

Open
Shuwen-Fang wants to merge 1 commit into
mainfrom
fix/vma-check-support
Open

Skip ring of experts all gather when EP = 1#4086
Shuwen-Fang wants to merge 1 commit into
mainfrom
fix/vma-check-support

Conversation

@Shuwen-Fang

@Shuwen-Fang Shuwen-Fang commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

Description

Currently runs error out when EP = 1 and ring of experts is used, fix by skipping all gather when EP = 1. This PR updates the logic to only all gather when EP > 1.

Tests

CI tests

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 5, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

ici_expert_parallelism: 1

# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, use_ring_of_experts: False, and use_tokamax_gmm=False.
# Enabling check_vma is recommended for improved performance. Only supported for EP / FSDP ICI parallelisms, shard_mode: "auto", use_ragged_sort: False, and use_tokamax_gmm=False.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emmm why this flag is removed?

Comment thread src/maxtext/layers/moe.py
x, logits, pre_bias_logits = tuple(
jax.lax.all_gather(z, axis_name=self._expert_parallelism_name, tiled=True)
for z in (x, logits, pre_bias_logits)
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should error out if use_ring_of_experts=true and num_ep==1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants