Skip to content

Commit 252715f

Browse files
committed
Add PCIe transfer benchmark (H2D/D2H) to Ironwood
- Implement benchmark_pcie_transfer.py to measure H2D and D2H transfer performance using JAX, supporting various transfer modes (Standard, Parallel, Threaded, Chunked). - Integrate the new benchmark into run_benchmark.py. - Add configuration files for single device, single chip, and single VM topologies in configs/pcie_transfer/. - Add scripts/run_pcie_transfer_benchmark.sh for bulk execution with numactl interleaving option.
1 parent 1e6c308 commit 252715f

6 files changed

Lines changed: 376 additions & 0 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
benchmarks:
2+
- benchmark_name: host_device
3+
num_runs: 20
4+
benchmark_sweep_params:
5+
# Single Chip (1 Chip, 2 Devices)
6+
- {mesh_shape: "1x2", data_size_mb_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]}
7+
8+
csv_path: "../microbenchmarks/host_device/single_chip"
9+
trace_dir: "../microbenchmarks/host_device/single_chip/trace"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Host Device Microbenchmarks on tpu7x-2x2x1
2+
3+
This guide provides instructions for running Host Device (Host-to-Device and Device-to-Host) microbenchmarks on tpu7x-2x2x1 Google Kubernetes Engine (GKE) clusters. It covers creating a node pool, running the benchmarks, and viewing the output.
4+
5+
> [!NOTE]
6+
> This benchmark is currently a Work In Progress (WIP). Expected bandwidth numbers are not yet finalized.
7+
8+
## Create Node Pools
9+
10+
Follow [Setup section](../../Ironwood_Microbenchmarks_readme.md#setup) to create a GKE cluster with one 2x2x1 nodepool.
11+
12+
## Run Host Device Microbenchmarks
13+
14+
To run the microbenchmarks, apply the following Kubernetes configuration:
15+
```bash
16+
kubectl apply -f tpu7x-host-device-benchmark.yaml
17+
```
18+
19+
To extract the log of the microbenchmark, use `kubectl logs`:
20+
```bash
21+
kubectl logs tpu7x-host-device-benchmark
22+
```
23+
24+
Once the benchmark completes, you should see logs reporting bandwidth statistics.
25+
26+
To retrieve the complete results, including the trace and CSV output files, you must keep the pod running after the benchmark completes. To do this, add a `sleep` command to the `tpu7x-host-device-benchmark.yaml` file. You can then use `kubectl cp` to copy the output from the pod.
27+
28+
```bash
29+
kubectl cp tpu7x-host-device-benchmark:/microbenchmarks/host_device host_device
30+
```
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
apiVersion: v1
2+
kind: Pod
3+
metadata:
4+
name: tpu7x-pcie-transfer-benchmark
5+
spec:
6+
restartPolicy: Never
7+
nodeSelector:
8+
cloud.google.com/gke-tpu-accelerator: tpu7x
9+
cloud.google.com/gke-tpu-topology: 2x2x1
10+
containers:
11+
- name: tpu-job
12+
image: python:3.12
13+
ports:
14+
- containerPort: 8431
15+
securityContext:
16+
privileged: false
17+
command:
18+
- bash
19+
- -c
20+
- |
21+
set -ex
22+
23+
git clone https://github.com/AI-Hypercomputer/accelerator-microbenchmarks.git
24+
cd accelerator-microbenchmarks
25+
pip install -r requirements.txt
26+
27+
bash ./Ironwood/scripts/run_host_device_benchmark.sh
28+
29+
resources:
30+
requests:
31+
google.com/tpu: 4
32+
limits:
33+
google.com/tpu: 4
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/bin/bash
2+
3+
# Default values
4+
CONFIG_DIR="Ironwood/configs/host_device"
5+
SPECIFIC_CONFIG=""
6+
INTERLEAVED=false
7+
8+
# Helper function for usage
9+
usage() {
10+
echo "Usage: $0 [OPTIONS]"
11+
echo "Options:"
12+
echo " --config <path> Path to specific config file (optional)"
13+
echo " --interleaved Run with numactl --interleave=all"
14+
echo " --help Show this help message"
15+
exit 1
16+
}
17+
18+
# Parse arguments
19+
while [[ "$#" -gt 0 ]]; do
20+
case $1 in
21+
--config) SPECIFIC_CONFIG="$2"; shift ;;
22+
--interleaved) INTERLEAVED=true ;;
23+
--help) usage ;;
24+
*) echo "Unknown parameter passed: $1"; usage ;;
25+
esac
26+
shift
27+
done
28+
29+
echo "--- Starting Host-Device Transfer Benchmark (H2D/D2H) ---"
30+
echo "Note: This benchmark is work in progress"
31+
echo "Interleaved: $INTERLEAVED"
32+
33+
if [ -n "$SPECIFIC_CONFIG" ]; then
34+
CONFIGS=("$SPECIFIC_CONFIG")
35+
else
36+
# Use nullglob to handle case where no files match (though unlikely here)
37+
shopt -s nullglob
38+
CONFIGS=("$CONFIG_DIR"/*.yaml)
39+
shopt -u nullglob
40+
fi
41+
42+
if [ ${#CONFIGS[@]} -eq 0 ]; then
43+
echo "No configuration files found!"
44+
exit 1
45+
fi
46+
47+
for CONFIG_FILE in "${CONFIGS[@]}"; do
48+
echo "--- Running Config: $CONFIG_FILE ---"
49+
CMD="python Ironwood/src/run_benchmark.py --config=${CONFIG_FILE}"
50+
51+
if [ "$INTERLEAVED" = true ]; then
52+
if command -v numactl &> /dev/null; then
53+
echo "Running with numactl --interleave=all"
54+
numactl --interleave=all $CMD
55+
else
56+
echo "Warning: numactl not found. Running without interleaving."
57+
$CMD
58+
fi
59+
else
60+
$CMD
61+
fi
62+
echo "--- Finished Config: $CONFIG_FILE ---"
63+
echo ""
64+
done
65+
66+
echo "--- All Benchmarks Finished ---"
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Benchmarks Host-to-Device and Device-to-Host transfer performance."""
2+
3+
import concurrent.futures
4+
import gc
5+
import time
6+
import os
7+
from typing import Any, Dict, Tuple, List
8+
9+
import jax
10+
from jax import sharding
11+
import numpy as np
12+
from benchmark_utils import MetricsStatistics
13+
14+
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" # 64 GiB
15+
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
16+
17+
def get_tpu_devices(num_devices: int):
18+
devices = jax.devices()
19+
if len(devices) < num_devices:
20+
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
21+
return devices[:num_devices]
22+
23+
def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devices, chunks_per_device):
24+
# Smart Chunked H2D
25+
chk_h2d_start = time.perf_counter()
26+
total_workers = num_devices * chunks_per_device
27+
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
28+
chunked_futures = []
29+
for shard, dev in zip(host_shards, target_devices):
30+
sub_chunks = np.array_split(shard, chunks_per_device, axis=0)
31+
for chunk in sub_chunks:
32+
chunked_futures.append(
33+
executor.submit(jax.device_put, chunk, dev)
34+
)
35+
chunked_buffers = [f.result() for f in chunked_futures]
36+
for db in chunked_buffers:
37+
db.block_until_ready()
38+
chk_h2d_end = time.perf_counter()
39+
h2d_ms = (chk_h2d_end - chk_h2d_start) * 1000
40+
for db in chunked_buffers:
41+
db.delete()
42+
43+
# Smart Chunked D2H
44+
data_on_device = jax.device_put(host_data, data_sharding)
45+
data_on_device.block_until_ready()
46+
47+
chk_d2h_start = time.perf_counter()
48+
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
49+
d2h_futures = []
50+
for shard in data_on_device.addressable_shards:
51+
# Direct slicing on device array to avoid copy
52+
shard_len = shard.data.shape[0]
53+
chunk_size = (shard_len + chunks_per_device - 1) // chunks_per_device
54+
for i in range(chunks_per_device):
55+
start = i * chunk_size
56+
end = min((i + 1) * chunk_size, shard_len)
57+
if start < end:
58+
d2h_futures.append(
59+
executor.submit(jax.device_get, shard.data[start:end])
60+
)
61+
_ = [f.result() for f in d2h_futures]
62+
chk_d2h_end = time.perf_counter()
63+
d2h_ms = (chk_d2h_end - chk_d2h_start) * 1000
64+
data_on_device.delete()
65+
66+
return h2d_ms, d2h_ms
67+
68+
69+
def _run_warmup(host_data, data_sharding, data_size_mb):
70+
# --- ADAPTIVE WARM UP ---
71+
if data_size_mb <= 128:
72+
warmup_iters = 50
73+
elif data_size_mb >= 8192:
74+
warmup_iters = 3
75+
else:
76+
warmup_iters = 10
77+
78+
for _ in range(warmup_iters):
79+
data_on_device = jax.device_put(host_data, data_sharding)
80+
data_on_device.block_until_ready()
81+
_ = jax.device_get(data_on_device)
82+
data_on_device.delete()
83+
84+
gc.collect()
85+
86+
def _get_chunks_per_device(data_size_mb, num_devices):
87+
# --- SMART CHUNKING CONFIG ---
88+
target_chunk_size_mb = 16
89+
max_global_threads = 256
90+
91+
data_per_device_mb = data_size_mb / num_devices
92+
93+
if data_per_device_mb < target_chunk_size_mb:
94+
chunks_per_device = 1
95+
else:
96+
chunks_per_device = int(data_per_device_mb / target_chunk_size_mb)
97+
98+
total_threads = num_devices * chunks_per_device
99+
if total_threads > max_global_threads:
100+
chunks_per_device = max(1, int(max_global_threads / num_devices))
101+
102+
return chunks_per_device
103+
104+
105+
def benchmark_host_device(
106+
mesh_shape: str,
107+
data_size_mb: int,
108+
num_runs: int = 100,
109+
trace_dir: str = None,
110+
) -> Dict[str, Any]:
111+
"""Benchmarks H2D/D2H transfer using smart chunking."""
112+
dims = [int(d) for d in mesh_shape.split("x")]
113+
mesh_shape = tuple(dims)
114+
115+
num_devices = int(np.prod(mesh_shape))
116+
tpu_devices = get_tpu_devices(num_devices)
117+
118+
rows = 1024 * data_size_mb // np.dtype(np.float32).itemsize
119+
120+
host_data = np.ones((rows, 8, 128), dtype=np.float32)
121+
122+
print(
123+
f"Benchmarking Transfer with Data Size: {data_size_mb} MB on"
124+
f" {num_devices} devices for {num_runs} iterations"
125+
)
126+
127+
# Setup Mesh Sharding
128+
if len(mesh_shape) == 1:
129+
mesh = sharding.Mesh(
130+
np.array(tpu_devices).reshape(mesh_shape), axis_names=("x",)
131+
)
132+
data_sharding = sharding.NamedSharding(mesh, sharding.PartitionSpec("x"))
133+
else:
134+
mesh = sharding.Mesh(
135+
np.array(tpu_devices).reshape(mesh_shape), axis_names=("x", "y")
136+
)
137+
data_sharding = sharding.NamedSharding(
138+
mesh, sharding.PartitionSpec(("x", "y"))
139+
)
140+
141+
# --- ADAPTIVE WARM UP ---
142+
_run_warmup(host_data, data_sharding, data_size_mb)
143+
144+
# Pre-calculate sharding info
145+
dummy_put = jax.device_put(host_data[:num_devices], data_sharding)
146+
target_devices = [s.device for s in dummy_put.addressable_shards]
147+
dummy_put.delete()
148+
149+
host_shards = np.split(host_data, num_devices, axis=0)
150+
151+
# Performance Lists
152+
h2d_perf, d2h_perf = [], []
153+
154+
# --- SMART CHUNKING CONFIG ---
155+
chunks_per_device = _get_chunks_per_device(data_size_mb, num_devices)
156+
157+
# Profiling Context
158+
if trace_dir:
159+
profiler_context = jax.profiler.trace(trace_dir)
160+
else:
161+
# No-op context manager
162+
import contextlib
163+
profiler_context = contextlib.nullcontext()
164+
165+
with profiler_context:
166+
for i in range(num_runs):
167+
# Step Context
168+
if trace_dir:
169+
step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i)
170+
else:
171+
step_context = contextlib.nullcontext()
172+
173+
with step_context:
174+
# Optimized Chunked Transfer (Sole Strategy)
175+
h2d_ms, d2h_ms = _run_chunked(
176+
host_data, data_sharding, host_shards, target_devices,
177+
num_devices, chunks_per_device
178+
)
179+
h2d_perf.append(h2d_ms)
180+
d2h_perf.append(d2h_ms)
181+
182+
del host_data, host_shards
183+
gc.collect()
184+
185+
return {
186+
"H2D_Bandwidth": h2d_perf,
187+
"D2H_Bandwidth": d2h_perf,
188+
"Chunk_Count": chunks_per_device,
189+
"Thread_Count": num_devices * chunks_per_device,
190+
}
191+
192+
def benchmark_host_device_calculate_metrics(
193+
mesh_shape: str,
194+
data_size_mb: int,
195+
H2D_Bandwidth: List[float],
196+
D2H_Bandwidth: List[float],
197+
Chunk_Count: int,
198+
Thread_Count: int,
199+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
200+
"""Calculates metrics for Host-Device transfer."""
201+
params = locals().items()
202+
203+
# Filter out list params from metadata to avoid explosion
204+
metadata_keys = {"mesh_shape", "data_size_mb", "Chunk_Count", "Thread_Count"}
205+
metadata = {k: v for k, v in params if k in metadata_keys}
206+
207+
metrics = {}
208+
209+
def add_metric(name, ms_list):
210+
# Report Bandwidth (GiB/s)
211+
# Handle division by zero if ms is 0
212+
bw_list = [
213+
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
214+
for ms in ms_list
215+
]
216+
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
217+
metrics.update(stats_bw.serialize_statistics())
218+
219+
add_metric("H2D", H2D_Bandwidth)
220+
add_metric("D2H", D2H_Bandwidth)
221+
222+
return metadata, metrics

0 commit comments

Comments
 (0)