Skip to content

Commit 1907768

Browse files
committed
[WIP] separate matmul and all reduce
1 parent 7df6677 commit 1907768

3 files changed

Lines changed: 262 additions & 13 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
benchmarks:
2+
- benchmark_name: gemm_only
3+
trace_dir: "../microbenchmarks/gemm_all_reduce"
4+
csv_path: "../microbenchmarks/gemm_all_reduce"
5+
xlml_metrics_dir: "../microbenchmarks/gemm_all_reduce"
6+
xla_dump_dir: "../microbenchmarks/gemm_all_reduce/hlo_graphs"
7+
num_runs: 10
8+
benchmark_params:
9+
- m: 4096
10+
k: 4096
11+
n: 4096
12+
dtype: "bfloat16"
13+
- benchmark_name: all_reduce_only
14+
trace_dir: "../microbenchmarks/gemm_all_reduce"
15+
csv_path: "../microbenchmarks/gemm_all_reduce"
16+
xlml_metrics_dir: "../microbenchmarks/gemm_all_reduce"
17+
xla_dump_dir: "../microbenchmarks/gemm_all_reduce/hlo_graphs"
18+
num_runs: 10
19+
benchmark_params:
20+
- m: 4096
21+
k: 4096 # Passed to maintain signature, though not used for shape of C
22+
n: 4096
23+
dtype: "bfloat16"
24+
- benchmark_name: gemm_all_reduce
25+
trace_dir: "../microbenchmarks/gemm_all_reduce"
26+
csv_path: "../microbenchmarks/gemm_all_reduce"
27+
xlml_metrics_dir: "../microbenchmarks/gemm_all_reduce"
28+
xla_dump_dir: "../microbenchmarks/gemm_all_reduce/hlo_graphs"
29+
num_runs: 10
30+
benchmark_params:
31+
- m: 4096
32+
k: 4096
33+
n: 4096
34+
dtype: "bfloat16"

Ironwood/src/benchmark_gemm_all_reduce.py

