diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index f1a10fe..e6e984a 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -34,7 +34,11 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: """Creates a mesh with the given ICI size.""" devices_needed = ici_size - devices = jax.devices() + local_devices = jax.local_devices() + if devices_needed <= len(local_devices): + devices = local_devices + else: + devices = jax.devices() if len(devices) < devices_needed: raise ValueError( @@ -51,7 +55,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: first_device = devices[0] device_kind = first_device.device_kind print("Device kind: ", device_kind) - mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) + mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices) mesh = Mesh(mesh_devices, axis_names) return mesh diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 3396269..5e68ad0 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1192,7 +1192,7 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): - num_devices = jax.device_count() + num_devices = jax.local_device_count() assert ( num_devices % 2 == 0 ), "Total devices must be divisible by 2 (chip size)" @@ -1200,10 +1200,10 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape(mesh_shape), mesh_axes + np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes ) else: - mesh = Mesh(np.array(jax.devices()), axis_names="device") + mesh = Mesh(np.array(jax.local_devices()), axis_names="device") return mesh