|
| 1 | +"""Benchmarks gemm + all_reduce for DP gradient sync simulation.""" |
| 2 | + |
| 3 | +import os |
| 4 | +from typing import Any, Dict |
| 5 | + |
| 6 | +# pylint: disable=g-importing-member |
| 7 | +from benchmark_utils import ( |
| 8 | + iteration_timeit, |
| 9 | + ShardingStrategy, |
| 10 | + get_lhs_named_shading, |
| 11 | + get_rhs_named_shading, |
| 12 | + get_out_sharding, |
| 13 | + create_mesh, |
| 14 | + handle_based_on_sharding, |
| 15 | + unified_flops_metrics, |
| 16 | + MetricsStatistics, |
| 17 | + get_metrics_helper, |
| 18 | + str_to_dtype, |
| 19 | + get_peak_flops_multiplier |
| 20 | +) |
| 21 | +from common import MARKER |
| 22 | +import jax |
| 23 | +from jax.experimental.shard_map import shard_map |
| 24 | +import jax.numpy as jnp |
| 25 | + |
| 26 | + |
| 27 | +# pylint: disable=g-importing-member |
| 28 | + |
| 29 | +os.environ["LIBTPU_INIT_ARGS"] = ( |
| 30 | + "--xla_tpu_enable_async_collective_fusion=true " |
| 31 | + "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true " |
| 32 | + "--xla_tpu_enable_async_collective_fusion_multiple_steps=true " |
| 33 | + "--xla_tpu_overlap_compute_collective_tc=true " |
| 34 | + "--xla_enable_async_all_gather=true " |
| 35 | + "--xla_enable_async_collective_permute=true " |
| 36 | + "--xla_tpu_enable_all_experimental_scheduler_features=true " |
| 37 | + "--xla_tpu_accumulate_into_mrb=true " |
| 38 | + "--xla_tpu_scoped_vmem_limit_kib=65536 " |
| 39 | + "--xla_tpu_vmem_scavenging_mode=NONE " |
| 40 | + "--xla_tpu_dvfs_p_state=7 " |
| 41 | + |
| 42 | + "--xla_jf_debug_level=3 " |
| 43 | + "--xla_sc_disable_megacore_partitioning=true " |
| 44 | + "--xla_tpu_disable_sparse_core_collective_offload_remover=true " |
| 45 | + "--xla_tpu_enable_all_reduce_offload_tracing=true " |
| 46 | + "--xla_tpu_enable_all_reduce_scatter_fusion=false " |
| 47 | + "--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true " |
| 48 | + "--xla_tpu_pad_operations_input_tiles=true " |
| 49 | + "--xla_tpu_sparse_core_all_reduce_offload_min_size_in_bytes=0 " |
| 50 | + "--xla_tpu_use_tc_device_shape_on_sc=true " |
| 51 | +) |
| 52 | + |
| 53 | +# Matmul shapes: A(M,K) x B(K,N) = C(M,N) |
| 54 | +# Then AllReduce(C) |
| 55 | +SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING |
| 56 | +SEED = 0 |
| 57 | +PEAK_FLOPS_PER_DEVICE = 2307 # TFLOP/s for single core(device) of FP8 |
| 58 | + |
| 59 | + |
| 60 | +def gemm_all_reduce( |
| 61 | + m: int, |
| 62 | + k: int, |
| 63 | + n: int, |
| 64 | + dtype: jnp.dtype = jnp.bfloat16, |
| 65 | + num_runs: int = 1, |
| 66 | + trace_dir: str = None, |
| 67 | +) -> Dict[str, Any]: |
| 68 | + """Benchmarks the AllReduce(matmul(A, B)). |
| 69 | + |
| 70 | + A: [M, K] |
| 71 | + B: [K, N] |
| 72 | + C = A @ B: [M, N] |
| 73 | + Output = AllReduce(C) |
| 74 | + """ |
| 75 | + |
| 76 | + dtype_str = dtype.dtype.name |
| 77 | + print(f"Running gemm_all_reduce benchmark with m={m}, k={k}, n={n}, dtype={dtype_str}, runs={num_runs}") |
| 78 | + |
| 79 | + def f(x, y): |
| 80 | + with jax.named_scope(MARKER): |
| 81 | + # Matmul |
| 82 | + acc = jax.numpy.einsum( |
| 83 | + "ij,jk->ik", x, y, preferred_element_type=jnp.float32 |
| 84 | + ) |
| 85 | + c = acc.astype(dtype) |
| 86 | + |
| 87 | + # AllReduce (psum) |
| 88 | + out = jax.lax.psum(c, axis_name="device") |
| 89 | + return out |
| 90 | + |
| 91 | + # This benchmark simulates the Data Parallel (DP) Backward Pass: |
| 92 | + # 1. Local Gradient Computation: Each device computes `Grads = Activations.T @ GradOutput`. |
| 93 | + # - Here: `acc = x @ y` corresponds to `(M, K) @ (K, N) -> (M, N)`. |
| 94 | + # - `K` represents the LOCAL Batch Size (contracting dimension). |
| 95 | + # - `M` and `N` represent the Weight dimensions (e.g. Hidden Size). |
| 96 | + # - The input `x` and `y` are effectively local to the device (replicated or split, the compute is local). |
| 97 | + # 2. Gradient Synchronization: `AllReduce(Grads)`. |
| 98 | + # - `out = psum(c, axis_name="device")` sums the partial gradients across all devices. |
| 99 | + |
| 100 | + # We use `ShardingStrategy.NO_SHARDING` for the mesh. |
| 101 | + # In `benchmark_utils`, this creates a mesh with a single "device" axis containing all devices. |
| 102 | + # Inside `shard_map` (with `check_rep=False` and fully replicated in_specs P(None, None)), |
| 103 | + # each device receives the input arrays and executes the function `f`. |
| 104 | + # `psum("device")` then performs the AllReduce across all devices in the mesh. |
| 105 | + |
| 106 | + mesh = create_mesh(SHARDING_STRATEGY) |
| 107 | + lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY) |
| 108 | + rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY) |
| 109 | + out_sharding = get_out_sharding(SHARDING_STRATEGY) |
| 110 | + |
| 111 | + # Note: `out_sharding` for NO_SHARDING is P(None, None). |
| 112 | + # The output of `f` (post-psum) is mathematically consistent across devices (replicated). |
| 113 | + |
| 114 | + jit_sharded_f = jax.jit( |
| 115 | + shard_map( |
| 116 | + f, |
| 117 | + mesh, |
| 118 | + in_specs=( |
| 119 | + lhs_sharding.spec, |
| 120 | + rhs_sharding.spec, |
| 121 | + ), |
| 122 | + out_specs=out_sharding, |
| 123 | + check_rep=False, |
| 124 | + ) |
| 125 | + ) |
| 126 | + |
| 127 | + lhs_shape = (m, k) |
| 128 | + rhs_shape = (k, n) |
| 129 | + |
| 130 | + lhs_dtype = dtype |
| 131 | + rhs_dtype = dtype |
| 132 | + |
| 133 | + key = jax.random.key(SEED) |
| 134 | + |
| 135 | + def data_generator(): |
| 136 | + """Creates new random data on host and puts it on device.""" |
| 137 | + nonlocal key |
| 138 | + key, key_lhs, key_rhs = jax.random.split(key, 3) |
| 139 | + |
| 140 | + # Create random data on host |
| 141 | + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) |
| 142 | + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) |
| 143 | + |
| 144 | + # Put on device (HBM) |
| 145 | + lhs_device = jax.device_put(lhs_host, lhs_sharding) |
| 146 | + rhs_device = jax.device_put(rhs_host, rhs_sharding) |
| 147 | + |
| 148 | + return (lhs_device, rhs_device) |
| 149 | + |
| 150 | + time_ms_list = iteration_timeit( |
| 151 | + jit_sharded_f, |
| 152 | + data_generator, |
| 153 | + matrix_dim=f"{dtype_str}_{m}x{n}x{k}", |
| 154 | + tries=num_runs, |
| 155 | + task=f"gemm_all_reduce_{dtype_str}", |
| 156 | + trace_dir=trace_dir, |
| 157 | + ) |
| 158 | + return { |
| 159 | + "time_ms_list": time_ms_list, |
| 160 | + } |
| 161 | + |
| 162 | + |
| 163 | +def gemm_all_reduce_calculate_metrics( |
| 164 | + m: int, |
| 165 | + k: int, |
| 166 | + n: int, |
| 167 | + dtype: jnp.dtype, |
| 168 | + time_ms_list: list[float], |
| 169 | +) -> Dict[str, Any]: |
| 170 | + # Calculate FLOPs (Matmul) |
| 171 | + total_flops = 2 * m * k * n |
| 172 | + |
| 173 | + total_flops_per_device, total_flops_all_devices = handle_based_on_sharding( |
| 174 | + total_flops, SHARDING_STRATEGY |
| 175 | + ) |
| 176 | + |
| 177 | + dtype_str = dtype.dtype.name |
| 178 | + peak_flops_multiplier = get_peak_flops_multiplier(dtype_str) |
| 179 | + peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier |
| 180 | + |
| 181 | + # Unified FLOPs metrics |
| 182 | + metadata, metrics = unified_flops_metrics( |
| 183 | + m, |
| 184 | + n, |
| 185 | + k, |
| 186 | + time_ms_list, |
| 187 | + total_flops_per_device, |
| 188 | + total_flops_all_devices, |
| 189 | + peak_flops, |
| 190 | + dtype=dtype_str, |
| 191 | + ) |
| 192 | + |
| 193 | + # Calculate Bandwidth for AllReduce |
| 194 | + # AllReduce moves Matrix C: M x N |
| 195 | + matrix_c_size_bytes = m * n * dtype.dtype.itemsize |
| 196 | + |
| 197 | + metadata, metrics = unified_flops_metrics( |
| 198 | + m, n, k, time_ms_list, total_flops_per_device, total_flops_all_devices, peak_flops, dtype=dtype_str, |
| 199 | + total_bytes=matrix_c_size_bytes, |
| 200 | + bandwidth_metric_name="all_reduce_algo_bw_gbs" |
| 201 | + ) |
| 202 | + |
| 203 | + return metadata, metrics |
0 commit comments