Skip to content

Commit e3cf453

Browse files
authored
Finalize native API host-device benchmark (#82)
1. Use 2D dimension to match memory layout 2. Use default device_get and device_put
1 parent d8f04a8 commit e3cf453

4 files changed

Lines changed: 21 additions & 45 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
benchmarks:
2+
- benchmark_name: host_device
3+
num_runs: 20
4+
benchmark_sweep_params:
5+
- {
6+
data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
7+
}
8+
csv_path: "../microbenchmarks/host_device"
9+
trace_dir: "../microbenchmarks/host_device/trace"

Ironwood/configs/host_device/host_device_single_chip.yaml

Lines changed: 0 additions & 11 deletions
This file was deleted.

Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ spec:
2424
cd accelerator-microbenchmarks
2525
pip install -r requirements.txt
2626
27-
export TPU_VISIBLE_CHIPS=0
28-
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml
27+
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device.yaml
2928
3029
resources:
3130
requests:

Ironwood/src/benchmark_host_device.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,43 +18,27 @@
1818
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
1919
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
2020

21-
def get_tpu_devices(num_devices: int):
22-
devices = jax.devices()
23-
if len(devices) < num_devices:
24-
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
25-
return devices[:num_devices]
2621

2722
def benchmark_host_device(
28-
num_devices: int,
29-
data_size_mb: int,
23+
data_size_mib: int,
3024
num_runs: int = 100,
3125
trace_dir: str = None,
3226
) -> Dict[str, Any]:
3327
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
34-
tpu_devices = get_tpu_devices(num_devices)
3528

36-
num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize
29+
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize
3730

3831
# Allocate Host Source Buffer
39-
host_data = np.random.normal(size=(num_elements,)).astype(np.float32)
32+
column = 128
33+
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)
4034

4135
print(
42-
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
43-
f" {num_devices} devices for {num_runs} iterations"
36+
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations"
4437
)
4538

46-
# Setup Mesh Sharding (1D)
47-
mesh = sharding.Mesh(
48-
np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",)
49-
)
50-
# Shard the 1D array across "x"
51-
partition_spec = sharding.PartitionSpec("x")
52-
53-
data_sharding = sharding.NamedSharding(mesh, partition_spec)
54-
5539
# Performance Lists
5640
h2d_perf, d2h_perf = [], []
57-
41+
5842
# Profiling Context
5943
import contextlib
6044
if trace_dir:
@@ -65,7 +49,7 @@ def benchmark_host_device(
6549
with profiler_context:
6650
# Warmup
6751
for _ in range(2):
68-
device_array = jax.device_put(host_data, data_sharding)
52+
device_array = jax.device_put(host_data)
6953
device_array.block_until_ready()
7054
host_out = np.array(device_array)
7155
device_array.delete()
@@ -83,15 +67,14 @@ def benchmark_host_device(
8367
t0 = time.perf_counter()
8468

8569
# Simple device_put
86-
device_array = jax.device_put(host_data, data_sharding)
70+
device_array = jax.device_put(host_data)
8771
device_array.block_until_ready()
8872

8973
t1 = time.perf_counter()
9074
h2d_perf.append((t1 - t0) * 1000)
9175

92-
# Verify H2D shape/sharding
76+
# Verify H2D shape
9377
assert device_array.shape == host_data.shape
94-
assert device_array.sharding == data_sharding
9578

9679
# D2H
9780
t2 = time.perf_counter()
@@ -111,19 +94,15 @@ def benchmark_host_device(
11194
}
11295

11396
def benchmark_host_device_calculate_metrics(
114-
num_devices: int,
115-
data_size_mb: int,
97+
data_size_mib: int,
11698
H2D_Bandwidth_ms: List[float],
11799
D2H_Bandwidth_ms: List[float],
118100
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
119101
"""Calculates metrics for Host-Device transfer."""
120102
params = locals().items()
121103

122-
data_size_mib = data_size_mb
123-
124104
# Filter out list params from metadata to avoid explosion
125105
metadata_keys = {
126-
"num_devices",
127106
"data_size_mib",
128107
}
129108
metadata = {k: v for k, v in params if k in metadata_keys}
@@ -134,7 +113,7 @@ def add_metric(name, ms_list):
134113
# Report Bandwidth (GiB/s)
135114
# Handle division by zero if ms is 0
136115
bw_list = [
137-
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
116+
((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0
138117
for ms in ms_list
139118
]
140119
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")

0 commit comments

Comments
 (0)