Skip to content

Commit 3a8d714

Browse files
committed
Auto searching chunk size for benchmarking
1 parent 1fa3c99 commit 3a8d714

1 file changed

Lines changed: 144 additions & 64 deletions

File tree

Ironwood/src/benchmark_host_device.py

Lines changed: 144 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ def get_tpu_devices(num_devices: int):
2020
raise RuntimeError(f"Require {num_devices} devices, found {len(devices)}")
2121
return devices[:num_devices]
2222

23-
def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devices, chunks_per_device):
24-
# Smart Chunked H2D
25-
chk_h2d_start = time.perf_counter()
23+
def _run_h2d_chunked(host_shards, target_devices, num_devices, chunks_per_device):
2624
total_workers = num_devices * chunks_per_device
2725
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
2826
chunked_futures = []
@@ -35,16 +33,10 @@ def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devi
3533
chunked_buffers = [f.result() for f in chunked_futures]
3634
for db in chunked_buffers:
3735
db.block_until_ready()
38-
chk_h2d_end = time.perf_counter()
39-
h2d_ms = (chk_h2d_end - chk_h2d_start) * 1000
40-
for db in chunked_buffers:
41-
db.delete()
36+
return chunked_buffers
4237

43-
# Smart Chunked D2H
44-
data_on_device = jax.device_put(host_data, data_sharding)
45-
data_on_device.block_until_ready()
46-
47-
chk_d2h_start = time.perf_counter()
38+
def _run_d2h_chunked(data_on_device, num_devices, chunks_per_device):
39+
total_workers = num_devices * chunks_per_device
4840
with concurrent.futures.ThreadPoolExecutor(max_workers=total_workers) as executor:
4941
d2h_futures = []
5042
for shard in data_on_device.addressable_shards:
@@ -58,49 +50,94 @@ def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devi
5850
d2h_futures.append(
5951
executor.submit(jax.device_get, shard.data[start:end])
6052
)
61-
_ = [f.result() for f in d2h_futures]
62-
chk_d2h_end = time.perf_counter()
63-
d2h_ms = (chk_d2h_end - chk_d2h_start) * 1000
64-
data_on_device.delete()
65-
66-
return h2d_ms, d2h_ms
67-
53+
for f in d2h_futures:
54+
f.result()
6855

69-
def _run_warmup(host_data, data_sharding, data_size_mb):
70-
# --- ADAPTIVE WARM UP ---
71-
if data_size_mb <= 128:
72-
warmup_iters = 50
73-
elif data_size_mb >= 8192:
74-
warmup_iters = 3
75-
else:
76-
warmup_iters = 10
77-
78-
for _ in range(warmup_iters):
79-
data_on_device = jax.device_put(host_data, data_sharding)
80-
data_on_device.block_until_ready()
81-
_ = jax.device_get(data_on_device)
82-
data_on_device.delete()
83-
84-
gc.collect()
8556

86-
def _get_chunks_per_device(data_size_mb, num_devices):
87-
# --- SMART CHUNKING CONFIG ---
88-
target_chunk_size_mb = 16
89-
max_global_threads = 256
57+
def _find_optimal_chunk_size(
58+
run_fn,
59+
num_devices,
60+
data_size_mb,
61+
search_min_size_mb=1,
62+
max_global_threads=256
63+
):
64+
"""Finds optimal chunk size by iterating over candidates."""
65+
print(" Searching for optimal chunk size...")
9066

67+
# Generate size candidates
68+
candidates_mb = []
69+
curr = search_min_size_mb
9170
data_per_device_mb = data_size_mb / num_devices
71+
72+
# Iterate until we cover the full data size per device
73+
while curr <= data_per_device_mb:
74+
candidates_mb.append(curr)
75+
curr *= 2
76+
# Ensure we test at least one candidate (e.g. if data < min_size)
77+
if not candidates_mb:
78+
candidates_mb.append(data_per_device_mb)
9279

