|
| 1 | +"""Benchmarks Host-to-Device and Device-to-Host transfer performance.""" |
| 2 | + |
| 3 | +import concurrent.futures |
| 4 | +import gc |
| 5 | +import time |
| 6 | +import os |
| 7 | +from typing import Any, Dict, Tuple, List |
| 8 | + |
| 9 | +import jax |
| 10 | +from jax import sharding |
| 11 | +import numpy as np |
| 12 | +from benchmark_utils import MetricsStatistics |
| 13 | + |
| 14 | +os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" # 64 GiB |
| 15 | +os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736" |
| 16 | + |
| 17 | +def get_tpu_devices(num_devices: int): |
| 18 | + devices = jax.devices() |
| 19 | + if len(devices) < num_devices: |
| 20 | + raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}") |
| 21 | + return devices[:num_devices] |
| 22 | + |
| 23 | +def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devices, chunks_per_device): |
| 24 | + # Smart Chunked H2D |
| 25 | + chk_h2d_start = time.perf_counter() |
| 26 | + total_workers = num_devices * chunks_per_device |
| 27 | + with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor: |
| 28 | + chunked_futures = [] |
| 29 | + for shard, dev in zip(host_shards, target_devices): |
| 30 | + sub_chunks = np.array_split(shard, chunks_per_device, axis=0) |
| 31 | + for chunk in sub_chunks: |
| 32 | + chunked_futures.append( |
| 33 | + executor.submit(jax.device_put, chunk, dev) |
| 34 | + ) |
| 35 | + chunked_buffers = [f.result() for f in chunked_futures] |
| 36 | + for db in chunked_buffers: |
| 37 | + db.block_until_ready() |
| 38 | + chk_h2d_end = time.perf_counter() |
| 39 | + h2d_ms = (chk_h2d_end - chk_h2d_start) * 1000 |
| 40 | + for db in chunked_buffers: |
| 41 | + db.delete() |
| 42 | + |
| 43 | + # Smart Chunked D2H |
| 44 | + data_on_device = jax.device_put(host_data, data_sharding) |
| 45 | + data_on_device.block_until_ready() |
| 46 | + |
| 47 | + chk_d2h_start = time.perf_counter() |
| 48 | + with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor: |
| 49 | + d2h_futures = [] |
| 50 | + for shard in data_on_device.addressable_shards: |
| 51 | + # Direct slicing on device array to avoid copy |
| 52 | + shard_len = shard.data.shape[0] |
| 53 | + chunk_size = (shard_len + chunks_per_device - 1) // chunks_per_device |
| 54 | + for i in range(chunks_per_device): |
| 55 | + start = i * chunk_size |
| 56 | + end = min((i + 1) * chunk_size, shard_len) |
| 57 | + if start < end: |
| 58 | + d2h_futures.append( |
| 59 | + executor.submit(jax.device_get, shard.data[start:end]) |
| 60 | + ) |
| 61 | + _ = [f.result() for f in d2h_futures] |
| 62 | + chk_d2h_end = time.perf_counter() |
| 63 | + d2h_ms = (chk_d2h_end - chk_d2h_start) * 1000 |
| 64 | + data_on_device.delete() |
| 65 | + |
| 66 | + return h2d_ms, d2h_ms |
| 67 | + |
| 68 | + |
| 69 | +def _run_warmup(host_data, data_sharding, data_size_mb): |
| 70 | + # --- ADAPTIVE WARM UP --- |
| 71 | + if data_size_mb <= 128: |
| 72 | + warmup_iters = 50 |
| 73 | + elif data_size_mb >= 8192: |
| 74 | + warmup_iters = 3 |
| 75 | + else: |
| 76 | + warmup_iters = 10 |
| 77 | + |
| 78 | + for _ in range(warmup_iters): |
| 79 | + data_on_device = jax.device_put(host_data, data_sharding) |
| 80 | + data_on_device.block_until_ready() |
| 81 | + _ = jax.device_get(data_on_device) |
| 82 | + data_on_device.delete() |
| 83 | + |
| 84 | + gc.collect() |
| 85 | + |
| 86 | +def _get_chunks_per_device(data_size_mb, num_devices): |
| 87 | + # --- SMART CHUNKING CONFIG --- |
| 88 | + target_chunk_size_mb = 16 |
| 89 | + max_global_threads = 256 |
| 90 | + |
| 91 | + data_per_device_mb = data_size_mb / num_devices |
| 92 | + |
| 93 | + if data_per_device_mb < target_chunk_size_mb: |
| 94 | + chunks_per_device = 1 |
| 95 | + else: |
| 96 | + chunks_per_device = int(data_per_device_mb / target_chunk_size_mb) |
| 97 | + |
| 98 | + total_threads = num_devices * chunks_per_device |
| 99 | + if total_threads > max_global_threads: |
| 100 | + chunks_per_device = max(1, int(max_global_threads / num_devices)) |
| 101 | + |
| 102 | + return chunks_per_device |
| 103 | + |
| 104 | + |
| 105 | +def benchmark_host_device( |
| 106 | + mesh_shape: str, |
| 107 | + data_size_mb: int, |
| 108 | + num_runs: int = 100, |
| 109 | + trace_dir: str = None, |
| 110 | +) -> Dict[str, Any]: |
| 111 | + """Benchmarks H2D/D2H transfer using smart chunking.""" |
| 112 | + dims = [int(d) for d in mesh_shape.split("x")] |
| 113 | + mesh_shape = tuple(dims) |
| 114 | + |
| 115 | + num_devices = int(np.prod(mesh_shape)) |
| 116 | + tpu_devices = get_tpu_devices(num_devices) |
| 117 | + |
| 118 | + rows = 1024 * data_size_mb // np.dtype(np.float32).itemsize |
| 119 | + |
| 120 | + host_data = np.ones((rows, 8, 128), dtype=np.float32) |
| 121 | + |
| 122 | + print( |
| 123 | + f"Benchmarking Transfer with Data Size: {data_size_mb} MB on" |
| 124 | + f" {num_devices} devices for {num_runs} iterations" |
| 125 | + ) |
| 126 | + |
| 127 | + # Setup Mesh Sharding |
| 128 | + if len(mesh_shape) == 1: |
| 129 | + mesh = sharding.Mesh( |
| 130 | + np.array(tpu_devices).reshape(mesh_shape), axis_names=("x",) |
| 131 | + ) |
| 132 | + data_sharding = sharding.NamedSharding(mesh, sharding.PartitionSpec("x")) |
| 133 | + else: |
| 134 | + mesh = sharding.Mesh( |
| 135 | + np.array(tpu_devices).reshape(mesh_shape), axis_names=("x", "y") |
| 136 | + ) |
| 137 | + data_sharding = sharding.NamedSharding( |
| 138 | + mesh, sharding.PartitionSpec(("x", "y")) |
| 139 | + ) |
| 140 | + |
| 141 | + # --- ADAPTIVE WARM UP --- |
| 142 | + _run_warmup(host_data, data_sharding, data_size_mb) |
| 143 | + |
| 144 | + # Pre-calculate sharding info |
| 145 | + dummy_put = jax.device_put(host_data[:num_devices], data_sharding) |
| 146 | + target_devices = [s.device for s in dummy_put.addressable_shards] |
| 147 | + dummy_put.delete() |
| 148 | + |
| 149 | + host_shards = np.split(host_data, num_devices, axis=0) |
| 150 | + |
| 151 | + # Performance Lists |
| 152 | + h2d_perf, d2h_perf = [], [] |
| 153 | + |
| 154 | + # --- SMART CHUNKING CONFIG --- |
| 155 | + chunks_per_device = _get_chunks_per_device(data_size_mb, num_devices) |
| 156 | + |
| 157 | + # Profiling Context |
| 158 | + if trace_dir: |
| 159 | + profiler_context = jax.profiler.trace(trace_dir) |
| 160 | + else: |
| 161 | + # No-op context manager |
| 162 | + import contextlib |
| 163 | + profiler_context = contextlib.nullcontext() |
| 164 | + |
| 165 | + with profiler_context: |
| 166 | + for i in range(num_runs): |
| 167 | + # Step Context |
| 168 | + if trace_dir: |
| 169 | + step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i) |
| 170 | + else: |
| 171 | + step_context = contextlib.nullcontext() |
| 172 | + |
| 173 | + with step_context: |
| 174 | + # Optimized Chunked Transfer (Sole Strategy) |
| 175 | + h2d_ms, d2h_ms = _run_chunked( |
| 176 | + host_data, data_sharding, host_shards, target_devices, |
| 177 | + num_devices, chunks_per_device |
| 178 | + ) |
| 179 | + h2d_perf.append(h2d_ms) |
| 180 | + d2h_perf.append(d2h_ms) |
| 181 | + |
| 182 | + del host_data, host_shards |
| 183 | + gc.collect() |
| 184 | + |
| 185 | + return { |
| 186 | + "H2D_Bandwidth": h2d_perf, |
| 187 | + "D2H_Bandwidth": d2h_perf, |
| 188 | + "Chunk_Count": chunks_per_device, |
| 189 | + "Thread_Count": num_devices * chunks_per_device, |
| 190 | + } |
| 191 | + |
| 192 | +def benchmark_host_device_calculate_metrics( |
| 193 | + mesh_shape: str, |
| 194 | + data_size_mb: int, |
| 195 | + H2D_Bandwidth: List[float], |
| 196 | + D2H_Bandwidth: List[float], |
| 197 | + Chunk_Count: int, |
| 198 | + Thread_Count: int, |
| 199 | +) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| 200 | + """Calculates metrics for Host-Device transfer.""" |
| 201 | + params = locals().items() |
| 202 | + |
| 203 | + # Filter out list params from metadata to avoid explosion |
| 204 | + metadata_keys = {"mesh_shape", "data_size_mb", "Chunk_Count", "Thread_Count"} |
| 205 | + metadata = {k: v for k, v in params if k in metadata_keys} |
| 206 | + |
| 207 | + metrics = {} |
| 208 | + |
| 209 | + def add_metric(name, ms_list): |
| 210 | + # Report Bandwidth (GiB/s) |
| 211 | + # Handle division by zero if ms is 0 |
| 212 | + bw_list = [ |
| 213 | + ((data_size_mb / 1024) / (ms / 1000)) if ms > 0 else 0.0 |
| 214 | + for ms in ms_list |
| 215 | + ] |
| 216 | + stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") |
| 217 | + metrics.update(stats_bw.serialize_statistics()) |
| 218 | + |
| 219 | + add_metric("H2D", H2D_Bandwidth) |
| 220 | + add_metric("D2H", D2H_Bandwidth) |
| 221 | + |
| 222 | + return metadata, metrics |
0 commit comments