Lines changed: 226 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
MetricsStatistics,
1717
get_metrics_helper,
1818
str_to_dtype,
19-
get_peak_flops_multiplier
19+
get_peak_flops_multiplier,
20+
unified_bytes_metrics,
2021
)
2122
from common import MARKER
2223
import jax
@@ -178,18 +179,6 @@ def gemm_all_reduce_calculate_metrics(
178179
peak_flops_multiplier = get_peak_flops_multiplier(dtype_str)
179180
peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier
180181

181-
# Unified FLOPs metrics
182-
metadata, metrics = unified_flops_metrics(
183-
m,
184-
n,
185-
k,
186-
time_ms_list,
187-
total_flops_per_device,
188-
total_flops_all_devices,
189-
peak_flops,
190-
dtype=dtype_str,
191-
)
192-
193182
# Calculate Bandwidth for AllReduce
194183
# AllReduce moves Matrix C: M x N
195184
matrix_c_size_bytes = m * n * dtype.dtype.itemsize
@@ -201,3 +190,227 @@ def gemm_all_reduce_calculate_metrics(
201190
)
202191

203192
return metadata, metrics
193+
194+
195+
def gemm_only(
196+
m: int,
197+
k: int,
198+
n: int,
199+
dtype: jnp.dtype = jnp.bfloat16,
200+
num_runs: int = 1,
201+
trace_dir: str = None,
202+
) -> Dict[str, Any]:
203+
"""Benchmarks only the Matmul part of gemm_all_reduce.
204+
205+
A: [M, K]
206+
B: [K, N]
207+
C = A @ B: [M, N]
208+
"""
209+
210+
dtype_str = dtype.dtype.name
211+
print(f"Running gemm_only benchmark with m={m}, k={k}, n={n}, dtype={dtype_str}, runs={num_runs}")
212+
213+
def f(x, y):
214+
with jax.named_scope(MARKER):
215+
# Matmul
216+
acc = jax.numpy.einsum(
217+
"ij,jk->ik", x, y, preferred_element_type=jnp.float32
218+
)
219+
c = acc.astype(dtype)
220+
return c
221+
222+
mesh = create_mesh(SHARDING_STRATEGY)
223+
lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY)
224+
rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY)
225+
out_sharding = get_out_sharding(SHARDING_STRATEGY)
226+
227+
jit_sharded_f = jax.jit(
228+
shard_map(
229+
f,
230+
mesh,
231+
in_specs=(
232+
lhs_sharding.spec,
233+
rhs_sharding.spec,
234+
),
235+
out_specs=out_sharding,
236+
check_rep=False,
237+
)
238+
)
239+
240+
lhs_shape = (m, k)
241+
rhs_shape = (k, n)
242+
243+
lhs_dtype = dtype
244+
rhs_dtype = dtype
245+
246+
key = jax.random.key(SEED)
247+
248+
def data_generator():
249+
"""Creates new random data on host and puts it on device."""
250+
nonlocal key
251+
key, key_lhs, key_rhs = jax.random.split(key, 3)
252+
253+
# Create random data on host
254+
lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype)
255+
rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype)
256+
257+
# Put on device (HBM)
258+
lhs_device = jax.device_put(lhs_host, lhs_sharding)
259+
rhs_device = jax.device_put(rhs_host, rhs_sharding)
260+
261+
return (lhs_device, rhs_device)
262+
263+
time_ms_list = iteration_timeit(
264+
jit_sharded_f,
265+
data_generator,
266+
matrix_dim=f"{dtype_str}_{m}x{n}x{k}",
267+
tries=num_runs,
268+
task=f"gemm_only_{dtype_str}",
269+
trace_dir=trace_dir,
270+
)
271+
return {
272+
"time_ms_list": time_ms_list,
273+
}
274+
275+
276+
def gemm_only_calculate_metrics(
277+
m: int,
278+
k: int,
279+
n: int,
280+
dtype: jnp.dtype,
281+
time_ms_list: list[float],
282+
) -> Dict[str, Any]:
283+
# Calculate FLOPs (Matmul)
284+
total_flops = 2 * m * k * n
285+
286+
total_flops_per_device, total_flops_all_devices = handle_based_on_sharding(
287+
total_flops, SHARDING_STRATEGY
288+
)
289+
290+
dtype_str = dtype.dtype.name
291+
peak_flops_multiplier = get_peak_flops_multiplier(dtype_str)
292+
peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier
293+
294+
metadata, metrics = unified_flops_metrics(
295+
m, n, k, time_ms_list, total_flops_per_device, total_flops_all_devices, peak_flops, dtype=dtype_str,
296+
)
297+
298+
return metadata, metrics
299+
300+
301+
def all_reduce_only(
302+
m: int,
303+
k: int,
304+
n: int,
305+
dtype: jnp.dtype = jnp.bfloat16,
306+
num_runs: int = 1,
307+
trace_dir: str = None,
308+
) -> Dict[str, Any]:
309+
"""Benchmarks only the AllReduce part of gemm_all_reduce independently.
310+
311+
Input: C [M, N]
312+
Output = AllReduce(C)
313+
"""
314+
315+
dtype_str = dtype.dtype.name
316+
print(f"Running all_reduce_only benchmark with m={m}, k={k}, n={n}, dtype={dtype_str}, runs={num_runs}")
317+
318+
def f(c):
319+
with jax.named_scope(MARKER):
320+
# AllReduce (psum)
321+
out = jax.lax.psum(c, axis_name="device")
322+
return out
323+
324+
mesh = create_mesh(SHARDING_STRATEGY)
325+
# Input to AllReduce is the output of Matmul, which is C [M, N]
326+
# In gemm_all_reduce, C is effectively replicated or sharded depending on strategy,
327+
# but here SHARDING_STRATEGY is NO_SHARDING, so it's replicated?
328+
# Actually, let's double check gemm_all_reduce out_sharding.
329+
# out_sharding = get_out_sharding(SHARDING_STRATEGY) -> P(None, None) for NO_SHARDING
330+
331+
# So the input to THIS function should match the output of the GEMM part in gemm_all_reduce
332+
# In gemm_all_reduce:
333+
# f(x,y): ... return out
334+
# out_sharding is P(None, None).
335+
336+
# But wait, inside gemm_all_reduce's `f`, `c = acc.astype(dtype)`.
337+
# `c` is local to the device in shard_map terms if check_rep=False and in_specs are P(None, None).
338+
# Yes, `gemm_all_reduce` uses `in_specs=(lhs_sharding.spec, rhs_sharding.spec)`.
339+
# For NO_SHARDING, lhs_sharding is P(None, None), rhs is P(None, None).
340+
# So `c` is [M, N] per device.
341+
342+
# So here, we want input `c` to be P(None, None) per device.
343+
344+
input_sharding = get_out_sharding(SHARDING_STRATEGY) # Reusing this as it matched C's distribution
345+
out_sharding = get_out_sharding(SHARDING_STRATEGY)
346+
347+
jit_sharded_f = jax.jit(
348+
shard_map(
349+
f,
350+
mesh,
351+
in_specs=(input_sharding,),
352+
out_specs=out_sharding,
353+
check_rep=False,
354+
)
355+
)
356+
357+
# Shape of C
358+
c_shape = (m, n)
359+
c_dtype = dtype
360+
361+
key = jax.random.key(SEED)
362+
363+
def data_generator():
364+
"""Creates new random data on host and puts it on device."""
365+
nonlocal key
366+
key, key_c = jax.random.split(key, 2)
367+
368+
# Create random data on host
369+
c_host = jax.random.normal(key_c, c_shape).astype(c_dtype)
370+
371+
# Put on device (HBM)
372+
# We need to wrap input_sharding (which is a PartitionSpec) in NamedSharding
373+
# because device_put needs to know the mesh.
374+
named_input_sharding = jax.sharding.NamedSharding(mesh, input_sharding)
375+
c_device = jax.device_put(c_host, named_input_sharding)
376+
377+
return (c_device,)
378+
379+
time_ms_list = iteration_timeit(
380+
jit_sharded_f,
381+
data_generator,
382+
matrix_dim=f"{dtype_str}_{m}x{n}x{k}",
383+
tries=num_runs,
384+
task=f"all_reduce_only_{dtype_str}",
385+
trace_dir=trace_dir,
386+
)
387+
return {
388+
"time_ms_list": time_ms_list,
389+
}
390+
391+
392+
def all_reduce_only_calculate_metrics(
393+
m: int,
394+
k: int,
395+
n: int,
396+
dtype: jnp.dtype,
397+
time_ms_list: list[float],
398+
) -> Dict[str, Any]:
399+
400+
# Calculate Bandwidth for AllReduce
401+
# AllReduce moves Matrix C: M x N
402+
matrix_c_size_bytes = m * n * dtype.dtype.itemsize
403+
404+
# Use unified_bytes_metrics for bandwidth-bound operations
405+
# We estimate total_bytes_all_devices assuming full replication or reduction over all devices
406+
num_devices = jax.device_count()
407+
total_bytes_all_devices = matrix_c_size_bytes * num_devices
408+
409+
metadata, metrics = unified_bytes_metrics(
410+
m, n, time_ms_list,
411+
total_bytes=matrix_c_size_bytes,
412+
total_bytes_all_devices=total_bytes_all_devices,
413+
dtype=dtype.dtype.name
414+
)
415+
416+
return metadata, metrics

Ironwood/src/run_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
"gemm_throttling": "benchmark_gemm_throttling.gemm_throttling",
6363
"gemm": "benchmark_gemm.gemm",
6464
"gemm_all_reduce": "benchmark_gemm_all_reduce.gemm_all_reduce",
65+
"gemm_only": "benchmark_gemm_all_reduce.gemm_only",
66+
"all_reduce_only": "benchmark_gemm_all_reduce.all_reduce_only",
6567
"gemm_accum": "benchmark_gemm.gemm_accum",
6668
"quantization": "benchmark_compute.quantization",
6769
"transpose_quantization": "benchmark_compute.transpose_quantization",

0 commit comments

Comments
 (0)