Skip to content

Commit a86475d

Browse files
committed
Add baseline pipelined flow to H2D benchmark
1 parent 7b090f1 commit a86475d

2 files changed

Lines changed: 103 additions & 26 deletions

File tree

Ironwood/configs/host_device/host_device.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ benchmarks:
33
num_runs: 20
44
benchmark_sweep_params:
55
- {
6-
data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
6+
h2d_type: ["simple", "pipelined"],
7+
data_size_mib_list: [1, 16, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768],
78
}
89
csv_path: "../microbenchmarks/host_device"
910
trace_dir: "../microbenchmarks/host_device/trace"

Ironwood/src/benchmark_host_device.py

Lines changed: 101 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Dict, Tuple, List
66

77
import jax
8-
from jax import sharding
8+
from jax import numpy as jnp
99
import numpy as np
1010
from benchmark_utils import MetricsStatistics
1111

@@ -23,17 +23,23 @@ def benchmark_host_device(
2323
data_size_mib: int,
2424
num_runs: int = 100,
2525
trace_dir: str = None,
26+
h2d_type: str = "simple",
2627
) -> Dict[str, Any]:
27-
"""Benchmarks H2D/D2H transfer using simple device_put/device_get."""
28+
"""Benchmarks H2D/D2H transfer using device_put/device_get."""
2829

2930
num_elements = 1024 * 1024 * data_size_mib // np.dtype(np.float32).itemsize
3031

3132
# Allocate Host Source Buffer
3233
column = 128
3334
host_data = np.random.normal(size=(num_elements // column, column)).astype(np.float32)
3435

36+
# Used in pipelined flow
37+
# TODO: turn into a param
38+
num_devices_to_perform_h2d = 1
39+
target_devices = jax.devices()[:num_devices_to_perform_h2d]
40+
3541
print(
36-
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations",
42+
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
3743
flush=True
3844
)
3945

@@ -65,29 +71,98 @@ def benchmark_host_device(
6571

6672
with step_context:
6773
# H2D
68-
t0 = time.perf_counter()
69-
70-
# Simple device_put
71-
device_array = jax.device_put(host_data)
72-
device_array.block_until_ready()
73-
74-
t1 = time.perf_counter()
75-
h2d_perf.append((t1 - t0) * 1000)
76-
77-
# Verify H2D shape
78-
assert device_array.shape == host_data.shape
79-
80-
# D2H
81-
t2 = time.perf_counter()
82-
83-
# Simple device_get
84-
# Note: device_get returns a numpy array (copy)
85-
_ = jax.device_get(device_array)
86-
87-
t3 = time.perf_counter()
88-
d2h_perf.append((t3 - t2) * 1000)
74+
if h2d_type == "simple":
75+
t0 = time.perf_counter()
76+
# Simple device_put
77+
device_array = jax.device_put(host_data)
78+
device_array.block_until_ready()
79+
t1 = time.perf_counter()
80+
81+
# Verify H2D shape
82+
assert device_array.shape == host_data.shape
83+
84+
h2d_perf.append((t1 - t0) * 1000)
8985

90-
device_array.delete()
86+
# D2H
87+
t2 = time.perf_counter()
88+
89+
# Simple device_get
90+
# Note: device_get returns a numpy array (copy)
91+
_ = jax.device_get(device_array)
92+
93+
t3 = time.perf_counter()
94+
d2h_perf.append((t3 - t2) * 1000)
95+
96+
device_array.delete()
97+
elif h2d_type == "pipelined":
98+
target_chunk_size_mib = 16 # Sweet spot from profiling
99+
num_devices = len(target_devices)
100+
101+
tensors_on_device = []
102+
103+
# Calculate chunks per device
104+
data_per_dev = data_size_mib / num_devices
105+
chunks_per_dev = int(data_per_dev / target_chunk_size_mib)
106+
chunks_per_dev = max(1, chunks_per_dev)
107+
108+
chunks = np.array_split(host_data, chunks_per_dev * num_devices, axis=0)
109+
110+
t0 = time.perf_counter()
111+
if chunks_per_dev > 1:
112+
# We need to map chunks to the correct device
113+
# This simple example assumes chunks are perfectly divisible and ordered
114+
# In production, use `jax.sharding` mesh logic for complex layouts
115+
116+
# approach 1: simple for loop
117+
for idx, chunk in enumerate(chunks):
118+
if num_devices > 1:
119+
dev = target_devices[idx % num_devices]
120+
else:
121+
dev = target_devices[0]
122+
tensors_on_device.append(jax.device_put(chunk, dev))
123+
# Re-assemble array
124+
result = jnp.vstack(tensors_on_device)
125+
# Wait for all chunks to be transferred
126+
result.block_until_ready()
127+
128+
# approach 2: generator (slightly less overhead)
129+
# def chunk_generator(num_devices, chunks_per_dev):
130+
# for n in range(chunks_per_dev):
131+
# for d in range(num_devices):
132+
# # 1. Get the specific small chunk
133+
# chunk = chunks[d*chunks_per_dev+n]
134+
135+
# # 2. Trigger an individual DMA transfer for this specific chunk
136+
# # This is where NUMA-local memory access matters
137+
# yield jax.device_put(chunk, target_devices[d])
138+
139+
# # Re-assemble array
140+
# result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev)))
141+
# # Wait for all chunks to be transferred
142+
# result.block_until_ready()
143+
else:
144+
print(f"Warning: {data_size_mib=} is not larger than {target_chunk_size_mib=}, falling back to standard JAX put.")
145+
# Fallback to standard JAX put for small data
146+
result = jax.device_put(host_data, target_devices[0])
147+
result.block_until_ready()
148+
149+
t1 = time.perf_counter()
150+
h2d_perf.append((t1 - t0) * 1000)
151+
152+
# D2H
153+
t2 = time.perf_counter()
154+
# Simple device_get
155+
# Note: device_get returns a numpy array (copy)
156+
_ = jax.device_get(result)
157+
158+
t3 = time.perf_counter()
159+
if not np.allclose(result, host_data):
160+
print("pipelined result not equal to host_data")
161+
d2h_perf.append((t3 - t2) * 1000)
162+
163+
for r in tensors_on_device:
164+
r.delete()
165+
del tensors_on_device
91166

92167
return {
93168
"H2D_Bandwidth_ms": h2d_perf,
@@ -98,6 +173,7 @@ def benchmark_host_device_calculate_metrics(
98173
data_size_mib: int,
99174
H2D_Bandwidth_ms: List[float],
100175
D2H_Bandwidth_ms: List[float],
176+
h2d_type: str = "simple",
101177
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
102178
"""Calculates metrics for Host-Device transfer."""
103179
params = locals().items()

0 commit comments

Comments
 (0)