|
1 | | -"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline).""" |
| 1 | + """Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline).""" |
2 | 2 |
|
3 | | -import time |
4 | | -import os |
5 | | -from typing import Any, Dict, Tuple, List |
| 3 | + import time |
| 4 | + import os |
| 5 | + from typing import Any, Dict, Tuple, List |
6 | 6 |
|
7 | | -import jax |
8 | | -from jax import numpy as jnp |
9 | | -import numpy as np |
10 | | -from benchmark_utils import MetricsStatistics |
| 7 | + import jax |
| 8 | + from jax import numpy as jnp |
| 9 | + import numpy as np |
| 10 | + from benchmark_utils import MetricsStatistics |
11 | 11 |
|
12 | 12 |
|
13 | | -libtpu_init_args = [ |
14 | | - "--xla_tpu_dvfs_p_state=7", |
15 | | -] |
16 | | -os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) |
17 | | -# 64 GiB |
18 | | -os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" |
19 | | -os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736" |
| 13 | + libtpu_init_args = [ |
| 14 | + "--xla_tpu_dvfs_p_state=7", |
| 15 | + ] |
| 16 | + os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args) |
| 17 | + # 64 GiB |
| 18 | + os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736" |
| 19 | + os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736" |
20 | 20 |
|
21 | 21 |
|
22 | | -def benchmark_host_device( |
23 | | - data_size_mib: int, |
24 | | - num_runs: int = 100, |
25 | | - trace_dir: str = None, |
26 | | - h2d_type: str = "simple", |
27 | | -) -> Dict[str, Any]: |
28 | | - """Benchmarks H2D/D2H transfer using device_put/device_get.""" |
29 | | - |
30 | | - num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize |
31 | | - |
32 | | - # Allocate Host Source Buffer |
33 | | - column = 128 |
34 | | - host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32) |
35 | | - |
36 | | - # Used in pipelined flow |
37 | | - # TODO: turn into a param |
38 | | - num_devices_to_perform_h2d = 1 |
39 | | - target_devices = jax.devices()[:num_devices_to_perform_h2d] |
| 22 | + def benchmark_host_device( |
| 23 | + h2d_type: str, |
| 24 | + data_size_mib: int, |
| 25 | + num_runs: int = 100, |
| 26 | + trace_dir: str = None, |
| 27 | + ) -> Dict[str, Any]: |
| 28 | + """Benchmarks H2D/D2H transfer using device_put/device_get.""" |
| 29 | + |
| 30 | + num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize |
| 31 | + |
| 32 | + # Allocate Host Source Buffer |
| 33 | + column = 128 |
| 34 | + host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32) |
| 35 | + |
| 36 | + # Used in pipelined flow |
| 37 | + # TODO: turn into a param |
| 38 | + num_devices_to_perform_h2d = 1 |
| 39 | + target_devices = jax.devices()[:num_devices_to_perform_h2d] |
40 | 40 |
|
| 41 | +<<<<<<< Updated upstream |
41 | 42 | print( |
42 | 43 | f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}", |
43 | 44 | flush=True |
@@ -196,13 +197,188 @@ def add_metric(name, ms_list): |
196 | 197 | for ms in ms_list |
197 | 198 | ] |
198 | 199 | stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") |
| 200 | +======= |
| 201 | +>>>>>>> Stashed changes |
199 | 202 | print( |
200 | | - f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}", |
| 203 | + f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}", |
201 | 204 | flush=True |
202 | 205 | ) |
203 | | - metrics.update(stats_bw.serialize_statistics()) |
204 | 206 |
|
205 | | - add_metric("H2D", H2D_Bandwidth_ms) |
206 | | - add_metric("D2H", D2H_Bandwidth_ms) |
| 207 | + # Performance Lists |
| 208 | + h2d_perf, d2h_perf = [], [] |
| 209 | + |
| 210 | + # Profiling Context |
| 211 | + import contextlib |
| 212 | + if trace_dir: |
| 213 | + profiler_context = jax.profiler.trace(trace_dir) |
| 214 | + else: |
| 215 | + profiler_context = contextlib.nullcontext() |
| 216 | + |
| 217 | + with profiler_context: |
| 218 | + # Warmup |
| 219 | + for _ in range(2): |
| 220 | + device_array = jax.device_put(host_data) |
| 221 | + device_array.block_until_ready() |
| 222 | + host_out = np.array(device_array) |
| 223 | + device_array.delete() |
| 224 | + del host_out |
| 225 | + |
| 226 | + for i in range(num_runs): |
| 227 | + # Step Context |
| 228 | + if trace_dir: |
| 229 | + step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i) |
| 230 | + else: |
| 231 | + step_context = contextlib.nullcontext() |
| 232 | + |
| 233 | + with step_context: |
| 234 | + # H2D |
| 235 | + if h2d_type == "simple": |
| 236 | + t0 = time.perf_counter() |
| 237 | + # Simple device_put |
| 238 | + device_array = jax.device_put(host_data) |
| 239 | + device_array.block_until_ready() |
| 240 | + t1 = time.perf_counter() |
| 241 | + |
| 242 | + # Verify H2D shape |
| 243 | + assert device_array.shape == host_data.shape |
| 244 | + |
| 245 | + h2d_perf.append((t1 - t0) * 1000) |
| 246 | + |
| 247 | + # D2H |
| 248 | + t2 = time.perf_counter() |
| 249 | + |
| 250 | + # Simple device_get |
| 251 | + # Note: device_get returns a numpy array (copy) |
| 252 | + _ = jax.device_get(device_array) |
| 253 | + |
| 254 | + t3 = time.perf_counter() |
| 255 | + d2h_perf.append((t3 - t2) * 1000) |
| 256 | + |
| 257 | + device_array.delete() |
| 258 | + elif h2d_type == "pipelined": |
| 259 | + target_chunk_size_mib = 16 # Sweet spot from profiling |
| 260 | + num_devices = len(target_devices) |
| 261 | + |
| 262 | + tensors_on_device = [] |
| 263 | + |
| 264 | + # Calculate chunks per device |
| 265 | + data_per_dev = data_size_mib / num_devices |
| 266 | + chunks_per_dev = int(data_per_dev / target_chunk_size_mib) |
| 267 | + chunks_per_dev = max(1, chunks_per_dev) |
| 268 | + |
| 269 | + chunks = np.array_split(host_data, chunks_per_dev * num_devices, axis=0) |
| 270 | + |
| 271 | + t0 = time.perf_counter() |
| 272 | + if chunks_per_dev > 1: |
| 273 | + # We need to map chunks to the correct device |
| 274 | + # This simple example assumes chunks are perfectly divisible and ordered |
| 275 | + # In production, use `jax.sharding` mesh logic for complex layouts |
| 276 | + |
| 277 | + # approach 1: simple for loop |
| 278 | + for idx, chunk in enumerate(chunks): |
| 279 | + if num_devices > 1: |
| 280 | + dev = target_devices[idx % num_devices] |
| 281 | + else: |
| 282 | + dev = target_devices[0] |
| 283 | + tensors_on_device.append(jax.device_put(chunk, dev)) |
| 284 | + # Re-assemble array |
| 285 | + # result = jnp.vstack(tensors_on_device) |
| 286 | + # Wait for all chunks to be transferred |
| 287 | + # result.block_until_ready() |
| 288 | + |
| 289 | + # Don't re-assemble |
| 290 | + for tensor in tensors_on_device: |
| 291 | + tensor.block_until_ready() |
| 292 | + |
| 293 | + # approach 2: generator (slightly less overhead) |
| 294 | + # def chunk_generator(num_devices, chunks_per_dev): |
| 295 | + # for n in range(chunks_per_dev): |
| 296 | + # for d in range(num_devices): |
| 297 | + # # 1. Get the specific small chunk |
| 298 | + # chunk = chunks[d*chunks_per_dev+n] |
| 299 | + |
| 300 | + # # 2. Trigger an individual DMA transfer for this specific chunk |
| 301 | + # # This is where NUMA-local memory access matters |
| 302 | + # yield jax.device_put(chunk, target_devices[d]) |
| 303 | + |
| 304 | + # # Re-assemble array |
| 305 | + # result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev))) |
| 306 | + # # Wait for all chunks to be transferred |
| 307 | + # result.block_until_ready() |
| 308 | + else: |
| 309 | + print(f"Warning: {data_size_mib=} is not larger than {target_chunk_size_mib=}, falling back to standard JAX put.") |
| 310 | + # Fallback to standard JAX put for small data |
| 311 | + result = jax.device_put(host_data, target_devices[0]) |
| 312 | + result.block_until_ready() |
| 313 | + |
| 314 | + t1 = time.perf_counter() |
| 315 | + h2d_perf.append((t1 - t0) * 1000) |
| 316 | + |
| 317 | + # D2H |
| 318 | + t2 = time.perf_counter() |
| 319 | + # Simple device_get |
| 320 | + # Note: device_get returns a numpy array (copy) |
| 321 | + # result = jnp.vstack(tensors_on_device) |
| 322 | + # _ = jax.device_get(result) |
| 323 | + # del tensors_on_device |
| 324 | + |
| 325 | + # device_put instead |
| 326 | + tensors_on_host = [] |
| 327 | + for tensor in tensors_on_device: |
| 328 | + tensors_on_host.append(jax.device_put(x, jax.devices("cpu")[0])) |
| 329 | + for tensor in tensors_on_host: |
| 330 | + tensor.block_until_ready() |
| 331 | + |
| 332 | + t3 = time.perf_counter() |
| 333 | + if not np.allclose(result, host_data): |
| 334 | + print("pipelined result not equal to host_data") |
| 335 | + d2h_perf.append((t3 - t2) * 1000) |
| 336 | + |
| 337 | + for r in tensors_on_device: |
| 338 | + r.delete() |
| 339 | + del tensors_on_device |
| 340 | + for r in tensors_on_host: |
| 341 | + r.delete() |
| 342 | + del tensors_on_host |
| 343 | + |
| 344 | + return { |
| 345 | + "H2D_Bandwidth_ms": h2d_perf, |
| 346 | + "D2H_Bandwidth_ms": d2h_perf, |
| 347 | + } |
| 348 | + |
| 349 | + def benchmark_host_device_calculate_metrics( |
| 350 | + data_size_mib: int, |
| 351 | + H2D_Bandwidth_ms: List[float], |
| 352 | + D2H_Bandwidth_ms: List[float], |
| 353 | + h2d_type: str = "simple", |
| 354 | + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
| 355 | + """Calculates metrics for Host-Device transfer.""" |
| 356 | + params = locals().items() |
| 357 | + |
| 358 | + # Filter out list params from metadata to avoid explosion |
| 359 | + metadata_keys = { |
| 360 | + "data_size_mib", |
| 361 | + } |
| 362 | + metadata = {k: v for k, v in params if k in metadata_keys} |
| 363 | + metadata["dtype"] = "float32" |
| 364 | + |
| 365 | + metrics = {} |
| 366 | + |
| 367 | + def add_metric(name, ms_list): |
| 368 | + # Report Bandwidth (GiB/s) |
| 369 | + # Handle division by zero if ms is 0 |
| 370 | + bw_list = [ |
| 371 | + ((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0 |
| 372 | + for ms in ms_list |
| 373 | + ] |
| 374 | + stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)") |
| 375 | + print( |
| 376 | + f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}", |
| 377 | + flush=True |
| 378 | + ) |
| 379 | + metrics.update(stats_bw.serialize_statistics()) |
| 380 | + |
| 381 | + add_metric("H2D", H2D_Bandwidth_ms) |
| 382 | + add_metric("D2H", D2H_Bandwidth_ms) |
207 | 383 |
|
208 | | - return metadata, metrics |
| 384 | + return metadata, metrics |
0 commit comments