Skip to content

Commit ce217c7

Browse files
committed
Finalize native API host-device benchmark
1. Use 2D dimension to match memory layout 2. Use default device_get and device_put
1 parent a497d26 commit ce217c7

4 files changed

Lines changed: 23 additions & 44 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
benchmarks:
2+
- benchmark_name: host_device
3+
num_runs: 20
4+
benchmark_sweep_params:
5+
- {
6+
data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
7+
}
8+
csv_path: "../microbenchmarks/host_device"
9+
trace_dir: "../microbenchmarks/host_device/trace"

Ironwood/configs/host_device/host_device_single_chip.yaml

Lines changed: 0 additions & 11 deletions
This file was deleted.

Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ spec:
2424
cd accelerator-microbenchmarks
2525
pip install -r requirements.txt
2626
27-
export TPU_VISIBLE_CHIPS=0
28-
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml
27+
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device.yaml
2928
3029
resources:
3130
requests:

Ironwood/src/benchmark_host_device.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,28 @@
1818
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
1919
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
2020

21-
def get_tpu_devices(num_devices: int):
22-
devices = jax.devices()
23-
if len(devices) < num_devices:
24-
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
25-
return devices[:num_devices]
21+
2622

2723
def benchmark_host_device(
28-
num_devices: int,
29-
data_size_mb: int,
24+
data_size_mib: int,
3025
num_runs: int = 100,
3126
trace_dir: str = None,
3227
) -> Dict[str, Any]:
3328
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
34-
tpu_devices = get_tpu_devices(num_devices)
3529

36-
num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize
30+
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize
3731

3832
# Allocate Host Source Buffer
39-
host_data = np.random.normal(size=(num_elements,)).astype(np.float32)
33+
column = 128
34+
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)
4035

4136
print(
42-
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
43-
f" {num_devices} devices for {num_runs} iterations"
37+
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations"
4438
)
4539

46-
# Setup Mesh Sharding (1D)
47-
mesh = sharding.Mesh(
48-
np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",)
49-
)
50-
# Shard the 1D array across "x"
51-
partition_spec = sharding.PartitionSpec("x")
52-
53-
data_sharding = sharding.NamedSharding(mesh, partition_spec)
54-
5540
# Performance Lists
5641
h2d_perf, d2h_perf = [], []
57-
42+
5843
# Profiling Context
5944
import contextlib
6045
if trace_dir:
@@ -65,7 +50,7 @@ def benchmark_host_device(
6550
with profiler_context:
6651
# Warmup
6752
for _ in range(2):
68-
device_array = jax.device_put(host_data, data_sharding)
53+
device_array = jax.device_put(host_data)
6954
device_array.block_until_ready()
7055
host_out = np.array(device_array)
7156
device_array.delete()
@@ -83,15 +68,14 @@ def benchmark_host_device(
8368
t0 = time.perf_counter()
8469

8570
# Simple device_put
86-
device_array = jax.device_put(host_data, data_sharding)
71+
device_array = jax.device_put(host_data)
8772
device_array.block_until_ready()
8873

8974
t1 = time.perf_counter()
9075
h2d_perf.append((t1 - t0) * 1000)
9176

92-
# Verify H2D shape/sharding
77+
# Verify H2D shape
9378
assert device_array.shape == host_data.shape
94-
assert device_array.sharding == data_sharding
9579

9680
# D2H
9781
t2 = time.perf_counter()
@@ -111,19 +95,17 @@ def benchmark_host_device(
11195
}
11296

11397
def benchmark_host_device_calculate_metrics(
114-
num_devices: int,
115-
data_size_mb: int,
98+
data_size_mib: int,
11699
H2D_Bandwidth_ms: List[float],
117100
D2H_Bandwidth_ms: List[float],
118101
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
119102
"""Calculates metrics for Host-Device transfer."""
120103
params = locals().items()
121104

122-
data_size_mib = data_size_mb
105+
data_size_mib = data_size_mib
123106

124107
# Filter out list params from metadata to avoid explosion
125108
metadata_keys = {
126-
"num_devices",
127109
"data_size_mib",
128110
}
129111
metadata = {k: v for k, v in params if k in metadata_keys}
@@ -134,7 +116,7 @@ def add_metric(name, ms_list):
134116
# Report Bandwidth (GiB/s)
135117
# Handle division by zero if ms is 0
136118
bw_list = [
137-
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
119+
((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0
138120
for ms in ms_list
139121
]
140122
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")

0 commit comments

Comments
 (0)