Skip to content

Commit 19ace8e

Browse files
author
Utkarsh Sharma
committed
Revert "Fix psum & psum_scatter sharding logic"
This reverts commit 0c848de.
1 parent 0c848de commit 19ace8e

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/benchmark_collectives.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def psum_benchmark(
7878
# DCN benchmark
7979
if dcn_size > 1:
8080

81-
@partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None, None))
81+
@partial(shard_map, mesh=mesh, in_specs=P("dcn", None), out_specs=P(None))
8282
def f(x):
8383
return jax.lax.psum(x, "dcn")
8484

@@ -99,12 +99,12 @@ def f(x):
9999
# ICI benchmark
100100
if ici_size > 1:
101101

102-
@partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, None))
102+
@partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, None))
103103
def f(x):
104104
return jax.lax.psum(x, "ici")
105105

106106
sharded_matrix = jax.device_put(
107-
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
107+
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
108108
)
109109
jitted_op = jax.jit(f)
110110
ici_average_time_ms_list = simple_timeit(
@@ -235,12 +235,12 @@ def f(x):
235235
# ICI benchmark
236236
if ici_size > 1:
237237

238-
@partial(shard_map, mesh=mesh, in_specs=P(None, "ici"), out_specs=P(None, "ici"))
238+
@partial(shard_map, mesh=mesh, in_specs=P(None, None), out_specs=P(None, "ici"))
239239
def f(x):
240240
return jax.lax.psum_scatter(x, "ici", tiled=True)
241241

242242
sharded_matrix = jax.device_put(
243-
matrix, jax.sharding.NamedSharding(mesh, P(None, "ici"))
243+
matrix, jax.sharding.NamedSharding(mesh, P(None, None))
244244
)
245245
jitted_op = jax.jit(f)
246246
ici_average_time_ms_list = simple_timeit(
@@ -443,9 +443,9 @@ def all_gather_benchmark_calculate_metrics(
443443
# each sharded matrix size is matrix_size_gbyte / ici_size and then it needs
444444
# to use (ici_size - 1) steps in a ring algorithm
445445
ici_bandwidth_gbyte_s_list = [
446-
matrix_size_gbyte
447-
* (ici_size - 1)
448-
/ ici_size
446+
matrix_size_gbyte
447+
* (ici_size - 1)
448+
/ ici_size
449449
/ (ici_average_time_ms / 1e3)
450450
for ici_average_time_ms in ici_average_time_ms_list
451451
]

0 commit comments

Comments
 (0)