Skip to content

Commit c8eafc0

Browse files
committed
Implement simple baseline using device_put/get
1 parent 1e6c308 commit c8eafc0

6 files changed

Lines changed: 308 additions & 0 deletions

File tree

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
benchmarks:
2+
- benchmark_name: host_device
3+
num_runs: 20
4+
benchmark_sweep_params:
5+
# Single Chip (1 Chip, 2 Devices)
6+
- {
7+
num_devices: 2,
8+
data_size_mb_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
9+
}
10+
csv_path: "../microbenchmarks/host_device/single_chip"
11+
trace_dir: "../microbenchmarks/host_device/single_chip/trace"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Host Device Microbenchmarks on tpu7x-2x2x1
2+
3+
This guide provides instructions for running Host Device (Host-to-Device and Device-to-Host) microbenchmarks on tpu7x-2x2x1 Google Kubernetes Engine (GKE) clusters. It covers creating a node pool, running the benchmarks, and viewing the output.
4+
5+
> [!WARNING]
6+
> This benchmark is currently a Work In Progress (WIP). Expected bandwidth numbers are not yet finalized.
7+
8+
## Create Node Pools
9+
10+
Follow [Setup section](../../Ironwood_Microbenchmarks_readme.md#setup) to create a GKE cluster with one 2x2x1 nodepool.
11+
12+
## Run Host Device Microbenchmarks
13+
14+
To run the microbenchmarks, apply the following Kubernetes configuration:
15+
```bash
16+
kubectl apply -f tpu7x-host-device-benchmark.yaml
17+
```
18+
19+
To extract the log of the microbenchmark, use `kubectl logs`:
20+
```bash
21+
kubectl logs tpu7x-host-device-benchmark
22+
```
23+
24+
Once the benchmark completes, you should see logs reporting bandwidth statistics.
25+
26+
To retrieve the complete results, including the trace and CSV output files, you must keep the pod running after the benchmark completes. To do this, add a `sleep` command to the `tpu7x-host-device-benchmark.yaml` file. You can then use `kubectl cp` to copy the output from the pod.
27+
28+
```bash
29+
kubectl cp tpu7x-host-device-benchmark:/microbenchmarks/host_device host_device
30+
```
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
apiVersion: v1
2+
kind: Pod
3+
metadata:
4+
name: tpu7x-host-device-benchmark
5+
spec:
6+
restartPolicy: Never
7+
nodeSelector:
8+
cloud.google.com/gke-tpu-accelerator: tpu7x
9+
cloud.google.com/gke-tpu-topology: 2x2x1
10+
containers:
11+
- name: tpu-job
12+
image: python:3.12
13+
ports:
14+
- containerPort: 8431
15+
securityContext:
16+
privileged: false
17+
command:
18+
- bash
19+
- -c
20+
- |
21+
set -ex
22+
23+
git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git
24+
cd accelerator-microbenchmarks
25+
pip install -r requirements.txt
26+
27+
export TPU_VISIBLE_CHIPS=0
28+
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml
29+
30+
resources:
31+
requests:
32+
google.com/tpu: 4
33+
limits:
34+
google.com/tpu: 4
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
3+
# Default values
4+
CONFIG_DIR="Ironwood/configs/host_device"
5+
SPECIFIC_CONFIG=""
6+
INTERLEAVED=false
7+
8+
# Helper function for usage
9+
usage() {
10+
echo "Usage: $0 [OPTIONS]"
11+
echo "Options:"
12+
echo " --config <path> Path to specific config file (optional)"
13+
echo " --interleaved Run with numactl --interleave=all"
14+
echo " --help Show this help message"
15+
exit 1
16+
}
17+
18+
# Parse arguments
19+
while [[ "$#" -gt 0 ]]; do
20+
case $1 in
21+
--config) SPECIFIC_CONFIG="$2"; shift ;;
22+
--interleaved) INTERLEAVED=true ;;
23+
--help) usage ;;
24+
*) echo "Unknown parameter passed: $1"; usage ;;
25+
esac
26+
shift
27+
done
28+
29+
echo "--- Starting Host-Device Transfer Benchmark (H2D/D2H) ---"
30+
echo "********************************************************"
31+
echo "WARNING: This benchmark is currently a WORK IN PROGRESS"
32+
echo "********************************************************"
33+
echo ""
34+
echo "Configuration:"
35+
echo " Interleaved: $INTERLEAVED"
36+
echo ""
37+
38+
if [ -n "$SPECIFIC_CONFIG" ]; then
39+
CONFIGS=("$SPECIFIC_CONFIG")
40+
else
41+
# Use nullglob to handle case where no files match (though unlikely here)
42+
shopt -s nullglob
43+
CONFIGS=("$CONFIG_DIR"/*.yaml)
44+
shopt -u nullglob
45+
fi
46+
47+
if [ ${#CONFIGS[@]} -eq 0 ]; then
48+
echo "No configuration files found!"
49+
exit 1
50+
fi
51+
52+
for CONFIG_FILE in "${CONFIGS[@]}"; do
53+
echo "--- Running Config: $CONFIG_FILE ---"
54+
CMD="python Ironwood/src/run_benchmark.py --config=${CONFIG_FILE}"
55+
56+
if [ "$INTERLEAVED" = true ]; then
57+
if command -v numactl &> /dev/null; then
58+
echo "Running with numactl --interleave=all"
59+
numactl --interleave=all $CMD
60+
else
61+
echo "Warning: numactl not found. Running without interleaving."
62+
$CMD
63+
fi
64+
else
65+
$CMD
66+
fi
67+
echo "--- Finished Config: $CONFIG_FILE ---"
68+
echo ""
69+
done
70+
71+
echo "--- All Benchmarks Finished ---"
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
2+
3+
import time
4+
import os
5+
from typing import Any, Dict, Tuple, List
6+
7+
import jax
8+
from jax import sharding
9+
import numpy as np
10+
from benchmark_utils import MetricsStatistics
11+
12+
13+
libtpu_init_args = [
14+
"--xla_tpu_dvfs_p_state=7",
15+
]
16+
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
17+
# 64 GiB
18+
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
19+
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
20+
21+
def get_tpu_devices(num_devices: int):
22+
devices = jax.devices()
23+
if len(devices) < num_devices:
24+
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
25+
return devices[:num_devices]
26+
27+
def benchmark_host_device(
28+
num_devices: int,
29+
data_size_mb: int,
30+
num_runs: int = 100,
31+
trace_dir: str = None,
32+
) -> Dict[str, Any]:
33+
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
34+
tpu_devices = get_tpu_devices(num_devices)
35+
36+
num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize
37+
38+
# Allocate Host Source Buffer
39+
host_data = np.ones((num_elements,), dtype=np.float32)
40+
41+
print(
42+
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
43+
f" {num_devices} devices for {num_runs} iterations"
44+
)
45+
46+
# Setup Mesh Sharding (1D)
47+
mesh = sharding.Mesh(
48+
np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",)
49+
)
50+
# Shard the 1D array across "x"
51+
partition_spec = sharding.PartitionSpec("x")
52+
53+
data_sharding = sharding.NamedSharding(mesh, partition_spec)
54+
55+
# Performance Lists
56+
h2d_perf, d2h_perf = [], []
57+
58+
# Profiling Context
59+
import contextlib
60+
if trace_dir:
61+
profiler_context = jax.profiler.trace(trace_dir)
62+
else:
63+
profiler_context = contextlib.nullcontext()
64+
65+
with profiler_context:
66+
# Warmup
67+
for _ in range(2):
68+
device_array = jax.device_put(host_data, data_sharding)
69+
device_array.block_until_ready()
70+
host_out = np.array(device_array)
71+
device_array.delete()
72+
del host_out
73+
74+
for i in range(num_runs):
75+
# Step Context
76+
if trace_dir:
77+
step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i)
78+
else:
79+
step_context = contextlib.nullcontext()
80+
81+
with step_context:
82+
# H2D
83+
t0 = time.perf_counter()
84+
85+
# Simple device_put
86+
device_array = jax.device_put(host_data, data_sharding)
87+
device_array.block_until_ready()
88+
89+
t1 = time.perf_counter()
90+
h2d_perf.append((t1 - t0) * 1000)
91+
92+
# Verify H2D shape/sharding
93+
assert device_array.shape == host_data.shape
94+
assert device_array.sharding == data_sharding
95+
96+
# D2H
97+
t2 = time.perf_counter()
98+
99+
# Simple device_get
100+
# Note: device_get returns a numpy array (copy)
101+
_ = jax.device_get(device_array)
102+
103+
t3 = time.perf_counter()
104+
d2h_perf.append((t3 - t2) * 1000)
105+
106+
device_array.delete()
107+
108+
return {
109+
"H2D_Bandwidth_ms": h2d_perf,
110+
"D2H_Bandwidth_ms": d2h_perf,
111+
}
112+
113+
def benchmark_host_device_calculate_metrics(
114+
num_devices: int,
115+
data_size_mb: int,
116+
H2D_Bandwidth_ms: List[float],
117+
D2H_Bandwidth_ms: List[float],
118+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
119+
"""Calculates metrics for Host-Device transfer."""
120+
params = locals().items()
121+
122+
data_size_mib = data_size_mb
123+
124+
# Filter out list params from metadata to avoid explosion
125+
metadata_keys = {
126+
"num_devices",
127+
"data_size_mib",
128+
}
129+
metadata = {k: v for k, v in params if k in metadata_keys}
130+
131+
metrics = {}
132+
133+
def add_metric(name, ms_list):
134+
# Report Bandwidth (GiB/s)
135+
# Handle division by zero if ms is 0
136+
bw_list = [
137+
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
138+
for ms in ms_list
139+
]
140+
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
141+
metrics.update(stats_bw.serialize_statistics())
142+
143+
add_metric("H2D", H2D_Bandwidth_ms)
144+
add_metric("D2H", D2H_Bandwidth_ms)
145+
146+
return metadata, metrics

Ironwood/src/run_benchmark.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,17 @@
9595
"inference_silu_mul": "benchmark_inference_compute.silu_mul",
9696
"inference_sigmoid": "benchmark_inference_compute.sigmoid",
9797
}
98+
HOST_DEVICE_BENCHMARK_MAP = {
99+
"host_device": "benchmark_host_device.benchmark_host_device",
100+
}
98101
BENCHMARK_MAP = {}
99102
BENCHMARK_MAP.update(COLLECTIVE_BENCHMARK_MAP)
100103
BENCHMARK_MAP.update(MATMUL_BENCHMARK_MAP)
101104
BENCHMARK_MAP.update(CONVOLUTION_BENCHMARK_MAP)
102105
BENCHMARK_MAP.update(ATTENTION_BENCHMARK_MAP)
103106
BENCHMARK_MAP.update(HBM_BENCHMARK_MAP)
104107
BENCHMARK_MAP.update(COMPUTE_BENCHMARK_MAP)
108+
BENCHMARK_MAP.update(HOST_DEVICE_BENCHMARK_MAP)
105109

106110

107111
# Mapping from dtype string to actual dtype object
@@ -326,6 +330,12 @@ def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str):
326330
# csv_path = os.path.join(output_path, benchmark_name)
327331
trace_dir = os.path.join(output_path, benchmark_name, "trace")
328332
xla_dump_dir = os.path.join(output_path, benchmark_name, "hlo_graphs")
333+
# Inject num_runs from config if not present in params
334+
global_num_runs = benchmark_config.get("num_runs")
335+
if global_num_runs is not None:
336+
for param in benchmark_params:
337+
if "num_runs" not in param:
338+
param["num_runs"] = global_num_runs
329339

330340
if not benchmark_name:
331341
raise ValueError("Each benchmark must have a 'benchmark_name'.")
@@ -467,6 +477,12 @@ def run_benchmark_multithreaded(benchmark_config, output_path):
467477
if output_path != "":
468478
csv_path = os.path.join(output_path, benchmark_name)
469479
os.makedirs(csv_path, exist_ok=True)
480+
# Inject num_runs from config if not present in params
481+
global_num_runs = benchmark_config.get("num_runs")
482+
if global_num_runs is not None:
483+
for param in benchmark_params:
484+
if "num_runs" not in param:
485+
param["num_runs"] = global_num_runs
470486

471487
# Get the benchmark function
472488
benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name)

0 commit comments

Comments
 (0)