Skip to content

Commit 1fa3c99

Browse files
committed
Add PCIe transfer benchmark (H2D/D2H) to Ironwood
- Implement benchmark_pcie_transfer.py to measure H2D and D2H transfer performance using JAX, supporting various transfer modes (Standard, Parallel, Threaded, Chunked). - Integrate the new benchmark into run_benchmark.py. - Add configuration files for single device, single chip, and single VM topologies in configs/pcie_transfer/. - Add scripts/run_pcie_transfer_benchmark.sh for bulk execution with numactl interleaving option.
1 parent f0ac0db commit 1fa3c99

3 files changed

Lines changed: 127 additions & 54 deletions

File tree

Ironwood/guides/host_device/host_device.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
This guide provides instructions for running Host Device (Host-to-Device and Device-to-Host) microbenchmarks on tpu7x-2x2x1 Google Kubernetes Engine (GKE) clusters. It covers creating a node pool, running the benchmarks, and viewing the output.
44

5-
> [!WARNING]
5+
> [!NOTE]
66
> This benchmark is currently a Work In Progress (WIP). Expected bandwidth numbers are not yet finalized.
77
88
## Create Node Pools

Ironwood/scripts/run_host_device_benchmark.sh

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,8 @@ while [[ "$#" -gt 0 ]]; do
2727
done
2828

2929
echo "--- Starting Host-Device Transfer Benchmark (H2D/D2H) ---"
30-
echo "********************************************************"
31-
echo "WARNING: This benchmark is currently a WORK IN PROGRESS"
32-
echo "********************************************************"
33-
echo ""
34-
echo "Configuration:"
35-
echo " Interleaved: $INTERLEAVED"
36-
echo ""
30+
echo "Note: This benchmark is work in progress"
31+
echo "Interleaved: $INTERLEAVED"
3732

3833
if [ -n "$SPECIFIC_CONFIG" ]; then
3934
CONFIGS=("$SPECIFIC_CONFIG")

Ironwood/src/benchmark_host_device.py

Lines changed: 124 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
1+
"""Benchmarks Host-to-Device and Device-to-Host transfer performance."""
22

3+
import concurrent.futures
4+
import gc
35
import time
46
import os
57
from typing import Any, Dict, Tuple, List
@@ -9,8 +11,7 @@
911
import numpy as np
1012
from benchmark_utils import MetricsStatistics
1113

12-
# 64 GiB
13-
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
14+
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" # 64 GiB
1415
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
1516

1617
def get_tpu_devices(num_devices: int):
@@ -19,26 +20,107 @@ def get_tpu_devices(num_devices: int):
1920
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
2021
return devices[:num_devices]
2122

23+
def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devices, chunks_per_device):
24+
# Smart Chunked H2D
25+
chk_h2d_start = time.perf_counter()
26+
total_workers = num_devices * chunks_per_device
27+
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
28+
chunked_futures = []
29+
for shard, dev in zip(host_shards, target_devices):
30+
sub_chunks = np.array_split(shard, chunks_per_device, axis=0)
31+
for chunk in sub_chunks:
32+
chunked_futures.append(
33+
executor.submit(jax.device_put, chunk, dev)
34+
)
35+
chunked_buffers = [f.result() for f in chunked_futures]
36+
for db in chunked_buffers:
37+
db.block_until_ready()
38+
chk_h2d_end = time.perf_counter()
39+
h2d_ms = (chk_h2d_end - chk_h2d_start) * 1000
40+
for db in chunked_buffers:
41+
db.delete()
42+
43+
# Smart Chunked D2H
44+
data_on_device = jax.device_put(host_data, data_sharding)
45+
data_on_device.block_until_ready()
46+
47+
chk_d2h_start = time.perf_counter()
48+
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
49+
d2h_futures = []
50+
for shard in data_on_device.addressable_shards:
51+
# Direct slicing on device array to avoid copy
52+
shard_len = shard.data.shape[0]
53+
chunk_size = (shard_len + chunks_per_device - 1) // chunks_per_device
54+
for i in range(chunks_per_device):
55+
start = i * chunk_size
56+
end = min((i + 1) * chunk_size, shard_len)
57+
if start < end:
58+
d2h_futures.append(
59+
executor.submit(jax.device_get, shard.data[start:end])
60+
)
61+
_ = [f.result() for f in d2h_futures]
62+
chk_d2h_end = time.perf_counter()
63+
d2h_ms = (chk_d2h_end - chk_d2h_start) * 1000
64+
data_on_device.delete()
65+
66+
return h2d_ms, d2h_ms
67+
68+
69+
def _run_warmup(host_data, data_sharding, data_size_mb):
70+
# --- ADAPTIVE WARM UP ---
71+
if data_size_mb <= 128:
72+
warmup_iters = 50
73+
elif data_size_mb >= 8192:
74+
warmup_iters = 3
75+
else:
76+
warmup_iters = 10
77+
78+
for _ in range(warmup_iters):
79+
data_on_device = jax.device_put(host_data, data_sharding)
80+
data_on_device.block_until_ready()
81+
_ = jax.device_get(data_on_device)
82+
data_on_device.delete()
83+
84+
gc.collect()
85+
86+
def _get_chunks_per_device(data_size_mb, num_devices):
87+
# --- SMART CHUNKING CONFIG ---
88+
target_chunk_size_mb = 16
89+
max_global_threads = 256
90+
91+
data_per_device_mb = data_size_mb / num_devices
92+
93+
if data_per_device_mb < target_chunk_size_mb:
94+
chunks_per_device = 1
95+
else:
96+
chunks_per_device = int(data_per_device_mb / target_chunk_size_mb)
97+
98+
total_threads = num_devices * chunks_per_device
99+
if total_threads > max_global_threads:
100+
chunks_per_device = max(1, int(max_global_threads / num_devices))
101+
102+
return chunks_per_device
103+
104+
22105
def benchmark_host_device(
23106
mesh_shape: str,
24107
data_size_mb: int,
25108
num_runs: int = 100,
26109
trace_dir: str = None,
27110
) -> Dict[str, Any]:
28-
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
111+
"""Benchmarks H2D/D2H transfer using smart chunking."""
29112
dims = [int(d) for d in mesh_shape.split("x")]
30113
mesh_shape = tuple(dims)
31114

