Skip to content

TE EP integration to MoEBlock#3116

Draft
tdophung wants to merge 52 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration
Draft

TE EP integration to MoEBlock#3116
tdophung wants to merge 52 commits into
NVIDIA:mainfrom
tdophung:teddy/te_ep_integration

Conversation

@tdophung

@tdophung tdophung commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Description

Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng and others added 30 commits June 9, 2026 18:27
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…espace

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…I in EP files

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 21 commits June 9, 2026 18:27
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…headers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…lint runtime/int)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…pe lifetime)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…16 max_token_dtype

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with
EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied
the three deltas uniquely ours:

  * transformer_engine/jax/moe.py: replaces upstream's multi-backend
    MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted
    to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle
    (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed
    in place of handle, ep_prepare arg order swapped, top_k= dropped
    from ep_dispatch_bwd since it's now in cfg.
  * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with
    ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped
    (no longer supported; ep_size is derived from mesh axes and the
    handle_mem reloc gating is gone).
  * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept).

Pre-sync state preserved at branch
teddy/te_ep_integration.backup-pre-phuong-sync.
EOF
)
@tdophung tdophung force-pushed the teddy/te_ep_integration branch from 0ff3bff to bd14fe6 Compare June 10, 2026 21:58
Comment thread transformer_engine/jax/flax/moe.py Outdated
use_bias: bool = False
# Per-expert router bias added before the top-k. Only meaningful when
# score_function='sigmoid'.
use_expert_bias: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit SR: can we rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias?

Comment thread transformer_engine/jax/flax/moe.py Outdated
# Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0
# uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM
# tile alignment.
align_size: int = 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user

Comment thread transformer_engine/jax/flax/moe.py Outdated
nn.with_logical_partitioning(self.bias_init, ("exp",)),
(self.num_experts,),
self.dtype,
jnp.float32,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this



__all__ = ["moe", "PermutationBackend"]
def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this utility function? I haven't seen something like this required for our other VJPs

# is a frozen dataclass of ints); the rest are jnp.ndarray,
# GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0.
@register_pytree_node_class
@dataclass

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us

else:
d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat)

# Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block

# local expert. We must size to that worst case or NCCL EP's HT kernel
# rejects the dispatch buffer with ``invalid argument``.
natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S
# NCCL EP requires each expert-major output block to be at least

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.

We can always expand the API to support a user-specified align size in the future

batch_pspec_axis = (*data_parallelism_axes, ep_axis)
ep3_spec = P(batch_pspec_axis, None, None)
ep2_spec = P(batch_pspec_axis, None)
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.

No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!

# `grad_pre_combine * w` sees them. Padded positions in sparse_probs
# are already zero (routing_map is False there); only the rare
# underflow path emits NaN.
sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this NaN filtering a debugging artifact or something we need in the final version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants