TE EP integration to MoEBlock#3116
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…em_reloc gating Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rce at dispatch Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… static layer registration Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…er + NVTEEpHandle struct (NVTE_EP_HANDLE_CACHE_SIZE=-1 disables eviction) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…hout it Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ogging.h Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…_COPY_{ON,OFF}
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tyAllSymm Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…CUDA Toolkit) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… for wheel install Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…bmodules Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…rop submodule header mirror Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…al CommWindow Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…espace Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…I in EP files Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…headers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…lint runtime/int) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…pe lifetime) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
…16 max_token_dtype Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
… with_sharding_constraint Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…trap Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…EpLayerConfig type) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ives (lint 10.00) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…; define NVTE_WITH_NCCL_EP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ract, drop dead helpers Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Reset 33 local commits onto phuong/ep-3-jax @ c34771d (her latest with EpConfig + EpLayerConfig API, NCCL bumped to 808d2433) and re-applied the three deltas uniquely ours: * transformer_engine/jax/moe.py: replaces upstream's multi-backend MoE block with our TE-EP-only single-custom-vjp rewrite. Adapted to her new API surface: tex.EpLayerConfig replaces tex.ep_make_handle (no more EpHandle pool/cache); 5 EP callsites rewired (cfg passed in place of handle, ep_prepare arg order swapped, top_k= dropped from ep_dispatch_bwd since it's now in cfg. * tests/jax/test_te_ep_moe.py: TE-EP MoE test (kept), with ep_bootstrap kwargs ep_size= and allow_handle_mem_reloc= dropped (no longer supported; ep_size is derived from mesh axes and the handle_mem reloc gating is gone). * tests/jax/run_te_ep_moe.sh: multi-process launcher (kept). Pre-sync state preserved at branch teddy/te_ep_integration.backup-pre-phuong-sync. EOF )
0ff3bff to
bd14fe6
Compare
for more information, see https://pre-commit.ci
| use_bias: bool = False | ||
| # Per-expert router bias added before the top-k. Only meaningful when | ||
| # score_function='sigmoid'. | ||
| use_expert_bias: bool = False |
There was a problem hiding this comment.
nit SR: can we rename use_bias -> use_ffn_bias and use_expert_bias -> use_expert_routing_bias?
| # Minimum per-expert slot alignment fed to ``tex.ep_prepare``. Default 0 | ||
| # uses the natural slot count; set to e.g. 128 to satisfy FP8 grouped-GEMM | ||
| # tile alignment. | ||
| align_size: int = 0 |
There was a problem hiding this comment.
Placeholder comment for me to fix this so align_size is inferred automatically based on the recipe and doesn't need to be specified by the user
| nn.with_logical_partitioning(self.bias_init, ("exp",)), | ||
| (self.num_experts,), | ||
| self.dtype, | ||
| jnp.float32, |
There was a problem hiding this comment.
Is the router always in fp32 so this expert bias must also be? If so, can we add a small comment indicating this
|
|
||
|
|
||
| __all__ = ["moe", "PermutationBackend"] | ||
| def _with_sharding_constraint_cast_bwd(x: jnp.ndarray, sharding) -> jnp.ndarray: |
There was a problem hiding this comment.
Why do we need this utility function? I haven't seen something like this required for our other VJPs
| # is a frozen dataclass of ints); the rest are jnp.ndarray, | ||
| # GroupedNoScaleTensor (already a pytree), or None when aux_loss_coeff == 0. | ||
| @register_pytree_node_class | ||
| @dataclass |
There was a problem hiding this comment.
I think this tree_flatten was from my patch, but looking at the diff I think it'd be better to use the @flax_struct.dataclass you were using on the permutation dataclasses since that seems to auto-populate a default pytree flatten/unflatten for us
| else: | ||
| d_recv_w_from_intermediate = jnp.zeros_like(recv_w_flat) | ||
|
|
||
| # Activation bwd. Mirror the fwd's fp32 promotion of silu+multiply |
There was a problem hiding this comment.
Why is this dtype casting required? I don't recall us needing it for the non-MoE LNMLP block
| # local expert. We must size to that worst case or NCCL EP's HT kernel | ||
| # rejects the dispatch buffer with ``invalid argument``. | ||
| natural_spe = num_ep * max_tokens_per_rank # = (B // dp_size) * S | ||
| # NCCL EP requires each expert-major output block to be at least |
There was a problem hiding this comment.
Do we have a use-case for user-specified alignments beyond 128 currently? If NCCL EP requires an alignment of at least 128, and since an alignment of 128 is sufficient for all TE grouped GEMM types, would it make sense to instead hardcode _ALIGN_SIZE = 128 as a constant at the top of the file for now to simplify this MoEBlock API.
We can always expand the API to support a user-specified align size in the future
| batch_pspec_axis = (*data_parallelism_axes, ep_axis) | ||
| ep3_spec = P(batch_pspec_axis, None, None) | ||
| ep2_spec = P(batch_pspec_axis, None) | ||
| x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, ep3_spec)) |
There was a problem hiding this comment.
Which axis name inputs are physical mesh axes and why can be logical axes? I see above x = with_sharding_constraint_by_logical_axes(x, input_axes) but here we directly use jax.lax.with_sharding_constraint which only supports mesh axes.
No need to make any changes for now, I just want to assess which are which and then we can discuss if it makes sense to support logical on some/all or if some are required to be physical axes. Thanks!
| # `grad_pre_combine * w` sees them. Padded positions in sparse_probs | ||
| # are already zero (routing_map is False there); only the rare | ||
| # underflow path emits NaN. | ||
| sparse_probs = jnp.where(jnp.isnan(sparse_probs), 0, sparse_probs).astype(dtype) |
There was a problem hiding this comment.
Is this NaN filtering a debugging artifact or something we need in the final version?
Description
Will rebase and squash the commits on this branch once about to merge
Will also change the JAX APIs if needed when TE EP JAX merge
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: