Skip to content

Commit 0ff3bff

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d74f4ce commit 0ff3bff

6 files changed

Lines changed: 109 additions & 107 deletions

File tree

tests/jax/test_multi_process_ep.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,13 @@ def test_two_layer_dispatch_no_handle_aliasing(self):
245245
w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
246246

247247
def one_layer(hk, idx, toks, w_):
248-
recv_t, recv_w, hm, tc = ep_dispatch(
249-
hk, idx, toks, w_, self.recv_capacity_per_rank
248+
recv_t, recv_w, hm, tc = ep_dispatch(hk, idx, toks, w_, self.recv_capacity_per_rank)
249+
recv_t = jax.lax.with_sharding_constraint(
250+
recv_t, NamedSharding(self.mesh, ep_spec_3d)
251+
)
252+
recv_w = jax.lax.with_sharding_constraint(
253+
recv_w, NamedSharding(self.mesh, ep_spec_2d)
250254
)
251-
recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_spec_3d))
252-
recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_spec_2d))
253255
return ep_combine(
254256
hk, hm, tc, recv_t, recv_w, T_global, out_sharding=(("dp", "ep"), None)
255257
)
@@ -269,12 +271,14 @@ def run(idx, ta_, tb_, w_):
269271
np.testing.assert_allclose(
270272
np.asarray(out_a_g.astype(jnp.float32)),
271273
np.asarray(tokens.astype(jnp.float32)),
272-
atol=5e-2, rtol=5e-2,
274+
atol=5e-2,
275+
rtol=5e-2,
273276
)
274277
np.testing.assert_allclose(
275278
np.asarray(out_b_g.astype(jnp.float32)),
276279
np.asarray(tokens_b.astype(jnp.float32)),
277-
atol=5e-2, rtol=5e-2,
280+
atol=5e-2,
281+
rtol=5e-2,
278282
)
279283

280284
def test_primitive_prepare(self):
@@ -328,7 +332,10 @@ def run(idx, toks, w):
328332
weighted, NamedSharding(self.mesh, ep_spec_3d)
329333
)
330334
out = ep_combine_fwd(
331-
self.hk, hm, weighted, T_global,
335+
self.hk,
336+
hm,
337+
weighted,
338+
T_global,
332339
out_partition_spec=(("dp", "ep"), None),
333340
)
334341
return jax.lax.with_sharding_constraint(out, NamedSharding(self.mesh, dp_spec))
@@ -372,7 +379,9 @@ def loss_fn(toks):
372379
toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
373380
idx = jax.lax.with_sharding_constraint(topk_idx, NamedSharding(self.mesh, dp_spec))
374381
w = jax.lax.with_sharding_constraint(topk_w, NamedSharding(self.mesh, dp_spec))
375-
recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
382+
recv_t, recv_w, hm, tc = ep_dispatch(
383+
self.hk, idx, toks, w, self.recv_capacity_per_rank
384+
)
376385
recv_t = jax.lax.with_sharding_constraint(
377386
recv_t, NamedSharding(self.mesh, ep_spec_3d)
378387
)
@@ -420,7 +429,9 @@ def test_dispatch_combine_3d_input_output(self):
420429

421430
@jax.jit
422431
def run(idx, toks, w):
423-
recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
432+
recv_t, recv_w, hm, _tc = ep_dispatch(
433+
self.hk, idx, toks, w, self.recv_capacity_per_rank
434+
)
424435
recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t))
425436
recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w))
426437
out = ep_combine(
@@ -463,7 +474,9 @@ def test_dispatch_combine_dp_only_first_dim(self):
463474

464475
@jax.jit
465476
def run(idx, toks, w):
466-
recv_t, recv_w, hm, _tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
477+
recv_t, recv_w, hm, _tc = ep_dispatch(
478+
self.hk, idx, toks, w, self.recv_capacity_per_rank
479+
)
467480
recv_t = jax.lax.with_sharding_constraint(recv_t, NamedSharding(self.mesh, ep_t))
468481
recv_w = jax.lax.with_sharding_constraint(recv_w, NamedSharding(self.mesh, ep_w))
469482
out = ep_combine(
@@ -641,7 +654,9 @@ def run(idx, toks, w):
641654
idx = jax.lax.with_sharding_constraint(idx, NamedSharding(self.mesh, dp_spec))
642655
toks = jax.lax.with_sharding_constraint(toks, NamedSharding(self.mesh, dp_spec))
643656
w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec))
644-
recv_t, recv_w, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
657+
recv_t, recv_w, hm, tc = ep_dispatch(
658+
self.hk, idx, toks, w, self.recv_capacity_per_rank
659+
)
645660
recv_t = jax.lax.with_sharding_constraint(
646661
recv_t, NamedSharding(self.mesh, ep_spec_3d)
647662
)
@@ -688,7 +703,9 @@ def fwd(eo, toks, idx, w):
688703
w = jax.lax.with_sharding_constraint(w, NamedSharding(self.mesh, dp_spec))
689704
_rt, rw, hm, tc = ep_dispatch(self.hk, idx, toks, w, self.recv_capacity_per_rank)
690705
rw = jax.lax.with_sharding_constraint(rw, NamedSharding(self.mesh, ep_spec_2d))
691-
combined = ep_combine(self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None))
706+
combined = ep_combine(
707+
self.hk, hm, tc, eo, rw, T_dp, out_sharding=(("dp", "ep"), None)
708+
)
692709
return jax.lax.with_sharding_constraint(combined, NamedSharding(self.mesh, dp_spec))
693710

694711
# jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd

tests/jax/test_te_ep_moe.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def _read_mp_options():
112112

113113
if not _MP_ACTIVE:
114114
pytest.skip(
115-
"test_te_ep_moe.py requires the multiprocess launcher "
116-
"(run_te_ep_moe.sh). Skipping.",
115+
"test_te_ep_moe.py requires the multiprocess launcher (run_te_ep_moe.sh). Skipping.",
117116
allow_module_level=True,
118117
)
119118

