Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Ironwood/src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
"""Creates a mesh with the given ICI size."""
devices_needed = ici_size
devices = jax.devices()
local_devices = jax.local_devices()
if devices_needed <= len(local_devices):
devices = local_devices
else:
devices = jax.devices()

if len(devices) < devices_needed:
raise ValueError(
Expand All @@ -51,7 +55,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
first_device = devices[0]
device_kind = first_device.device_kind
print("Device kind: ", device_kind)
mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices())
mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices)
mesh = Mesh(mesh_devices, axis_names)
return mesh

Expand Down
6 changes: 3 additions & 3 deletions Ironwood/src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,18 +1192,18 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh:
strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M
or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N
):
num_devices = jax.device_count()
num_devices = jax.local_device_count()
assert (
num_devices % 2 == 0
), "Total devices must be divisible by 2 (chip size)"
num_chips = num_devices // 2
mesh_shape = (num_chips, 2)
mesh_axes = ("chip", "device")
mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(mesh_shape), mesh_axes
np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes
)
else:
mesh = Mesh(np.array(jax.devices()), axis_names="device")
mesh = Mesh(np.array(jax.local_devices()), axis_names="device")
return mesh


Expand Down