@@ -245,11 +245,13 @@ def test_two_layer_dispatch_no_handle_aliasing(self):
245245 w = jax .lax .with_sharding_constraint (topk_w , NamedSharding (self .mesh , dp_spec ))
246246
247247 def one_layer (hk , idx , toks , w_ ):
248- recv_t , recv_w , hm , tc = ep_dispatch (
249- hk , idx , toks , w_ , self .recv_capacity_per_rank
248+ recv_t , recv_w , hm , tc = ep_dispatch (hk , idx , toks , w_ , self .recv_capacity_per_rank )
249+ recv_t = jax .lax .with_sharding_constraint (
250+ recv_t , NamedSharding (self .mesh , ep_spec_3d )
251+ )
252+ recv_w = jax .lax .with_sharding_constraint (
253+ recv_w , NamedSharding (self .mesh , ep_spec_2d )
250254 )
251- recv_t = jax .lax .with_sharding_constraint (recv_t , NamedSharding (self .mesh , ep_spec_3d ))
252- recv_w = jax .lax .with_sharding_constraint (recv_w , NamedSharding (self .mesh , ep_spec_2d ))
253255 return ep_combine (
254256 hk , hm , tc , recv_t , recv_w , T_global , out_sharding = (("dp" , "ep" ), None )
255257 )
@@ -269,12 +271,14 @@ def run(idx, ta_, tb_, w_):
269271 np .testing .assert_allclose (
270272 np .asarray (out_a_g .astype (jnp .float32 )),
271273 np .asarray (tokens .astype (jnp .float32 )),
272- atol = 5e-2 , rtol = 5e-2 ,
274+ atol = 5e-2 ,
275+ rtol = 5e-2 ,
273276 )
274277 np .testing .assert_allclose (
275278 np .asarray (out_b_g .astype (jnp .float32 )),
276279 np .asarray (tokens_b .astype (jnp .float32 )),
277- atol = 5e-2 , rtol = 5e-2 ,
280+ atol = 5e-2 ,
281+ rtol = 5e-2 ,
278282 )
279283
280284 def test_primitive_prepare (self ):
@@ -328,7 +332,10 @@ def run(idx, toks, w):
328332 weighted , NamedSharding (self .mesh , ep_spec_3d )
329333 )
330334 out = ep_combine_fwd (
331- self .hk , hm , weighted , T_global ,
335+ self .hk ,
336+ hm ,
337+ weighted ,
338+ T_global ,
332339 out_partition_spec = (("dp" , "ep" ), None ),
333340 )
334341 return jax .lax .with_sharding_constraint (out , NamedSharding (self .mesh , dp_spec ))
@@ -372,7 +379,9 @@ def loss_fn(toks):
372379 toks = jax .lax .with_sharding_constraint (toks , NamedSharding (self .mesh , dp_spec ))
373380 idx = jax .lax .with_sharding_constraint (topk_idx , NamedSharding (self .mesh , dp_spec ))
374381 w = jax .lax .with_sharding_constraint (topk_w , NamedSharding (self .mesh , dp_spec ))
375- recv_t , recv_w , hm , tc = ep_dispatch (self .hk , idx , toks , w , self .recv_capacity_per_rank )
382+ recv_t , recv_w , hm , tc = ep_dispatch (
383+ self .hk , idx , toks , w , self .recv_capacity_per_rank
384+ )
376385 recv_t = jax .lax .with_sharding_constraint (
377386 recv_t , NamedSharding (self .mesh , ep_spec_3d )
378387 )
@@ -420,7 +429,9 @@ def test_dispatch_combine_3d_input_output(self):
420429
421430 @jax .jit
422431 def run (idx , toks , w ):
423- recv_t , recv_w , hm , _tc = ep_dispatch (self .hk , idx , toks , w , self .recv_capacity_per_rank )
432+ recv_t , recv_w , hm , _tc = ep_dispatch (
433+ self .hk , idx , toks , w , self .recv_capacity_per_rank
434+ )
424435 recv_t = jax .lax .with_sharding_constraint (recv_t , NamedSharding (self .mesh , ep_t ))
425436 recv_w = jax .lax .with_sharding_constraint (recv_w , NamedSharding (self .mesh , ep_w ))
426437 out = ep_combine (
@@ -463,7 +474,9 @@ def test_dispatch_combine_dp_only_first_dim(self):
463474
464475 @jax .jit
465476 def run (idx , toks , w ):
466- recv_t , recv_w , hm , _tc = ep_dispatch (self .hk , idx , toks , w , self .recv_capacity_per_rank )
477+ recv_t , recv_w , hm , _tc = ep_dispatch (
478+ self .hk , idx , toks , w , self .recv_capacity_per_rank
479+ )
467480 recv_t = jax .lax .with_sharding_constraint (recv_t , NamedSharding (self .mesh , ep_t ))
468481 recv_w = jax .lax .with_sharding_constraint (recv_w , NamedSharding (self .mesh , ep_w ))
469482 out = ep_combine (
@@ -641,7 +654,9 @@ def run(idx, toks, w):
641654 idx = jax .lax .with_sharding_constraint (idx , NamedSharding (self .mesh , dp_spec ))
642655 toks = jax .lax .with_sharding_constraint (toks , NamedSharding (self .mesh , dp_spec ))
643656 w = jax .lax .with_sharding_constraint (w , NamedSharding (self .mesh , dp_spec ))
644- recv_t , recv_w , hm , tc = ep_dispatch (self .hk , idx , toks , w , self .recv_capacity_per_rank )
657+ recv_t , recv_w , hm , tc = ep_dispatch (
658+ self .hk , idx , toks , w , self .recv_capacity_per_rank
659+ )
645660 recv_t = jax .lax .with_sharding_constraint (
646661 recv_t , NamedSharding (self .mesh , ep_spec_3d )
647662 )
@@ -688,7 +703,9 @@ def fwd(eo, toks, idx, w):
688703 w = jax .lax .with_sharding_constraint (w , NamedSharding (self .mesh , dp_spec ))
689704 _rt , rw , hm , tc = ep_dispatch (self .hk , idx , toks , w , self .recv_capacity_per_rank )
690705 rw = jax .lax .with_sharding_constraint (rw , NamedSharding (self .mesh , ep_spec_2d ))
691- combined = ep_combine (self .hk , hm , tc , eo , rw , T_dp , out_sharding = (("dp" , "ep" ), None ))
706+ combined = ep_combine (
707+ self .hk , hm , tc , eo , rw , T_dp , out_sharding = (("dp" , "ep" ), None )
708+ )
692709 return jax .lax .with_sharding_constraint (combined , NamedSharding (self .mesh , dp_spec ))
693710
694711 # jax.vjp + pinned cotangent feeds ep_combine_bwd/ep_dispatch_bwd
0 commit comments