Skip to content

Commit 4bd634e

Browse files
committed
Add new dtypes support for GEMM
Make sure `bfloat16`, `float32`, `float16`, `float8`, `float4` is supported for GEMM.
1 parent 8d1bf96 commit 4bd634e

3 files changed

Lines changed: 75 additions & 2 deletions

File tree

Ironwood/configs/training/gemm_multiple_run.yaml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,35 @@ benchmarks:
66
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_bf16/hlo_graphs"
77
benchmark_sweep_params:
88
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'}
9+
10+
- benchmark_name: "gemm_multiple_run"
11+
trace_dir: "../microbenchmarks/gemm_multiple_run_f32"
12+
csv_path: "../microbenchmarks/gemm_multiple_run_f32"
13+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_f32"
14+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_f32/hlo_graphs"
15+
benchmark_sweep_params:
16+
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float32'}
17+
18+
- benchmark_name: "gemm_multiple_run"
19+
trace_dir: "../microbenchmarks/gemm_multiple_run_fp16"
20+
csv_path: "../microbenchmarks/gemm_multiple_run_fp16"
21+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp16"
22+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp16/hlo_graphs"
23+
benchmark_sweep_params:
24+
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float16'}
25+
926
- benchmark_name: "gemm_multiple_run"
1027
trace_dir: "../microbenchmarks/gemm_multiple_run_fp8"
1128
csv_path: "../microbenchmarks/gemm_multiple_run_fp8"
1229
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp8"
1330
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp8/hlo_graphs"
1431
benchmark_sweep_params:
15-
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'}
32+
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'}
33+
34+
- benchmark_name: "gemm_multiple_run"
35+
trace_dir: "../microbenchmarks/gemm_multiple_run_fp4"
36+
csv_path: "../microbenchmarks/gemm_multiple_run_fp4"
37+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp4"
38+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp4/hlo_graphs"
39+
benchmark_sweep_params:
40+
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float4'}

Ironwood/configs/training/gemm_multiple_run_more.yaml

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,37 @@ benchmarks:
1313
- {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'bfloat16'}
1414
- {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'bfloat16'}
1515
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'bfloat16'}
16+
17+
- benchmark_name: "gemm_multiple_run"
18+
trace_dir: "../microbenchmarks/gemm_multiple_run_f32"
19+
csv_path: "../microbenchmarks/gemm_multiple_run_f32"
20+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_f32"
21+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_f32/hlo_graphs"
22+
benchmark_sweep_params:
23+
- {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float32'}
24+
- {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float32'}
25+
- {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float32'}
26+
- {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float32'}
27+
- {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float32'}
28+
- {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float32'}
29+
- {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float32'}
30+
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float32'}
31+
32+
- benchmark_name: "gemm_multiple_run"
33+
trace_dir: "../microbenchmarks/gemm_multiple_run_fp16"
34+
csv_path: "../microbenchmarks/gemm_multiple_run_fp16"
35+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp16"
36+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp16/hlo_graphs"
37+
benchmark_sweep_params:
38+
- {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float16'}
39+
- {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float16'}
40+
- {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float16'}
41+
- {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float16'}
42+
- {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float16'}
43+
- {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float16'}
44+
- {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float16'}
45+
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float16'}
46+
1647
- benchmark_name: "gemm_multiple_run"
1748
trace_dir: "../microbenchmarks/gemm_multiple_run_fp8"
1849
csv_path: "../microbenchmarks/gemm_multiple_run_fp8"
@@ -26,4 +57,19 @@ benchmarks:
2657
- {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float8'}
2758
- {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float8'}
2859
- {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float8'}
29-
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float8'}
60+
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float8'}
61+
62+
- benchmark_name: "gemm_multiple_run"
63+
trace_dir: "../microbenchmarks/gemm_multiple_run_fp4"
64+
csv_path: "../microbenchmarks/gemm_multiple_run_fp4"
65+
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp4"
66+
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp4/hlo_graphs"
67+
benchmark_sweep_params:
68+
- {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float4'}
69+
- {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float4'}
70+
- {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float4'}
71+
- {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float4'}
72+
- {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float4'}
73+
- {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float4'}
74+
- {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float4'}
75+
- {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float4'}

Ironwood/src/run_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@
114114
"float32": jax.numpy.float32,
115115
"int32": jax.numpy.int32,
116116
"float8": jax.numpy.float8_e4m3fn,
117+
"float16": jax.numpy.float16,
118+
"float4": jax.numpy.float4_e2m1fn,
117119
# Add other dtypes as needed
118120
}
119121

0 commit comments

Comments
 (0)