Skip to content

Commit 7b090f1

Browse files
committed
[Automation] Add BMM into automation script
1 parent ded42a6 commit 7b090f1

5 files changed

Lines changed: 85 additions & 1 deletion

File tree

Ironwood/guides/automation/aggregator.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@
2929
"tflops_per_sec_per_device_avg", "tflops_per_sec_per_device_min",
3030
"tflops_per_sec_per_device_max",
3131
],
32+
"bmm": [
33+
"b", "m", "n", "k", "dtype", "step_time_ms_num_runs",
34+
"tflops_per_sec_per_device_p50", "tflops_per_sec_per_device_p90",
35+
"tflops_per_sec_per_device_p95", "tflops_per_sec_per_device_p99",
36+
"tflops_per_sec_per_device_avg", "tflops_per_sec_per_device_min",
37+
"tflops_per_sec_per_device_max",
38+
],
3239
}
3340

3441
def download_from_gcs(bucket_path: str, local_dir: str):
@@ -86,15 +93,27 @@ def aggregate_gemm(directories: list[str], picked_columns: list[str]) -> pd.Data
8693
aggregated_df = pd.concat([aggregated_df, df[picked_columns].rename(columns={"step_time_ms_num_runs": "num_runs"})], ignore_index=True)
8794
return aggregated_df
8895

96+
def aggregate_bmm(directories: list[str], picked_columns: list[str]) -> pd.DataFrame:
97+
if len(directories) == 0:
98+
return None
99+
aggregated_df = pd.DataFrame()
100+
for directory in directories:
101+
files = glob.glob(f"{directory}/*.tsv")
102+
for file in files:
103+
df = pd.read_csv(file, sep='\t')
104+
aggregated_df = pd.concat([aggregated_df, df[picked_columns].rename(columns={"step_time_ms_num_runs": "num_runs"})], ignore_index=True)
105+
return aggregated_df
106+
89107
aggregate_function = {
90108
"collectives": aggregate_collectives,
91109
"hbm": aggregate_hbm,
92110
"host_device": aggregate_host_device,
93111
"gemm": aggregate_gemm,
112+
"bmm": aggregate_bmm,
94113
}
95114

96115
def aggregate_results(bucket_path: str, local_dir: str):
97-
categories = ["collectives", "hbm", "host_device", "gemm"]
116+
categories = ["collectives", "hbm", "host_device", "gemm", "bmm"]
98117
directories = {}
99118
results = {}
100119
for category in categories:

Ironwood/guides/automation/automation_launch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ yaml_names=(
1515
"tpu7x-2x2x1-hbm.yaml"
1616
"tpu7x-2x2x1-host_device.yaml"
1717
"tpu7x-2x2x1-gemm.yaml"
18+
"tpu7x-2x2x1-bmm.yaml"
1819
"tpu7x-2x2x1-collectives.yaml"
1920
"tpu7x-2x2x2-collectives.yaml"
2021
"tpu7x-2x2x4-collectives.yaml"
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
apiVersion: v1
2+
kind: Service
3+
metadata:
4+
name: headless-svc-${JOB_NAME}
5+
spec:
6+
clusterIP: None
7+
selector:
8+
job-name: ${JOB_NAME}
9+
---
10+
apiVersion: batch/v1
11+
kind: Job
12+
metadata:
13+
name: ${JOB_NAME}
14+
labels:
15+
kueue.x-k8s.io/queue-name: user-queue-2x2x1
16+
spec:
17+
completionMode: Indexed
18+
suspend: true
19+
parallelism: 1
20+
completions: 1
21+
backoffLimit: 0
22+
template:
23+
spec:
24+
subdomain: headless-svc-${JOB_NAME}
25+
serviceAccountName: ${GCS_SA_NAME}
26+
restartPolicy: Never
27+
nodeSelector:
28+
cloud.google.com/gke-tpu-accelerator: tpu7x
29+
cloud.google.com/gke-tpu-topology: 2x2x1
30+
containers:
31+
- name: jax-tpu
32+
image: python:3.12
33+
securityContext:
34+
privileged: false
35+
env:
36+
- name: JAX_PLATFORMS
37+
value: "tpu,cpu"
38+
- name: TPU_VMODULE
39+
value: "singleton_tpu_system_manager=10,tpu_version_flag=10,device_util=10,device_scanner=10,mesh_builder=10,master=10"
40+
- name: XLA_IR_DEBUG
41+
value: "1"
42+
- name: XLA_HLO_DEBUG
43+
value: "1"
44+
command:
45+
- bash
46+
- -c
47+
- |
48+
set -ex
49+
50+
git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git
51+
cd accelerator-microbenchmarks
52+
git checkout tpu7x-auto
53+
pip install -r requirements.txt
54+
55+
GCS_BUCKET_DIR=${GCS_PATH}
56+
python Ironwood/src/run_benchmark.py --config="Ironwood/configs/bmm/single_device_bmm.yaml" --gcs-bucket-csv-dir=${GCS_BUCKET_DIR}
57+
resources:
58+
requests:
59+
google.com/tpu: 4
60+
limits:
61+
google.com/tpu: 4

Ironwood/src/benchmark_bmm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,6 @@ def single_device_bmm_calculate_metrics(
131131
total_flops,
132132
total_flops_all_devices,
133133
PEAK_FLOPS_PER_DEVICE,
134+
dtype=dtype.dtype.name,
135+
b=b,
134136
)

Ironwood/src/benchmark_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,6 +1111,7 @@ def unified_flops_metrics(
11111111
total_flops_all_devices: int,
11121112
peak_TFLOPS_per_device: float,
11131113
dtype: str = None,
1114+
b: int = None,
11141115
) -> Dict[str, Any]:
11151116
"""Calculates the metrics for the naive matmul benchmark."""
11161117
# Build dictionary of all the parameters in the function

0 commit comments

Comments
 (0)