Skip to content

Commit 5dd6f85

Browse files
authored
Add bmm microbenchmark. (#97)
* [BMM] Add bmm microbenchmark * Update hook in benchmark entry. * Update BMM config * Update timeit logic
1 parent 30db8d0 commit 5dd6f85

3 files changed

Lines changed: 214 additions & 0 deletions

File tree

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
benchmarks:
2+
- benchmark_name: "single_device_bmm"
3+
trace_dir: "../microbenchmarks/single_device_bmm_bf16"
4+
csv_path: "../microbenchmarks/single_device_bmm_bf16"
5+
xlml_metrics_dir: "../microbenchmarks/single_device_bmm_bf16"
6+
xla_dump_dir: "../microbenchmarks/single_device_bmm_bf16/hlo_graphs"
7+
benchmark_sweep_params:
8+
- {b: 1, m: 128, k: 128, n: 128, num_runs: 100, dtype: 'bfloat16'}
9+
- {b: 1, m: 256, k: 256, n: 256, num_runs: 100, dtype: 'bfloat16'}
10+
- {b: 1, m: 512, k: 512, n: 512, num_runs: 100, dtype: 'bfloat16'}
11+
- {b: 1, m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'bfloat16'}
12+
- {b: 1, m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'bfloat16'}
13+
- {b: 1, m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'bfloat16'}
14+
- {b: 1, m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'bfloat16'}
15+
- {b: 1, m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'bfloat16'}
16+
17+
- benchmark_name: "single_device_bmm"
18+
trace_dir: "../microbenchmarks/single_device_bmm_f32"
19+
csv_path: "../microbenchmarks/single_device_bmm_f32"
20+
xlml_metrics_dir: "../microbenchmarks/single_device_bmm_f32"
21+
xla_dump_dir: "../microbenchmarks/single_device_bmm_f32/hlo_graphs"
22+
benchmark_sweep_params:
23+
- {b: 1, m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float32'}
24+
- {b: 1, m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float32'}
25+
- {b: 1, m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float32'}
26+
- {b: 1, m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float32'}
27+
- {b: 1, m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float32'}
28+
- {b: 1, m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float32'}
29+
- {b: 1, m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float32'}
30+
- {b: 1, m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float32'}
31+
32+
- benchmark_name: "single_device_bmm"
33+
trace_dir: "../microbenchmarks/single_device_bmm_fp16"
34+
csv_path: "../microbenchmarks/single_device_bmm_fp16"
35+
xlml_metrics_dir: "../microbenchmarks/single_device_bmm_fp16"
36+
xla_dump_dir: "../microbenchmarks/single_device_bmm_fp16/hlo_graphs"
37+
benchmark_sweep_params:
38+
- {b: 1, m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float16'}
39+
- {b: 1, m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float16'}
40+
- {b: 1, m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float16'}
41+
- {b: 1, m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float16'}
42+
- {b: 1, m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float16'}
43+
- {b: 1, m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float16'}
44+
- {b: 1, m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float16'}
45+
- {b: 1, m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float16'}
46+
47+
- benchmark_name: "single_device_bmm"
48+
trace_dir: "../microbenchmarks/single_device_bmm_fp8"
49+
csv_path: "../microbenchmarks/single_device_bmm_fp8"
50+
xlml_metrics_dir: "../microbenchmarks/single_device_bmm_fp8"
51+
xla_dump_dir: "../microbenchmarks/single_device_bmm_fp8/hlo_graphs"
52+
benchmark_sweep_params:
53+
- {b: 1, m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float8'}
54+
- {b: 1, m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float8'}
55+
- {b: 1, m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float8'}
56+
- {b: 1, m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float8'}
57+
- {b: 1, m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float8'}
58+
- {b: 1, m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float8'}
59+
- {b: 1, m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float8'}
60+
- {b: 1, m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float8'}
61+
62+
- benchmark_name: "single_device_bmm"
63+
trace_dir: "../microbenchmarks/single_device_bmm_fp4"
64+
csv_path: "../microbenchmarks/single_device_bmm_fp4"
65+
xlml_metrics_dir: "../microbenchmarks/single_device_bmm_fp4"
66+
xla_dump_dir: "../microbenchmarks/single_device_bmm_fp4/hlo_graphs"
67+
benchmark_sweep_params:
68+
- {b: 1, m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float4'}
69+
- {b: 1, m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float4'}
70+
- {b: 1, m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float4'}
71+
- {b: 1, m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float4'}
72+
- {b: 1, m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float4'}
73+
- {b: 1, m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float4'}
74+
- {b: 1, m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float4'}
75+
- {b: 1, m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float4'}

Ironwood/src/benchmark_bmm.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Benchmarks bmm in various flavors.
3+
Considered ops:
4+
1. bmm
5+
"""
6+
7+
import os
8+
from typing import Any, Dict
9+
10+
# pylint: disable=g-importing-member
11+
from benchmark_utils import (
12+
iteration_timeit,
13+
multiple_iteration_timeit_from_trace,
14+
ShardingStrategy,
15+
get_lhs_named_shading,
16+
get_rhs_named_shading,
17+
get_output_named_shading,
18+
get_out_sharding,
19+
create_mesh,
20+
handle_based_on_sharding,
21+
unified_flops_metrics,
22+
str_to_dtype,
23+
get_peak_flops_multiplier
24+
)
25+
from common import MARKER
26+
import jax
27+
from jax.experimental.shard_map import shard_map
28+
import jax.numpy as jnp
29+
from jax.sharding import NamedSharding
30+
from jax.sharding import PartitionSpec as P
31+
32+
33+
# pylint: disable=g-importing-member
34+
35+
os.environ["LIBTPU_INIT_ARGS"] = (
36+
"--xla_tpu_enable_async_collective_fusion=true "
37+
"--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true "
38+
"--xla_tpu_enable_async_collective_fusion_multiple_steps=true "
39+
"--xla_tpu_overlap_compute_collective_tc=true "
40+
"--xla_enable_async_all_gather=true "
41+
"--xla_enable_async_collective_permute=true "
42+
"--xla_tpu_enable_all_experimental_scheduler_features=true "
43+
"--xla_tpu_accumulate_into_mrb=true "
44+
"--xla_tpu_scoped_vmem_limit_kib=65536 "
45+
"--xla_tpu_vmem_scavenging_mode=NONE "
46+
"--xla_tpu_dvfs_p_state=7"
47+
)
48+
49+
TRACE_BASE_DIR = None
50+
METRICS_JSONL_DIR = None
51+
SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
52+
SEED = 0
53+
PEAK_FLOPS_PER_DEVICE = 2307 # TFLOP/s for single core(device) of FP8
54+
55+
def single_device_bmm(
56+
b: int,
57+
m: int,
58+
k: int,
59+
n: int,
60+
dtype: jnp.dtype = jax.numpy.float8_e4m3fn,
61+
num_runs: int = 1,
62+
trace_dir: str = None,
63+
) -> Dict[str, Any]:
64+
"""Benchmarks the OUT<B, M, N>:BF16 = IN0<B, M, K>:FP8 x IN1<B, K, N>:FP8. Accumulation is FP32."""
65+
66+
def f(x, y):
67+
with jax.named_scope(MARKER):
68+
acc = jax.numpy.einsum(
69+
"bij,bjk->bik", x, y, preferred_element_type=jnp.float32
70+
)
71+
return acc.astype(jnp.bfloat16)
72+
73+
jit_sharded_f = jax.jit(f)
74+
75+
lhs_shape = (b, m, k)
76+
rhs_shape = (b, k, n)
77+
78+
lhs_dtype = dtype
79+
rhs_dtype = dtype
80+
81+
key = jax.random.key(SEED)
82+
83+
def data_generator():
84+
"""Creates new random data on host and puts it on device."""
85+
nonlocal key # Use and update the outer 'key'
86+
key, key_lhs, key_rhs = jax.random.split(key, 3)
87+
88+
# Create random data on host
89+
lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype)
90+
rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype)
91+
92+
# Put on device (HBM)
93+
94+
return (lhs_host, rhs_host)
95+
96+
# Run the benchmark
97+
98+
# num_runs = 1
99+
100+
dtype_str = dtype.dtype.name
101+
time_ms_list = multiple_iteration_timeit_from_trace(
102+
jit_sharded_f,
103+
data_generator,
104+
matrix_dim=f"{dtype_str}_{b}x{m}x{n}x{k}",
105+
tries=num_runs,
106+
task="single_device_bmm",
107+
trace_dir=trace_dir,
108+
)
109+
110+
return {"time_ms_list": time_ms_list}
111+
112+
113+
def single_device_bmm_calculate_metrics(
114+
b: int,
115+
m: int,
116+
k: int,
117+
n: int,
118+
dtype: jnp.dtype,
119+
time_ms_list: list[float],
120+
) -> Dict[str, Any]:
121+
# Calculate FLOPs
122+
total_flops = 2 * b * m * k * n # Total floating-point operations
123+
total_flops, total_flops_all_devices = handle_based_on_sharding(
124+
total_flops, SHARDING_STRATEGY
125+
)
126+
return unified_flops_metrics(
127+
m,
128+
n,
129+
k,
130+
time_ms_list,
131+
total_flops,
132+
total_flops_all_devices,
133+
PEAK_FLOPS_PER_DEVICE,
134+
)

Ironwood/src/run_benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
"send_recv": "benchmark_send_recv.send_recv_benchmark",
3333
}
3434

35+
BMM_BENCHMARK_MAP = {
36+
"single_device_bmm": "benchmark_bmm.single_device_bmm",
37+
}
38+
3539
MATMUL_BENCHMARK_MAP = {
3640
"naive_matmul": "benchmark_matmul.naive_matmul",
3741
"single_host_naive_matmul": "benchmark_matmul.single_host_naive_matmul",
@@ -99,6 +103,7 @@
99103
"host_device": "benchmark_host_device.benchmark_host_device",
100104
}
101105
BENCHMARK_MAP = {}
106+
BENCHMARK_MAP.update(BMM_BENCHMARK_MAP)
102107
BENCHMARK_MAP.update(COLLECTIVE_BENCHMARK_MAP)
103108
BENCHMARK_MAP.update(MATMUL_BENCHMARK_MAP)
104109
BENCHMARK_MAP.update(CONVOLUTION_BENCHMARK_MAP)

0 commit comments

Comments
 (0)