Skip to content

Commit 7df6677

Browse files
committed
Matmul + All Reduce benchmark
1 parent 7638ebf commit 7df6677

4 files changed

Lines changed: 234 additions & 3 deletions

File tree

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
benchmarks:
2+
- benchmark_name: "gemm_all_reduce"
3+
trace_dir: "../microbenchmarks/gemm_all_reduce"
4+
csv_path: "../microbenchmarks/gemm_all_reduce"
5+
xlml_metrics_dir: "../microbenchmarks/gemm_all_reduce"
6+
xla_dump_dir: "../microbenchmarks/gemm_all_reduce/hlo_graphs"
7+
num_runs: 10
8+
benchmark_sweep_params:
9+
- {m: 1024, k: 1024, n: 1024, dtype: "bfloat16"}
10+
- {m: 2048, k: 2048, n: 2048, dtype: "bfloat16"}
11+
- {m: 4096, k: 4096, n: 4096, dtype: "bfloat16"}
12+
- {m: 8192, k: 8192, n: 8192, dtype: "bfloat16"}
13+
- {m: 16384, k: 16384, n: 16384, dtype: "bfloat16"}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Benchmarks gemm + all_reduce for DP gradient sync simulation."""
2+
3+
import os
4+
from typing import Any, Dict
5+
6+
# pylint: disable=g-importing-member
7+
from benchmark_utils import (
8+
iteration_timeit,
9+
ShardingStrategy,
10+
get_lhs_named_shading,
11+
get_rhs_named_shading,
12+
get_out_sharding,
13+
create_mesh,
14+
handle_based_on_sharding,
15+
unified_flops_metrics,
16+
MetricsStatistics,
17+
get_metrics_helper,
18+
str_to_dtype,
19+
get_peak_flops_multiplier
20+
)
21+
from common import MARKER
22+
import jax
23+
from jax.experimental.shard_map import shard_map
24+
import jax.numpy as jnp
25+
26+
27+
# pylint: disable=g-importing-member
28+
29+
os.environ["LIBTPU_INIT_ARGS"] = (
30+
"--xla_tpu_enable_async_collective_fusion=true "
31+
"--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true "
32+
"--xla_tpu_enable_async_collective_fusion_multiple_steps=true "
33+
"--xla_tpu_overlap_compute_collective_tc=true "
34+
"--xla_enable_async_all_gather=true "
35+
"--xla_enable_async_collective_permute=true "
36+
"--xla_tpu_enable_all_experimental_scheduler_features=true "
37+
"--xla_tpu_accumulate_into_mrb=true "
38+
"--xla_tpu_scoped_vmem_limit_kib=65536 "
39+
"--xla_tpu_vmem_scavenging_mode=NONE "
40+
"--xla_tpu_dvfs_p_state=7 "
41+
42+
"--xla_jf_debug_level=3 "
43+
"--xla_sc_disable_megacore_partitioning=true "
44+
"--xla_tpu_disable_sparse_core_collective_offload_remover=true "
45+
"--xla_tpu_enable_all_reduce_offload_tracing=true "
46+
"--xla_tpu_enable_all_reduce_scatter_fusion=false "
47+
"--xla_tpu_enable_sparse_core_collective_offload_all_reduce=true "
48+
"--xla_tpu_pad_operations_input_tiles=true "
49+
"--xla_tpu_sparse_core_all_reduce_offload_min_size_in_bytes=0 "
50+
"--xla_tpu_use_tc_device_shape_on_sc=true "
51+
)
52+
53+
# Matmul shapes: A(M,K) x B(K,N) = C(M,N)
54+
# Then AllReduce(C)
55+
SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
56+
SEED = 0
57+
PEAK_FLOPS_PER_DEVICE = 2307 # TFLOP/s for single core(device) of FP8
58+
59+
60+
def gemm_all_reduce(
61+
m: int,
62+
k: int,
63+
n: int,
64+
dtype: jnp.dtype = jnp.bfloat16,
65+
num_runs: int = 1,
66+
trace_dir: str = None,
67+
) -> Dict[str, Any]:
68+
"""Benchmarks the AllReduce(matmul(A, B)).
69+
70+
A: [M, K]
71+
B: [K, N]
72+
C = A @ B: [M, N]
73+
Output = AllReduce(C)
74+
"""
75+
76+
dtype_str = dtype.dtype.name
77+
print(f"Running gemm_all_reduce benchmark with m={m}, k={k}, n={n}, dtype={dtype_str}, runs={num_runs}")
78+
79+
def f(x, y):
80+
with jax.named_scope(MARKER):
81+
# Matmul
82+
acc = jax.numpy.einsum(
83+
"ij,jk->ik", x, y, preferred_element_type=jnp.float32
84+
)
85+
c = acc.astype(dtype)
86+
87+
# AllReduce (psum)
88+
out = jax.lax.psum(c, axis_name="device")
89+
return out
90+
91+
# This benchmark simulates the Data Parallel (DP) Backward Pass:
92+
# 1. Local Gradient Computation: Each device computes `Grads = Activations.T @ GradOutput`.
93+
# - Here: `acc = x @ y` corresponds to `(M, K) @ (K, N) -> (M, N)`.
94+
# - `K` represents the LOCAL Batch Size (contracting dimension).
95+
# - `M` and `N` represent the Weight dimensions (e.g. Hidden Size).
96+
# - The input `x` and `y` are effectively local to the device (replicated or split, the compute is local).
97+
# 2. Gradient Synchronization: `AllReduce(Grads)`.
98+
# - `out = psum(c, axis_name="device")` sums the partial gradients across all devices.
99+
100+
# We use `ShardingStrategy.NO_SHARDING` for the mesh.
101+
# In `benchmark_utils`, this creates a mesh with a single "device" axis containing all devices.
102+
# Inside `shard_map` (with `check_rep=False` and fully replicated in_specs P(None, None)),
103+
# each device receives the input arrays and executes the function `f`.
104+
# `psum("device")` then performs the AllReduce across all devices in the mesh.
105+
106+
mesh = create_mesh(SHARDING_STRATEGY)
107+
lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY)
108+
rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY)
109+
out_sharding = get_out_sharding(SHARDING_STRATEGY)
110+
111+
# Note: `out_sharding` for NO_SHARDING is P(None, None).
112+
# The output of `f` (post-psum) is mathematically consistent across devices (replicated).
113+
114+
jit_sharded_f = jax.jit(
115+
shard_map(
116+
f,
117+
mesh,
118+
in_specs=(
119+
lhs_sharding.spec,
120+
rhs_sharding.spec,
121+
),
122+
out_specs=out_sharding,
123+
check_rep=False,
124+
)
125+
)
126+
127+
lhs_shape = (m, k)
128+
rhs_shape = (k, n)
129+
130+
lhs_dtype = dtype
131+
rhs_dtype = dtype
132+
133+
key = jax.random.key(SEED)
134+
135+
def data_generator():
136+
"""Creates new random data on host and puts it on device."""
137+
nonlocal key
138+
key, key_lhs, key_rhs = jax.random.split(key, 3)
139+
140+
# Create random data on host
141+
lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype)
142+
rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype)
143+
144+
# Put on device (HBM)
145+
lhs_device = jax.device_put(lhs_host, lhs_sharding)
146+
rhs_device = jax.device_put(rhs_host, rhs_sharding)
147+
148+
return (lhs_device, rhs_device)
149+
150+
time_ms_list = iteration_timeit(
151+
jit_sharded_f,
152+
data_generator,
153+
matrix_dim=f"{dtype_str}_{m}x{n}x{k}",
154+
tries=num_runs,
155+
task=f"gemm_all_reduce_{dtype_str}",
156+
trace_dir=trace_dir,
157+
)
158+
return {
159+
"time_ms_list": time_ms_list,
160+
}
161+
162+
163+
def gemm_all_reduce_calculate_metrics(
164+
m: int,
165+
k: int,
166+
n: int,
167+
dtype: jnp.dtype,
168+
time_ms_list: list[float],
169+
) -> Dict[str, Any]:
170+
# Calculate FLOPs (Matmul)
171+
total_flops = 2 * m * k * n
172+
173+
total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(
174+
total_flops, SHARDING_STRATEGY
175+
)
176+
177+
dtype_str = dtype.dtype.name
178+
peak_flops_multiplier = get_peak_flops_multiplier(dtype_str)
179+
peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier
180+
181+
# Unified FLOPs metrics
182+
metadata, metrics = unified_flops_metrics(
183+
m,
184+
n,
185+
k,
186+
time_ms_list,
187+
total_flops_per_device,
188+
total_flops_all_devices,
189+
peak_flops,
190+
dtype=dtype_str,
191+
)
192+
193+
# Calculate Bandwidth for AllReduce
194+
# AllReduce moves Matrix C: M x N
195+
matrix_c_size_bytes = m * n * dtype.dtype.itemsize
196+
197+
metadata, metrics = unified_flops_metrics(
198+
m, n, k, time_ms_list, total_flops_per_device, total_flops_all_devices, peak_flops, dtype=dtype_str,
199+
total_bytes=matrix_c_size_bytes,
200+
bandwidth_metric_name="all_reduce_algo_bw_gbs"
201+
)
202+
203+
return metadata, metrics

Ironwood/src/benchmark_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,8 @@ def unified_flops_metrics(
11111111
total_flops_all_devices: int,
11121112
peak_TFLOPS_per_device: float,
11131113
dtype: str = None,
1114+
total_bytes: int = None,
1115+
bandwidth_metric_name: str = "GBytes/s/device",
11141116
) -> Dict[str, Any]:
11151117
"""Calculates the metrics for the naive matmul benchmark."""
11161118
# Build dictionary of all the parameters in the function
@@ -1140,13 +1142,25 @@ def unified_flops_metrics(
11401142
metrics_list=tflops_per_sec_all_devices, metrics_name="tflops_per_sec"
11411143
)
11421144
mfu_statistics = MetricsStatistics(metrics_list=mfu, metrics_name="MFU")
1145+
1146+
bw_print_str = ""
1147+
if total_bytes is not None:
1148+
bw_list = [(total_bytes / 1e9) / t_s for t_s in average_time_s_list]
1149+
bw_statistics = MetricsStatistics(
1150+
metrics_list=bw_list, metrics_name=bandwidth_metric_name
1151+
)
1152+
metrics.update(bw_statistics.serialize_statistics())
1153+
metadata["total_bytes"] = total_bytes
1154+
bw_print_str = f", Bandwidth (median): {bw_statistics.statistics['p50']:.2f} {bandwidth_metric_name}"
1155+
11431156
dtype_prefix = f"[{dtype}] " if dtype is not None else ""
11441157
print(
11451158
f"{dtype_prefix}"
11461159
f"Total floating-point ops: {total_flops}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, "
11471160
f"Throughput (median): {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOP / second / device, "
11481161
f"TotalThroughput (median): {tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} TFLOP / second, "
11491162
f"MFU: {mfu_statistics.statistics['p50']:.2%}"
1163+
f"{bw_print_str}"
11501164
)
11511165
# print()
11521166
# time_ms_list =
@@ -1260,15 +1274,15 @@ def get_peak_flops_multiplier(in_dtype_str: str) -> float:
12601274
(PEAK_FLOPS_PER_DEVICE) based on the input data type.
12611275
"""
12621276
in_dtype_lower = in_dtype_str.lower()
1263-
if in_dtype_lower == "fp8":
1277+
if in_dtype_lower in ("fp8", "float8_e4m3fn"):
12641278
# FP8 is 2x faster than BF16
12651279
# The baseline PEAK_FLOPS_PER_DEVICE is 1153.5 * 2 = 2307, which is FP8 peak.
12661280
# So the multiplier should be 1.0
12671281
return 1.0
1268-
elif in_dtype_lower == "bf16" or in_dtype_lower == "fp16":
1282+
elif in_dtype_lower in ("bf16", "bfloat16", "fp16", "float16"):
12691283
# BF16/FP16 is 2x slower than FP8 peak
12701284
return 0.5
1271-
elif in_dtype_lower == "fp32":
1285+
elif in_dtype_lower in ("fp32", "float32"):
12721286
# FP32 is 4x slower than FP8 peak
12731287
return 0.25
12741288
else:

Ironwood/src/run_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"gemm_multiple_run": "benchmark_gemm.gemm_multiple_run",
6262
"gemm_throttling": "benchmark_gemm_throttling.gemm_throttling",
6363
"gemm": "benchmark_gemm.gemm",
64+
"gemm_all_reduce": "benchmark_gemm_all_reduce.gemm_all_reduce",
6465
"gemm_accum": "benchmark_gemm.gemm_accum",
6566
"quantization": "benchmark_compute.quantization",
6667
"transpose_quantization": "benchmark_compute.transpose_quantization",

0 commit comments

Comments
 (0)