@@ -231,9 +230,7 @@ def mesh():
231230
# Eager bootstrap: ep_bootstrap does a host-side NCCL UID allgather
232231
# and cannot run from inside jax.jit. Sized to the worst-case recv_pr
233232
# across _CONFIGS so every parametrized config is bootstrap-compatible.
234-
with mesh_obj, global_shard_guard(
235-
MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)
236-
):
233+
with mesh_obj, global_shard_guard(MeshResource(ep_resource=EP_AXIS, fsdp_resource=FSDP_AXIS)):
237234
ep_bootstrap(
238235
world_size=num_procs,
239236
rank=jax.process_index(),
@@ -325,9 +322,7 @@ def _pure_jax_moe_reference(
325322
raise ValueError(f"Unsupported score_function={score_function!r}")
326323

327324
routing_weights_full = jnp.zeros((T, num_experts), dtype=jnp.float32)
328-
routing_weights_full = routing_weights_full.at[
329-
jnp.arange(T)[:, None], top_indices
330-
].set(weights)
325+
routing_weights_full = routing_weights_full.at[jnp.arange(T)[:, None], top_indices].set(weights)
331326

332327
# FFN. ``apply_topk_weights_early`` is a fusion knob that doesn't
333328
# change the math (wo is linear), so the reference is identical for
@@ -337,9 +332,7 @@ def _pure_jax_moe_reference(
337332
intermediate = jax.nn.silu(layer_w0.astype(jnp.float32)) * layer_w1.astype(jnp.float32)
338333
intermediate = intermediate.astype(x.dtype)
339334
expert_out = jnp.einsum("tem,emh->teh", intermediate, wo) # [T, E, H]
340-
output_2d = jnp.einsum(
341-
"te,teh->th", routing_weights_full.astype(x.dtype), expert_out
342-
)
335+
output_2d = jnp.einsum("te,teh->th", routing_weights_full.astype(x.dtype), expert_out)
343336
output = output_2d.reshape(B, S, H).astype(x.dtype)
344337

345338
if aux_loss_coeff > 0.0:
@@ -354,9 +347,7 @@ def _pure_jax_moe_reference(
354347
else: # sigmoid
355348
aux_scores = jax.nn.sigmoid(logits)
356349
if K > 1:
357-
aux_scores = aux_scores / (
358-
aux_scores.sum(axis=-1, keepdims=True) + 1e-20
359-
)
350+
aux_scores = aux_scores / (aux_scores.sum(axis=-1, keepdims=True) + 1e-20)
360351
routing_map = (routing_weights_full > 0).astype(jnp.int32)
361352
tokens_per_expert = jnp.sum(routing_map, axis=0) # [E]
362353
sum_probs_per_expert = jnp.sum(aux_scores, axis=0) # [E]
@@ -567,9 +558,7 @@ def _reference_kwargs_from_config(config, params_np):
567558
return dict(
568559
score_function=config.get("score_function", "softmax"),
569560
expert_bias=(
570-
jnp.asarray(params_np["expert_bias"])
571-
if config.get("use_expert_bias", False)
572-
else None
561+
jnp.asarray(params_np["expert_bias"]) if config.get("use_expert_bias", False) else None
573562
),
574563
)
575564

@@ -720,9 +709,7 @@ def test_aux_loss(self, mesh):
720709
# wired.
721710
aux_grads = _grad_aux_only(block, variables, mesh, x)
722711
g_gate = np.asarray(
723-
jax.device_get(
724-
_unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0)
725-
)
712+
jax.device_get(_unwrap(aux_grads["params"]["gate_kernel"]).addressable_data(0))
726713
)
727714
assert np.all(np.isfinite(g_gate)), "gate grad NaN/Inf under aux-only loss"
728715
assert np.any(g_gate != 0.0), "aux bwd should propagate to gate_kernel"
@@ -735,9 +722,7 @@ def test_combined_loss_grads(self, mesh):
735722
variables, _, _ = _init_apply(block, mesh, x, jax.random.PRNGKey(23))
736723
grads = _grad_step(block, variables, mesh, x, include_aux=True)
737724
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
738-
g_local = np.asarray(
739-
jax.device_get(_unwrap(grads["params"][name]).addressable_data(0))
740-
)
725+
g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)))
741726
assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf under main+aux"
742727
assert np.any(g_local != 0.0), f"{name} grad zero under main+aux"
743728

@@ -779,9 +764,7 @@ def test_init_apply_parity(self, mesh):
779764

780765
grads = _grad_step(block, variables, mesh, x)
781766
for name in ("gate_kernel", "wi_0", "wi_1", "wo"):
782-
g_local = np.asarray(
783-
jax.device_get(_unwrap(grads["params"][name]).addressable_data(0))
784-
)
767+
g_local = np.asarray(jax.device_get(_unwrap(grads["params"][name]).addressable_data(0)))
785768
assert np.all(np.isfinite(g_local)), f"{name} grad NaN/Inf"
786769
assert np.any(g_local != 0.0), f"{name} grad zero"
787770

@@ -801,9 +784,7 @@ def test_bootstrap_signature_mismatch_raises(self, mesh):
801784

