From 55d5c6e061a93bdceb34dbb8ba080c1777af12b2 Mon Sep 17 00:00:00 2001 From: Simran Kaur Date: Thu, 4 Jun 2026 18:20:00 +0000 Subject: [PATCH 1/3] Fix mesh creation to use local devices for single-host benchmarks --- Ironwood/src/benchmark_collectives.py | 8 ++++++-- Ironwood/src/benchmark_utils.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index cac1214..a1e35a2 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -35,7 +35,11 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: """Creates a mesh with the given ICI size.""" devices_needed = ici_size - devices = jax.devices() + local_devices = jax.local_devices() + if devices_needed <= len(local_devices): + devices = local_devices + else: + devices = jax.devices() if len(devices) < devices_needed: raise ValueError( @@ -52,7 +56,7 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh: first_device = devices[0] device_kind = first_device.device_kind print("Device kind: ", device_kind) - mesh_devices = mesh_utils.create_device_mesh(shape, devices=jax.devices()) + mesh_devices = mesh_utils.create_device_mesh(shape, devices=devices) mesh = Mesh(mesh_devices, axis_names) return mesh diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index d04e2c5..9d115b4 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1191,7 +1191,7 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): - num_devices = jax.device_count() + num_devices = jax.local_device_count() assert ( num_devices % 2 == 0 ), "Total devices must be divisible by 2 (chip size)" @@ -1199,10 +1199,10 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") mesh = jax.sharding.Mesh( - np.array(jax.devices()).reshape(mesh_shape), mesh_axes + np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes ) else: - mesh = Mesh(np.array(jax.devices()), axis_names="device") + mesh = Mesh(np.array(jax.local_devices()), axis_names="device") return mesh From c1da258732c20558a983449326e86d1f7fc8c5bc Mon Sep 17 00:00:00 2001 From: Simran Kaur Date: Mon, 8 Jun 2026 14:18:11 +0000 Subject: [PATCH 2/3] Fix pylint errors in benchmark_collectives.py --- Ironwood/src/benchmark_collectives.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index a1e35a2..d034a6c 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -122,18 +122,20 @@ def unified_ici_collectives_metrics( else: replica_group_type = "non-parallel" - # Safe to access [0] without safeguard because JAX guarantees at least one device is - # always initialized (CPU fallback if no accelerator), and mesh creation has already - # validated that the requested number of devices exist. + # Safe to access [0] without safeguard because JAX guarantees at least + # one device is always initialized (CPU fallback if no accelerator), and + # mesh creation has already validated that the requested number of + # devices exist. device_kind = jax.devices()[0].device_kind if device_kind in V6E_DEVICE_KINDS: # For TPU v6e (Trillium), 1 physical chip = 1 JAX device. - # Ring collective communication volume per chip across N ranks is exactly (N - 1) shards. + # Ring collective communication volume per chip across N ranks is + # exactly (N - 1) shards. # There is no dual-core traffic multiplier needed. participating_ranks = rank - 1 tf_multiplier = 1 else: - # Dual-core logic for TPU v7x + # Dual-core logic for TPU v7x if replica_group_type == "parallel": participating_ranks = rank - 1 tf_multiplier = 2 From 946541b4ac008fbbf768e34b13ff5a49aa4d5a56 Mon Sep 17 00:00:00 2001 From: Simran Kaur Date: Tue, 9 Jun 2026 05:40:17 +0000 Subject: [PATCH 3/3] fix: support configurable mesh locality and device count for GEMM --- Ironwood/src/benchmark_gemm.py | 61 ++++++++++++++++++++++++++------- Ironwood/src/benchmark_utils.py | 45 ++++++++++++++++-------- 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c79bda5..4244a05 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -66,6 +66,7 @@ def gemm_multiple_run( dtype: jnp.dtype = jax.numpy.float8_e4m3fn, num_runs: int = 1, trace_dir: str = None, + run_on_local_node: bool = False, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype.""" @@ -79,7 +80,7 @@ def f(x, y): ) return acc.astype(jnp.bfloat16) - mesh = create_mesh(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY, local_mesh=run_on_local_node) lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) out_sharding = get_out_sharding(SHARDING_STRATEGY) @@ -140,11 +141,15 @@ def gemm_multiple_run_calculate_metrics( n: int, dtype: jnp.dtype, time_ms_list: list[float], + run_on_local_node: bool = False, ) -> Dict[str, Any]: # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations + device_count = ( + jax.local_device_count() if run_on_local_node else jax.device_count() + ) total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY, device_count=device_count ) peak_flops = ( PEAK_FLOPS_PER_DEVICE @@ -169,6 +174,7 @@ def gemm_simple( n: int, num_runs: int = 1, trace_dir: str = None, + run_on_local_node: bool = False, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" # Accumulation is FP32. @@ -180,7 +186,7 @@ def f(x, y): ) return acc.astype(jnp.bfloat16) - mesh = create_mesh(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY, local_mesh=run_on_local_node) lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) out_sharding = get_out_sharding(SHARDING_STRATEGY) @@ -239,11 +245,15 @@ def gemm_simple_calculate_metrics( k: int, n: int, time_ms_list: list[float], + run_on_local_node: bool = False, ) -> Dict[str, Any]: # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations + device_count = ( + jax.local_device_count() if run_on_local_node else jax.device_count() + ) total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY, device_count=device_count ) return unified_flops_metrics( m, @@ -264,6 +274,7 @@ def gemm_simple_with_dtype( out_dtype_str: str, num_runs: int = 1, trace_dir: str = None, + run_on_local_node: bool = False, ) -> Dict[str, Any]: """Benchmarks the OUT:BF16 = IN0:FP8 x IN1:FP8.""" # Accumulation is FP32. @@ -280,7 +291,7 @@ def f(x, y): ) return acc.astype(out_dtype) - mesh = create_mesh(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY, local_mesh=run_on_local_node) lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) out_sharding = get_out_sharding(SHARDING_STRATEGY) @@ -337,11 +348,15 @@ def gemm_simple_with_dtype_calculate_metrics( in_dtype_str: str, out_dtype_str: str, time_ms_list: list[float], + run_on_local_node: bool = False, ) -> Dict[str, Any]: # Calculate FLOPs total_flops = (2 * k - 1) * m * n # Total floating-point operations + device_count = ( + jax.local_device_count() if run_on_local_node else jax.device_count() + ) total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY, device_count=device_count ) # Get the multiplier by calling the utility function @@ -365,7 +380,12 @@ def gemm_simple_with_dtype_calculate_metrics( def gemm( - m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None + m: int, + k: int, + n: int, + num_runs: int = 1, + trace_dir: str = None, + run_on_local_node: bool = False, ) -> Dict[str, Any]: """OUT:BF16 = matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" @@ -379,7 +399,7 @@ def f(x, y, scale_m, scale_n): result_fp32 = acc * scales return result_fp32.astype(jnp.bfloat16) - mesh = create_mesh(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY, local_mesh=run_on_local_node) lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) sf0_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) @@ -448,12 +468,19 @@ def data_generator(): def gemm_calculate_metrics( - m: int, k: int, n: int, time_ms_list: list[float] + m: int, + k: int, + n: int, + time_ms_list: list[float], + run_on_local_node: bool = False, ) -> Dict[str, Any]: # Calculate FLOPs total_flops = 2 * m * k * n # Total floating-point operations + device_count = ( + jax.local_device_count() if run_on_local_node else jax.device_count() + ) total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY, device_count=device_count ) return unified_flops_metrics( m, @@ -472,6 +499,7 @@ def gemm_accum( n: int, num_runs: int = 1, trace_dir: str = None, + run_on_local_node: bool = False, ) -> Dict[str, Any]: """OUT:FP32 += matmul(IN0:FP8, IN1:FP8) * outer_product(SF0:FP32 * SF1<1, N>:FP32).""" @@ -485,7 +513,7 @@ def f(out_buffer, x, y, scale_m, scale_n): result_fp32 = acc * scales return out_buffer + result_fp32 - mesh = create_mesh(SHARDING_STRATEGY) + mesh = create_mesh(SHARDING_STRATEGY, local_mesh=run_on_local_node) lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) sf0_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) @@ -568,12 +596,19 @@ def data_generator(): def gemm_accum_calculate_metrics( - m: int, k: int, n: int, time_ms_list: list[float] + m: int, + k: int, + n: int, + time_ms_list: list[float], + run_on_local_node: bool = False, ) -> Dict[str, Any]: # Calculate FLOPs total_flops = 2 * m * k * n + m * n # Total floating-point operations + device_count = ( + jax.local_device_count() if run_on_local_node else jax.device_count() + ) total_flops, total_flops_all_devices = handle_based_on_sharding( - total_flops, SHARDING_STRATEGY + total_flops, SHARDING_STRATEGY, device_count=device_count ) return unified_flops_metrics( m, diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 9d115b4..f39b711 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1148,50 +1148,65 @@ def get_output_named_shading(mesh, strategy: ShardingStrategy): return NamedSharding(mesh, P(None, "device")) -def handle_per_device_based_on_sharding(value, strategy: ShardingStrategy): +def handle_per_device_based_on_sharding( + value, strategy: ShardingStrategy, device_count: int +): match strategy: case ShardingStrategy.NO_SHARDING: return value case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M: - return value // jax.device_count() + return value // device_count case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M: return value // 2 case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N: - return value // jax.device_count() + return value // device_count case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N: return value // 2 def handle_all_devices_based_on_sharding( - value: int, strategy: ShardingStrategy + value: int, strategy: ShardingStrategy, device_count: int ): match strategy: case ShardingStrategy.NO_SHARDING: - return value * jax.device_count() + return value * device_count case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M: return value case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M: - return value * jax.device_count() // 2 + return value * device_count // 2 case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N: return value case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N: - return value * jax.device_count() // 2 + return value * device_count // 2 -def handle_based_on_sharding(value: int, strategy: ShardingStrategy): +def handle_based_on_sharding( + value: int, strategy: ShardingStrategy, device_count: int | None = None +): + if device_count is None: + device_count = jax.device_count() total_value = value - value = handle_per_device_based_on_sharding(value, strategy) - total_value = handle_all_devices_based_on_sharding(total_value, strategy) + value = handle_per_device_based_on_sharding(value, strategy, device_count) + total_value = handle_all_devices_based_on_sharding( + total_value, strategy, device_count + ) return value, total_value -def create_mesh(strategy: ShardingStrategy) -> Mesh: - """Creates a mesh.""" +def create_mesh(strategy: ShardingStrategy, local_mesh: bool = False) -> Mesh: + """Creates a mesh. + + Args: + strategy: The sharding strategy to apply. + local_mesh: If True, restricts the mesh to local devices. + If False, uses all available devices. + """ + devices = jax.local_devices() if local_mesh else jax.devices() + num_devices = len(devices) if ( strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M or strategy == ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N ): - num_devices = jax.local_device_count() assert ( num_devices % 2 == 0 ), "Total devices must be divisible by 2 (chip size)" @@ -1199,10 +1214,10 @@ def create_mesh(strategy: ShardingStrategy) -> Mesh: mesh_shape = (num_chips, 2) mesh_axes = ("chip", "device") mesh = jax.sharding.Mesh( - np.array(jax.local_devices()).reshape(mesh_shape), mesh_axes + np.array(devices).reshape(mesh_shape), mesh_axes ) else: - mesh = Mesh(np.array(jax.local_devices()), axis_names="device") + mesh = Mesh(np.array(devices), axis_names="device") return mesh