diff --git a/Ironwood/configs/training/gemm_multiple_run.yaml b/Ironwood/configs/training/gemm_multiple_run.yaml index 7730720d..cc481b53 100644 --- a/Ironwood/configs/training/gemm_multiple_run.yaml +++ b/Ironwood/configs/training/gemm_multiple_run.yaml @@ -6,10 +6,35 @@ benchmarks: xla_dump_dir: "../microbenchmarks/gemm_multiple_run_bf16/hlo_graphs" benchmark_sweep_params: - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_f32" + csv_path: "../microbenchmarks/gemm_multiple_run_f32" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_f32" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_f32/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float32'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_fp16" + csv_path: "../microbenchmarks/gemm_multiple_run_fp16" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp16" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp16/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float16'} + - benchmark_name: "gemm_multiple_run" trace_dir: "../microbenchmarks/gemm_multiple_run_fp8" csv_path: "../microbenchmarks/gemm_multiple_run_fp8" xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp8" xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp8/hlo_graphs" benchmark_sweep_params: - - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'} \ No newline at end of file + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_fp4" + csv_path: "../microbenchmarks/gemm_multiple_run_fp4" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp4" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp4/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float4'} \ No newline at end of file diff --git a/Ironwood/configs/training/gemm_multiple_run_more.yaml b/Ironwood/configs/training/gemm_multiple_run_more.yaml index 7e32ee89..ea89f98b 100644 --- a/Ironwood/configs/training/gemm_multiple_run_more.yaml +++ b/Ironwood/configs/training/gemm_multiple_run_more.yaml @@ -13,6 +13,37 @@ benchmarks: - {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'bfloat16'} - {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'bfloat16'} - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'bfloat16'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_f32" + csv_path: "../microbenchmarks/gemm_multiple_run_f32" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_f32" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_f32/hlo_graphs" + benchmark_sweep_params: + - {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float32'} + - {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float32'} + - {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float32'} + - {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float32'} + - {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float32'} + - {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float32'} + - {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float32'} + - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float32'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_fp16" + csv_path: "../microbenchmarks/gemm_multiple_run_fp16" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp16" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp16/hlo_graphs" + benchmark_sweep_params: + - {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float16'} + - {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float16'} + - {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float16'} + - {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float16'} + - {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float16'} + - {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float16'} + - {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float16'} + - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float16'} + - benchmark_name: "gemm_multiple_run" trace_dir: "../microbenchmarks/gemm_multiple_run_fp8" csv_path: "../microbenchmarks/gemm_multiple_run_fp8" @@ -26,4 +57,19 @@ benchmarks: - {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float8'} - {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float8'} - {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float8'} - - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float8'} \ No newline at end of file + - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float8'} + +- benchmark_name: "gemm_multiple_run" + trace_dir: "../microbenchmarks/gemm_multiple_run_fp4" + csv_path: "../microbenchmarks/gemm_multiple_run_fp4" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp4" + xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp4/hlo_graphs" + benchmark_sweep_params: + - {m: 128, k: 128, n: 128, num_runs: 100, dtype: 'float4'} + - {m: 256, k: 256, n: 256, num_runs: 100, dtype: 'float4'} + - {m: 512, k: 512, n: 512, num_runs: 100, dtype: 'float4'} + - {m: 1024, k: 1024, n: 1024, num_runs: 100, dtype: 'float4'} + - {m: 2048, k: 2048, n: 2048, num_runs: 100, dtype: 'float4'} + - {m: 4096, k: 4096, n: 4096, num_runs: 100, dtype: 'float4'} + - {m: 16384, k: 16384, n: 16384, num_runs: 100, dtype: 'float4'} + - {m: 32768, k: 32768, n: 32768, num_runs: 100, dtype: 'float4'} \ No newline at end of file diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c8c27bbe..e653fde1 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -120,7 +120,7 @@ def data_generator(): # Run the benchmark print("Running gemm_multiple_run benchmark", num_runs) - dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16" + dtype_str = dtype.dtype.name time_ms_list = multiple_iteration_timeit_from_trace( jit_sharded_f, data_generator, diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 2b703487..b44aab75 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -114,6 +114,8 @@ "float32": jax.numpy.float32, "int32": jax.numpy.int32, "float8": jax.numpy.float8_e4m3fn, + "float16": jax.numpy.float16, + "float4": jax.numpy.float4_e2m1fn, # Add other dtypes as needed }