802785
# Different hidden dim → different bootstrap signature.
803786
bigger_hidden = HIDDEN * 2
804-
x_b = jax.random.normal(
805-
jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE
806-
)
787+
x_b = jax.random.normal(jax.random.PRNGKey(16), (BATCH, SEQ, bigger_hidden), dtype=DTYPE)
807788
block_b = MoEBlock(
808789
num_experts=NUM_EXPERTS,
809790
num_experts_per_tok=TOPK,

transformer_engine/jax/cpp_extensions/ep.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,12 @@ def ep_combine_fwd(handle, handle_mem, expert_out, num_local_tokens, out_partiti
931931

932932
@compute_on("gpu_stream:collective")
933933
def ep_dispatch_bwd(
934-
handle, handle_mem, grad, g_recv_topk_weights, top_k, num_local_tokens,
934+
handle,
935+
handle_mem,
936+
grad,
937+
g_recv_topk_weights,
938+
top_k,
939+
num_local_tokens,
935940
out_partition_spec=None,
936941
):
937942
"""Backward of dispatch; returns (grad_tokens, grad_topk_weights)."""

transformer_engine/jax/csrc/extensions/inspect.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ Error_Type InspectFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type mi
121121

122122
XLA_FFI_DEFINE_HANDLER_SYMBOL(InspectHandler, InspectFFI,
123123
FFI::Bind()
124-
.Ctx<FFI_Stream_Type>() // stream
125-
.Arg<Buffer_Type>() // input
126-
.Arg<Buffer_Type>() // min
127-
.Arg<Buffer_Type>() // max
128-
.Arg<Buffer_Type>() // mean
129-
.Arg<Buffer_Type>() // std
130-
.Ret<Buffer_Type>() // output
131-
.Attr<std::string_view>("name") // probe name
124+
.Ctx<FFI_Stream_Type>() // stream
125+
.Arg<Buffer_Type>() // input
126+
.Arg<Buffer_Type>() // min
127+
.Arg<Buffer_Type>() // max
128+
.Arg<Buffer_Type>() // mean
129+
.Arg<Buffer_Type>() // std
130+
.Ret<Buffer_Type>() // output
131+
.Attr<std::string_view>("name") // probe name
132132
);
133133

134134
} // namespace jax

transformer_engine/jax/ep.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def _allgather_uid(uid_arr, world_size, uid_size):
4949
devices = np.asarray(jax.devices())
5050
if devices.size != world_size:
5151
raise RuntimeError(
52-
f"_allgather_uid fallback expected {world_size} global devices,"
53-
f" got {devices.size}."
52+
f"_allgather_uid fallback expected {world_size} global devices, got {devices.size}."
5453
)
5554
mesh = jax.sharding.Mesh(devices, ("_uid_all",))
5655
sharded = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("_uid_all", None))
@@ -268,8 +267,13 @@ def _dispatch_bwd(handle, recv_capacity_per_rank, res, g_outputs):
268267

269268
@partial(jax.custom_vjp, nondiff_argnums=(0, 5, 6))
270269
def ep_combine(
271-
handle, handle_mem, token_counts, expert_out, recv_topk_weights,
272-
num_local_tokens, out_sharding=None,
270+
handle,
271+
handle_mem,
272+
token_counts,
273+
expert_out,
274+
recv_topk_weights,
275+
num_local_tokens,
276+
out_sharding=None,
273277
):
274278
"""Reduce weighted expert outputs back to source ranks.
275279
@@ -291,8 +295,13 @@ def ep_combine(
291295
``[..., H]`` combined output shaped per ``num_local_tokens``.
292296
"""
293297
return _combine_fwd(
294-
handle, handle_mem, token_counts, expert_out, recv_topk_weights,
295-
num_local_tokens, out_sharding,
298+
handle,
299+
handle_mem,
300+
token_counts,
301+
expert_out,
302+
recv_topk_weights,
303+
num_local_tokens,
304+
out_sharding,
296305
)[0]
297306

298307

@@ -302,8 +311,13 @@ def _make_valid_mask(recv_topk_weights, dtype):
302311

303312

304313
def _combine_fwd(
305-
handle, handle_mem, token_counts, expert_out, recv_topk_weights,
306-
num_local_tokens, out_sharding,
314+
handle,
315+
handle_mem,
316+
token_counts,
317+
expert_out,
318+
recv_topk_weights,
319+
num_local_tokens,
320+
out_sharding,
307321
):
308322
del token_counts
309323
w = recv_topk_weights[..., None]

0 commit comments

Comments
 (0)