Skip to content

Commit 641c73b

Browse files
committed
Cleanup and commit benchmark_host_device.py
1 parent afc2320 commit 641c73b

1 file changed

Lines changed: 214 additions & 38 deletions

File tree

Ironwood/src/benchmark_host_device.py

Lines changed: 214 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,44 @@
1-
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
1+
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
22

3-
import time
4-
import os
5-
from typing import Any, Dict, Tuple, List
3+
import time
4+
import os
5+
from typing import Any, Dict, Tuple, List
66

7-
import jax
8-
from jax import numpy as jnp
9-
import numpy as np
10-
from benchmark_utils import MetricsStatistics
7+
import jax
8+
from jax import numpy as jnp
9+
import numpy as np
10+
from benchmark_utils import MetricsStatistics
1111

1212

13-
libtpu_init_args = [
14-
"--xla_tpu_dvfs_p_state=7",
15-
]
16-
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
17-
# 64 GiB
18-
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
19-
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
13+
libtpu_init_args = [
14+
"--xla_tpu_dvfs_p_state=7",
15+
]
16+
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
17+
# 64 GiB
18+
os.environ["TPU_PREMAPPED_BUFFER_SIZE"] = "68719476736"
19+
os.environ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES"] = "68719476736"
2020

2121

22-
def benchmark_host_device(
23-
data_size_mib: int,
24-
num_runs: int = 100,
25-
trace_dir: str = None,
26-
h2d_type: str = "simple",
27-
) -> Dict[str, Any]:
28-
"""Benchmarks H2D/D2H transfer using device_put/device_get."""
29-
30-
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize
31-
32-
# Allocate Host Source Buffer
33-
column = 128
34-
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)
35-
36-
# Used in pipelined flow
37-
# TODO: turn into a param
38-
num_devices_to_perform_h2d = 1
39-
target_devices = jax.devices()[:num_devices_to_perform_h2d]
22+
def benchmark_host_device(
23+
h2d_type: str,
24+
data_size_mib: int,
25+
num_runs: int = 100,
26+
trace_dir: str = None,
27+
) -> Dict[str, Any]:
28+
"""Benchmarks H2D/D2H transfer using device_put/device_get."""
29+
30+
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize
31+
32+
# Allocate Host Source Buffer
33+
column = 128
34+
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)
35+
36+
# Used in pipelined flow
37+
# TODO: turn into a param
38+
num_devices_to_perform_h2d = 1
39+
target_devices = jax.devices()[:num_devices_to_perform_h2d]
4040

41+
<<<<<<< Updated upstream
4142
print(
4243
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
4344
flush=True
@@ -196,13 +197,188 @@ def add_metric(name, ms_list):
196197
for ms in ms_list
197198
]
198199
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
200+
=======
201+
>>>>>>> Stashed changes
199202
print(
200-
f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}",
203+
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
201204
flush=True
202205
)
203-
metrics.update(stats_bw.serialize_statistics())
204206

205-
add_metric("H2D", H2D_Bandwidth_ms)
206-
add_metric("D2H", D2H_Bandwidth_ms)
207+
# Performance Lists
208+
h2d_perf, d2h_perf = [], []
209+
210+
# Profiling Context
211+
import contextlib
212+
if trace_dir:
213+
profiler_context = jax.profiler.trace(trace_dir)
214+
else:
215+
profiler_context = contextlib.nullcontext()
216+
217+
with profiler_context:
218+
# Warmup
219+
for _ in range(2):
220+
device_array = jax.device_put(host_data)
221+
device_array.block_until_ready()
222+
host_out = np.array(device_array)
223+
device_array.delete()
224+
del host_out
225+
226+
for i in range(num_runs):
227+
# Step Context
228+
if trace_dir:
229+
step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i)
230+
else:
231+
step_context = contextlib.nullcontext()
232+
233+
with step_context:
234+
# H2D
235+
if h2d_type == "simple":
236+
t0 = time.perf_counter()
237+
# Simple device_put
238+
device_array = jax.device_put(host_data)
239+
device_array.block_until_ready()
240+
t1 = time.perf_counter()
241+
242+
# Verify H2D shape
243+
assert device_array.shape == host_data.shape
244+
245+
h2d_perf.append((t1 - t0) * 1000)
246+
247+
# D2H
248+
t2 = time.perf_counter()
249+
250+
# Simple device_get
251+
# Note: device_get returns a numpy array (copy)
252+
_ = jax.device_get(device_array)
253+
254+
t3 = time.perf_counter()
255+
d2h_perf.append((t3 - t2) * 1000)
256+
257+
device_array.delete()
258+
elif h2d_type == "pipelined":
259+
target_chunk_size_mib = 16 # Sweet spot from profiling
260+
num_devices = len(target_devices)
261+
262+
tensors_on_device = []
263+
264+
# Calculate chunks per device
265+
data_per_dev = data_size_mib / num_devices
266+
chunks_per_dev = int(data_per_dev / target_chunk_size_mib)
267+
chunks_per_dev = max(1, chunks_per_dev)
268+
269+
chunks = np.array_split(host_data, chunks_per_dev * num_devices, axis=0)
270+
271+
t0 = time.perf_counter()
272+
if chunks_per_dev > 1:
273+
# We need to map chunks to the correct device
274+
# This simple example assumes chunks are perfectly divisible and ordered
275+
# In production, use `jax.sharding` mesh logic for complex layouts
276+
277+
# approach 1: simple for loop
278+
for idx, chunk in enumerate(chunks):
279+
if num_devices > 1:
280+
dev = target_devices[idx % num_devices]
281+
else:
282+
dev = target_devices[0]
283+
tensors_on_device.append(jax.device_put(chunk, dev))
284+
# Re-assemble array
285+
# result = jnp.vstack(tensors_on_device)
286+
# Wait for all chunks to be transferred
287+
# result.block_until_ready()
288+
289+
# Don't re-assemble
290+
for tensor in tensors_on_device:
291+
tensor.block_until_ready()
292+
293+
# approach 2: generator (slightly less overhead)
294+
# def chunk_generator(num_devices, chunks_per_dev):
295+
# for n in range(chunks_per_dev):
296+
# for d in range(num_devices):
297+
# # 1. Get the specific small chunk
298+
# chunk = chunks[d*chunks_per_dev+n]
299+
300+
# # 2. Trigger an individual DMA transfer for this specific chunk
301+
# # This is where NUMA-local memory access matters
302+
# yield jax.device_put(chunk, target_devices[d])
303+
304+
# # Re-assemble array
305+
# result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev)))
306+
# # Wait for all chunks to be transferred
307+
# result.block_until_ready()
308+
else:
309+
print(f"Warning: {data_size_mib=} is not larger than {target_chunk_size_mib=}, falling back to standard JAX put.")
310+
# Fallback to standard JAX put for small data
311+
result = jax.device_put(host_data, target_devices[0])
312+
result.block_until_ready()
313+
314+
t1 = time.perf_counter()
315+
h2d_perf.append((t1 - t0) * 1000)
316+
317+
# D2H
318+
t2 = time.perf_counter()
319+
# Simple device_get
320+
# Note: device_get returns a numpy array (copy)
321+
# result = jnp.vstack(tensors_on_device)
322+
# _ = jax.device_get(result)
323+
# del tensors_on_device
324+
325+
# device_put instead
326+
tensors_on_host = []
327+
for tensor in tensors_on_device:
328+
tensors_on_host.append(jax.device_put(x, jax.devices("cpu")[0]))
329+
for tensor in tensors_on_host:
330+
tensor.block_until_ready()
331+
332+
t3 = time.perf_counter()
333+
if not np.allclose(result, host_data):
334+
print("pipelined result not equal to host_data")
335+
d2h_perf.append((t3 - t2) * 1000)
336+
337+
for r in tensors_on_device:
338+
r.delete()
339+
del tensors_on_device
340+
for r in tensors_on_host:
341+
r.delete()
342+
del tensors_on_host
343+
344+
return {
345+
"H2D_Bandwidth_ms": h2d_perf,
346+
"D2H_Bandwidth_ms": d2h_perf,
347+
}
348+
349+
def benchmark_host_device_calculate_metrics(
350+
data_size_mib: int,
351+
H2D_Bandwidth_ms: List[float],
352+
D2H_Bandwidth_ms: List[float],
353+
h2d_type: str = "simple",
354+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
355+
"""Calculates metrics for Host-Device transfer."""
356+
params = locals().items()
357+
358+
# Filter out list params from metadata to avoid explosion
359+
metadata_keys = {
360+
"data_size_mib",
361+
}
362+
metadata = {k: v for k, v in params if k in metadata_keys}
363+
metadata["dtype"] = "float32"
364+
365+
metrics = {}
366+
367+
def add_metric(name, ms_list):
368+
# Report Bandwidth (GiB/s)
369+
# Handle division by zero if ms is 0
370+
bw_list = [
371+
((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0
372+
for ms in ms_list
373+
]
374+
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
375+
print(
376+
f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, P95: {stats_bw.statistics['p95']}",
377+
flush=True
378+
)
379+
metrics.update(stats_bw.serialize_statistics())
380+
381+
add_metric("H2D", H2D_Bandwidth_ms)
382+
add_metric("D2H", D2H_Bandwidth_ms)
207383

208-
return metadata, metrics
384+
return metadata, metrics

0 commit comments

Comments
 (0)