diff --git a/Ironwood/configs/host_device/host_device.yaml b/Ironwood/configs/host_device/host_device.yaml new file mode 100644 index 00000000..0b48800c --- /dev/null +++ b/Ironwood/configs/host_device/host_device.yaml @@ -0,0 +1,9 @@ +benchmarks: +- benchmark_name: host_device + num_runs: 20 + benchmark_sweep_params: + - { + data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] + } + csv_path: "../microbenchmarks/host_device" + trace_dir: "../microbenchmarks/host_device/trace" diff --git a/Ironwood/configs/host_device/host_device_single_chip.yaml b/Ironwood/configs/host_device/host_device_single_chip.yaml deleted file mode 100644 index f8e15b72..00000000 --- a/Ironwood/configs/host_device/host_device_single_chip.yaml +++ /dev/null @@ -1,11 +0,0 @@ -benchmarks: -- benchmark_name: host_device - num_runs: 20 - benchmark_sweep_params: - # Single Chip (1 Chip, 2 Devices) - - { - num_devices: 2, - data_size_mb_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] - } - csv_path: "../microbenchmarks/host_device/single_chip" - trace_dir: "../microbenchmarks/host_device/single_chip/trace" diff --git a/Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml b/Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml index 1c486e91..8c027c01 100644 --- a/Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml +++ b/Ironwood/guides/host_device/tpu7x-host-device-benchmark.yaml @@ -24,8 +24,7 @@ spec: cd accelerator-microbenchmarks pip install -r requirements.txt - export TPU_VISIBLE_CHIPS=0 - bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device_single_chip.yaml + bash ./Ironwood/scripts/run_host_device_benchmark.sh --config Ironwood/configs/host_device/host_device.yaml resources: requests: diff --git a/Ironwood/src/benchmark_host_device.py b/Ironwood/src/benchmark_host_device.py index d8ed139a..a17864da 100644 --- a/Ironwood/src/benchmark_host_device.py +++ b/Ironwood/src/benchmark_host_device.py @@ -18,43 +18,27 @@ os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736" -def get_tpu_devices(num_devices: int): - devices = jax.devices() - if len(devices) < num_devices: - raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}") - return devices[:num_devices] def benchmark_host_device( - num_devices: int, - data_size_mb: int, + data_size_mib: int, num_runs: int = 100, trace_dir: str = None, ) -> Dict[str, Any]: """Benchmarks H2D/D2H transfer using simple device_put/device_get.""" - tpu_devices = get_tpu_devices(num_devices) - num_elements = 1024 * 1024 * data_size_mb // np.dtype(np.float32).itemsize + num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize # Allocate Host Source Buffer - host_data = np.random.normal(size=(num_elements,)).astype(np.float32) + column = 128 + host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32) print( - f"Benchmarking (Simple) Transfer with Data Size: {data_size_mb} MB on" - f" {num_devices} devices for {num_runs} iterations" + f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations" ) - # Setup Mesh Sharding (1D) - mesh = sharding.Mesh( - np.array(tpu_devices).reshape((num_devices,)), axis_names=("x",) - ) - # Shard the 1D array across "x" - partition_spec = sharding.PartitionSpec("x") - - data_sharding = sharding.NamedSharding(mesh, partition_spec) - # Performance Lists h2d_perf, d2h_perf = [], [] - + # Profiling Context import contextlib if trace_dir: @@ -65,7 +49,7 @@ def benchmark_host_device( with profiler_context: # Warmup for _ in range(2): - device_array = jax.device_put(host_data, data_sharding) + device_array = jax.device_put(host_data) device_array.block_until_ready() host_out = np.array(device_array) device_array.delete() @@ -83,15 +67,14 @@ def benchmark_host_device( t0 = time.perf_counter() # Simple device_put - device_array = jax.device_put(host_data, data_sharding) + device_array = jax.device_put(host_data) device_array.block_until_ready() t1 = time.perf_counter() h2d_perf.append((t1 - t0) * 1000) - # Verify H2D shape/sharding + # Verify H2D shape assert device_array.shape == host_data.shape - assert device_array.sharding == data_sharding # D2H t2 = time.perf_counter() @@ -111,19 +94,15 @@ def benchmark_host_device( } def benchmark_host_device_calculate_metrics( - num_devices: int, - data_size_mb: int, + data_size_mib: int, H2D_Bandwidth_ms: List[float], D2H_Bandwidth_ms: List[float], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Calculates metrics for Host-Device transfer.""" params = locals().items() - data_size_mib = data_size_mb - # Filter out list params from metadata to avoid explosion metadata_keys = { - "num_devices", "data_size_mib", } metadata = {k: v for k, v in params if k in metadata_keys} @@ -134,7 +113,7 @@ def add_metric(name, ms_list): # Report Bandwidth (GiB/s) # Handle division by zero if ms is 0 bw_list = [ - ((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0 + ((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0 for ms in ms_list ] stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")