32115
num_devices = int(np.prod(mesh_shape))
33116
tpu_devices = get_tpu_devices(num_devices)
34117

35-
num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize
118+
rows = 1024 * data_size_mb // np.dtype(np.float32).itemsize
36119

37-
# Allocate Host Source Buffer
38-
host_data = np.ones((num_elements,), dtype=np.float32)
120+
host_data = np.ones((rows, 8, 128), dtype=np.float32)
39121

40122
print(
41-
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
123+
f"Benchmarking Transfer with Data Size: {data_size_mb} MB on"
42124
f" {num_devices} devices for {num_runs} iterations"
43125
)
44126

@@ -47,25 +129,37 @@ def benchmark_host_device(
47129
mesh = sharding.Mesh(
48130
np.array(tpu_devices).reshape(mesh_shape), axis_names=("x",)
49131
)
50-
# Shard the 1D array across "x"
51-
partition_spec = sharding.PartitionSpec("x")
132+
data_sharding = sharding.NamedSharding(mesh, sharding.PartitionSpec("x"))
52133
else:
53134
mesh = sharding.Mesh(
54135
np.array(tpu_devices).reshape(mesh_shape), axis_names=("x", "y")
55136
)
56-
# Shard the 1D array across BOTH "x" and "y" (product sharding)
57-
partition_spec = sharding.PartitionSpec(("x", "y"))
58-
59-
data_sharding = sharding.NamedSharding(mesh, partition_spec)
137+
data_sharding = sharding.NamedSharding(
138+
mesh, sharding.PartitionSpec(("x", "y"))
139+
)
60140

141+
# --- ADAPTIVE WARM UP ---
142+
_run_warmup(host_data, data_sharding, data_size_mb)
143+
144+
# Pre-calculate sharding info
145+
dummy_put = jax.device_put(host_data[:num_devices], data_sharding)
146+
target_devices = [s.device for s in dummy_put.addressable_shards]
147+
dummy_put.delete()
148+
149+
host_shards = np.split(host_data, num_devices, axis=0)
150+
61151
# Performance Lists
62152
h2d_perf, d2h_perf = [], []
153+
154+
# --- SMART CHUNKING CONFIG ---
155+
chunks_per_device = _get_chunks_per_device(data_size_mb, num_devices)
63156

64157
# Profiling Context
65-
import contextlib
66158
if trace_dir:
67159
profiler_context = jax.profiler.trace(trace_dir)
68160
else:
161+
# No-op context manager
162+
import contextlib
69163
profiler_context = contextlib.nullcontext()
70164

71165
with profiler_context:
@@ -77,53 +171,37 @@ def benchmark_host_device(
77171
step_context = contextlib.nullcontext()
78172

79173
with step_context:
80-
# H2D
81-
t0 = time.perf_counter()
82-
83-
# Simple device_put
84-
device_array = jax.device_put(host_data, data_sharding)
85-
device_array.block_until_ready()
86-
87-
t1 = time.perf_counter()
88-
h2d_perf.append((t1 - t0) * 1000)
89-
90-
# Verify H2D shape/sharding
91-
assert device_array.shape == host_data.shape
92-
assert device_array.sharding == data_sharding
93-
94-
# D2H
95-
t2 = time.perf_counter()
96-
97-
# Simple device_get
98-
# Note: device_get returns a numpy array (copy)
99-
_ = jax.device_get(device_array)
100-
101-
t3 = time.perf_counter()
102-
d2h_perf.append((t3 - t2) * 1000)
103-
104-
device_array.delete()
174+
# Optimized Chunked Transfer (Sole Strategy)
175+
h2d_ms, d2h_ms = _run_chunked(
176+
host_data, data_sharding, host_shards, target_devices,
177+
num_devices, chunks_per_device
178+
)
179+
h2d_perf.append(h2d_ms)
180+
d2h_perf.append(d2h_ms)
181+
182+
del host_data, host_shards
183+
gc.collect()
105184

106185
return {
107186
"H2D_Bandwidth": h2d_perf,
108187
"D2H_Bandwidth": d2h_perf,
188+
"Chunk_Count": chunks_per_device,
189+
"Thread_Count": num_devices * chunks_per_device,
109190
}
110191

111192
def benchmark_host_device_calculate_metrics(
112193
mesh_shape: str,
113194
data_size_mb: int,
114195
H2D_Bandwidth: List[float],
115196
D2H_Bandwidth: List[float],
197+
Chunk_Count: int,
198+
Thread_Count: int,
116199
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
117200
"""Calculates metrics for Host-Device transfer."""
118201
params = locals().items()
119202

120-
data_size_mib = data_size_mb
121-
122203
# Filter out list params from metadata to avoid explosion
123-
metadata_keys = {
124-
"mesh_shape",
125-
"data_size_mib",
126-
}
204+
metadata_keys = {"mesh_shape", "data_size_mb", "Chunk_Count", "Thread_Count"}
127205
metadata = {k: v for k, v in params if k in metadata_keys}
128206

129207
metrics = {}

0 commit comments

Comments
 (0)