Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Ironwood/configs/host_device/host_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
benchmarks:
- benchmark_name: host_device
num_runs: 20
benchmark_sweep_params:
- {
data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
}
csv_path: "../microbenchmarks/host_device"
trace_dir: "../microbenchmarks/host_device/trace"
11 changes: 0 additions & 11 deletions Ironwood/configs/host_device/host_device_single_chip.yaml

This file was deleted.

3 changes: 1 addition & 2 deletions Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ spec:
cd accelerator-microbenchmarks
pip install -r requirements.txt

export TPU_VISIBLE_CHIPS=0
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml
bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device.yaml

resources:
requests:
Expand Down
43 changes: 11 additions & 32 deletions Ironwood/src/benchmark_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,27 @@
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"

def get_tpu_devices(num_devices: int):
devices = jax.devices()
if len(devices) < num_devices:
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
return devices[:num_devices]

def benchmark_host_device(
num_devices: int,
data_size_mb: int,
data_size_mib: int,
num_runs: int = 100,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
tpu_devices = get_tpu_devices(num_devices)

num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize

# Allocate Host Source Buffer
host_data = np.random.normal(size=(num_elements,)).astype(np.float32)
column = 128
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)

print(
f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on"
f" {num_devices} devices for {num_runs} iterations"
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations"
)

# Setup Mesh Sharding (1D)
mesh = sharding.Mesh(
np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",)
)
# Shard the 1D array across "x"
partition_spec = sharding.PartitionSpec("x")

data_sharding = sharding.NamedSharding(mesh, partition_spec)

# Performance Lists
h2d_perf, d2h_perf = [], []

# Profiling Context
import contextlib
if trace_dir:
Expand All @@ -65,7 +49,7 @@ def benchmark_host_device(
with profiler_context:
# Warmup
for _ in range(2):
device_array = jax.device_put(host_data, data_sharding)
device_array = jax.device_put(host_data)
device_array.block_until_ready()
host_out = np.array(device_array)
device_array.delete()
Expand All @@ -83,15 +67,14 @@ def benchmark_host_device(
t0 = time.perf_counter()

# Simple device_put
device_array = jax.device_put(host_data, data_sharding)
device_array = jax.device_put(host_data)
device_array.block_until_ready()

t1 = time.perf_counter()
h2d_perf.append((t1 - t0) * 1000)

# Verify H2D shape/sharding
# Verify H2D shape
assert device_array.shape == host_data.shape
assert device_array.sharding == data_sharding

# D2H
t2 = time.perf_counter()
Expand All @@ -111,19 +94,15 @@ def benchmark_host_device(
}

def benchmark_host_device_calculate_metrics(
num_devices: int,
data_size_mb: int,
data_size_mib: int,
H2D_Bandwidth_ms: List[float],
D2H_Bandwidth_ms: List[float],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Calculates metrics for Host-Device transfer."""
params = locals().items()

data_size_mib = data_size_mb

# Filter out list params from metadata to avoid explosion
metadata_keys = {
"num_devices",
"data_size_mib",
}
metadata = {k: v for k, v in params if k in metadata_keys}
Expand All @@ -134,7 +113,7 @@ def add_metric(name, ms_list):
# Report Bandwidth (GiB/s)
# Handle division by zero if ms is 0
bw_list = [
((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0
((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0
for ms in ms_list
]
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
Expand Down