93-
if data_per_device_mb < target_chunk_size_mb:
94-
chunks_per_device = 1
95-
else:
96-
chunks_per_device = int(data_per_device_mb / target_chunk_size_mb)
97-
98-
total_threads = num_devices * chunks_per_device
99-
if total_threads > max_global_threads:
100-
chunks_per_device = max(1, int(max_global_threads / num_devices))
80+
# Map sizes to counts, keeping track of unique counts to test
81+
candidates_counts = []
82+
seen_counts = set()
10183

102-
return chunks_per_device
84+
for size_mb in candidates_mb:
85+
if size_mb > data_per_device_mb:
86+
count = 1
87+
else:
88+
count = int(data_per_device_mb / size_mb)
89+
if count < 1: count = 1
90+
91+
# Filter by max global threads
92+
if (count * num_devices) > max_global_threads:
93+
continue
94+
95+
if count not in seen_counts:
96+
candidates_counts.append(count)
97+
seen_counts.add(count)
98+
99+
# Sort candidates (counts) ascending for clean output
100+
candidates_counts.sort()
101+
102+
if not candidates_counts:
103+
candidates_counts = [1]
103104

105+
best_chunk_count = 1
106+
best_median_bw = -1.0
107+
108+
# 5 search iterations + 3 warmup (before search)
109+
warmup_iters = 3
110+
search_iters = 5
111+
112+
try:
113+
for _ in range(warmup_iters):
114+
run_fn(1) # Warmup with 1 chunk
115+
except Exception:
116+
pass
117+
118+
for chunk_count in candidates_counts:
119+
times_ms = []
120+
try:
121+
for _ in range(search_iters):
122+
t_start = time.perf_counter()
123+
res = run_fn(chunk_count)
124+
t_end = time.perf_counter()
125+
126+
if isinstance(res, (int, float)):
127+
times_ms.append(res)
128+
else:
129+
times_ms.append((t_end - t_start) * 1000)
130+
131+
median_ms = np.median(times_ms)
132+
if median_ms > 0:
133+
if best_median_bw < 0 or median_ms < best_median_bw:
134+
best_median_bw = median_ms
135+
best_chunk_count = chunk_count
136+
except Exception as e:
137+
continue
138+
139+
print(f" Found optimal chunk count: {best_chunk_count} (approx size: {data_per_device_mb/best_chunk_count:.2f} MB)")
140+
return best_chunk_count
104141

