@@ -78,7 +78,7 @@ def psum_benchmark(
7878 # DCN benchmark
7979 if dcn_size > 1 :
8080
81- @partial (shard_map , mesh = mesh , in_specs = P ("dcn" , None ), out_specs = P (None , None ))
81+ @partial (shard_map , mesh = mesh , in_specs = P ("dcn" , None ), out_specs = P (None ))
8282 def f (x ):
8383 return jax .lax .psum (x , "dcn" )
8484
@@ -99,12 +99,12 @@ def f(x):
9999 # ICI benchmark
100100 if ici_size > 1 :
101101
102- @partial (shard_map , mesh = mesh , in_specs = P (None , "ici" ), out_specs = P (None , None ))
102+ @partial (shard_map , mesh = mesh , in_specs = P (None , None ), out_specs = P (None , None ))
103103 def f (x ):
104104 return jax .lax .psum (x , "ici" )
105105
106106 sharded_matrix = jax .device_put (
107- matrix , jax .sharding .NamedSharding (mesh , P (None , "ici" ))
107+ matrix , jax .sharding .NamedSharding (mesh , P (None , None ))
108108 )
109109 jitted_op = jax .jit (f )
110110 ici_average_time_ms_list = simple_timeit (
@@ -235,12 +235,12 @@ def f(x):
235235 # ICI benchmark
236236 if ici_size > 1 :
237237
238- @partial (shard_map , mesh = mesh , in_specs = P (None , "ici" ), out_specs = P (None , "ici" ))
238+ @partial (shard_map , mesh = mesh , in_specs = P (None , None ), out_specs = P (None , "ici" ))
239239 def f (x ):
240240 return jax .lax .psum_scatter (x , "ici" , tiled = True )
241241
242242 sharded_matrix = jax .device_put (
243- matrix , jax .sharding .NamedSharding (mesh , P (None , "ici" ))
243+ matrix , jax .sharding .NamedSharding (mesh , P (None , None ))
244244 )
245245 jitted_op = jax .jit (f )
246246 ici_average_time_ms_list = simple_timeit (
@@ -443,9 +443,9 @@ def all_gather_benchmark_calculate_metrics(
443443 # each sharded matrix size is matrix_size_gbyte / ici_size and then it needs
444444 # to use (ici_size - 1) steps in a ring algorithm
445445 ici_bandwidth_gbyte_s_list = [
446- matrix_size_gbyte
447- * (ici_size - 1 )
448- / ici_size
446+ matrix_size_gbyte
447+ * (ici_size - 1 )
448+ / ici_size
449449 / (ici_average_time_ms / 1e3 )
450450 for ici_average_time_ms in ici_average_time_ms_list
451451 ]
0 commit comments