105142
def benchmark_host_device(
106143
mesh_shape: str,
@@ -138,21 +175,47 @@ def benchmark_host_device(
138175
mesh, sharding.PartitionSpec(("x", "y"))
139176
)
140177

141-
# --- ADAPTIVE WARM UP ---
142-
_run_warmup(host_data, data_sharding, data_size_mb)
143-
144178
# Pre-calculate sharding info
145179
dummy_put = jax.device_put(host_data[:num_devices], data_sharding)
146180
target_devices = [s.device for s in dummy_put.addressable_shards]
147181
dummy_put.delete()
148182

149183
host_shards = np.split(host_data, num_devices, axis=0)
150184

185+
# --- SEARCH OPTIMAL CHUNKS ---
186+
# Define wrappers for search
187+
188+
def h2d_run_fn(c):
189+
bufs = _run_h2d_chunked(host_shards, target_devices, num_devices, c)
190+
for b in bufs: b.delete()
191+
192+
# H2D Search
193+
h2d_chunks = _find_optimal_chunk_size(h2d_run_fn, num_devices, data_size_mb)
194+
195+
# D2H Search
196+
# We need persistent data on device for D2H search to avoid H2D overhead in D2H measurement
197+
data_on_device_for_search = jax.device_put(host_data, data_sharding)
198+
data_on_device_for_search.block_until_ready()
199+
200+
def d2h_run_fn(c):
201+
# Force a new buffer to avoid host-side caching of device_get
202+
# Adding 0.0 creates a new DeviceArray with same sharding
203+
fresh_data = jax.lax.add(data_on_device_for_search, 0.0)
204+
fresh_data.block_until_ready()
205+
206+
t0 = time.perf_counter()
207+
_run_d2h_chunked(fresh_data, num_devices, c)
208+
t1 = time.perf_counter()
209+
210+
fresh_data.delete()
211+
return (t1 - t0) * 1000
212+
213+
d2h_chunks = _find_optimal_chunk_size(d2h_run_fn, num_devices, data_size_mb)
214+
215+
data_on_device_for_search.delete()
216+
151217
# Performance Lists
152218
h2d_perf, d2h_perf = [], []
153-
154-
# --- SMART CHUNKING CONFIG ---
155-
chunks_per_device = _get_chunks_per_device(data_size_mb, num_devices)
156219

157220
# Profiling Context
158221
if trace_dir:
@@ -171,37 +234,54 @@ def benchmark_host_device(
171234
step_context = contextlib.nullcontext()
172235

173236
with step_context:
174-
# Optimized Chunked Transfer (Sole Strategy)
175-
h2d_ms, d2h_ms = _run_chunked(
176-
host_data, data_sharding, host_shards, target_devices,
177-
num_devices, chunks_per_device
237+
# H2D
238+
t0 = time.perf_counter()
239+
chunked_buffers = _run_h2d_chunked(
240+
host_shards, target_devices, num_devices, h2d_chunks
178241
)
179-
h2d_perf.append(h2d_ms)
180-
d2h_perf.append(d2h_ms)
242+
t1 = time.perf_counter()
243+
h2d_perf.append((t1 - t0) * 1000)
244+
245+
for db in chunked_buffers:
246+
db.delete()
247+
248+
# D2H
249+
# We need data on device again
250+
data_on_device = jax.device_put(host_data, data_sharding)
251+
data_on_device.block_until_ready()
252+
253+
t2 = time.perf_counter()
254+
_run_d2h_chunked(data_on_device, num_devices, d2h_chunks)
255+
t3 = time.perf_counter()
256+
d2h_perf.append((t3 - t2) * 1000)
257+
258+
data_on_device.delete()
181259

182260
del host_data, host_shards
183261
gc.collect()
184262

185263
return {
186264
"H2D_Bandwidth": h2d_perf,
187265
"D2H_Bandwidth": d2h_perf,
188-
"Chunk_Count": chunks_per_device,
189-
"Thread_Count": num_devices * chunks_per_device,
266+
"H2D_Chunk_Size_MB": (data_size_mb / num_devices) / h2d_chunks if h2d_chunks > 0 else 0,
267+
"D2H_Chunk_Size_MB": (data_size_mb / num_devices) / d2h_chunks if d2h_chunks > 0 else 0,
268+
"Thread_Count": num_devices * max(h2d_chunks, d2h_chunks), # Approx
190269
}
191270

192271
def benchmark_host_device_calculate_metrics(
193272
mesh_shape: str,
194273
data_size_mb: int,
195274
H2D_Bandwidth: List[float],
196275
D2H_Bandwidth: List[float],
197-
Chunk_Count: int,
276+
H2D_Chunk_Size_MB: float,
277+
D2H_Chunk_Size_MB: float,
198278
Thread_Count: int,
199279
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
200280
"""Calculates metrics for Host-Device transfer."""
201281
params = locals().items()
202282

203283
# Filter out list params from metadata to avoid explosion
204-
metadata_keys = {"mesh_shape", "data_size_mb", "Chunk_Count", "Thread_Count"}
284+
metadata_keys = {"mesh_shape", "data_size_mb", "H2D_Chunk_Size_MB", "D2H_Chunk_Size_MB", "Thread_Count"}
205285
metadata = {k: v for k, v in params if k in metadata_keys}
206286

207287
metrics = {}

0 commit comments

Comments
 (0)