@@ -20,9 +20,7 @@ def get_tpu_devices(num_devices: int):
2020 raise RuntimeError (f"Require { num_devices } devices, found { len (devices )} " )
2121 return devices [:num_devices ]
2222
23- def _run_chunked (host_data , data_sharding , host_shards , target_devices , num_devices , chunks_per_device ):
24- # Smart Chunked H2D
25- chk_h2d_start = time .perf_counter ()
23+ def _run_h2d_chunked (host_shards , target_devices , num_devices , chunks_per_device ):
2624 total_workers = num_devices * chunks_per_device
2725 with concurrent .futures .ThreadPoolExecutor (max_workers = total_workers ) as executor :
2826 chunked_futures = []
@@ -35,16 +33,10 @@ def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devi
3533 chunked_buffers = [f .result () for f in chunked_futures ]
3634 for db in chunked_buffers :
3735 db .block_until_ready ()
38- chk_h2d_end = time .perf_counter ()
39- h2d_ms = (chk_h2d_end - chk_h2d_start ) * 1000
40- for db in chunked_buffers :
41- db .delete ()
36+ return chunked_buffers
4237
43- # Smart Chunked D2H
44- data_on_device = jax .device_put (host_data , data_sharding )
45- data_on_device .block_until_ready ()
46-
47- chk_d2h_start = time .perf_counter ()
38+ def _run_d2h_chunked (data_on_device , num_devices , chunks_per_device ):
39+ total_workers = num_devices * chunks_per_device
4840 with concurrent .futures .ThreadPoolExecutor (max_workers = total_workers ) as executor :
4941 d2h_futures = []
5042 for shard in data_on_device .addressable_shards :
@@ -58,49 +50,94 @@ def _run_chunked(host_data, data_sharding, host_shards, target_devices, num_devi
5850 d2h_futures .append (
5951 executor .submit (jax .device_get , shard .data [start :end ])
6052 )
61- _ = [f .result () for f in d2h_futures ]
62- chk_d2h_end = time .perf_counter ()
63- d2h_ms = (chk_d2h_end - chk_d2h_start ) * 1000
64- data_on_device .delete ()
65-
66- return h2d_ms , d2h_ms
67-
53+ for f in d2h_futures :
54+ f .result ()
6855
69- def _run_warmup (host_data , data_sharding , data_size_mb ):
70- # --- ADAPTIVE WARM UP ---
71- if data_size_mb <= 128 :
72- warmup_iters = 50
73- elif data_size_mb >= 8192 :
74- warmup_iters = 3
75- else :
76- warmup_iters = 10
77-
78- for _ in range (warmup_iters ):
79- data_on_device = jax .device_put (host_data , data_sharding )
80- data_on_device .block_until_ready ()
81- _ = jax .device_get (data_on_device )
82- data_on_device .delete ()
83-
84- gc .collect ()
8556
86- def _get_chunks_per_device (data_size_mb , num_devices ):
87- # --- SMART CHUNKING CONFIG ---
88- target_chunk_size_mb = 16
89- max_global_threads = 256
57+ def _find_optimal_chunk_size (
58+ run_fn ,
59+ num_devices ,
60+ data_size_mb ,
61+ search_min_size_mb = 1 ,
62+ max_global_threads = 256
63+ ):
64+ """Finds optimal chunk size by iterating over candidates."""
65+ print (" Searching for optimal chunk size..." )
9066
67+ # Generate size candidates
68+ candidates_mb = []
69+ curr = search_min_size_mb
9170 data_per_device_mb = data_size_mb / num_devices
71+
72+ # Iterate until we cover the full data size per device
73+ while curr <= data_per_device_mb :
74+ candidates_mb .append (curr )
75+ curr *= 2
76+ # Ensure we test at least one candidate (e.g. if data < min_size)
77+ if not candidates_mb :
78+ candidates_mb .append (data_per_device_mb )
9279
93- if data_per_device_mb < target_chunk_size_mb :
94- chunks_per_device = 1
95- else :
96- chunks_per_device = int (data_per_device_mb / target_chunk_size_mb )
97-
98- total_threads = num_devices * chunks_per_device
99- if total_threads > max_global_threads :
100- chunks_per_device = max (1 , int (max_global_threads / num_devices ))
80+ # Map sizes to counts, keeping track of unique counts to test
81+ candidates_counts = []
82+ seen_counts = set ()
10183
102- return chunks_per_device
84+ for size_mb in candidates_mb :
85+ if size_mb > data_per_device_mb :
86+ count = 1
87+ else :
88+ count = int (data_per_device_mb / size_mb )
89+ if count < 1 : count = 1
90+
91+ # Filter by max global threads
92+ if (count * num_devices ) > max_global_threads :
93+ continue
94+
95+ if count not in seen_counts :
96+ candidates_counts .append (count )
97+ seen_counts .add (count )
98+
99+ # Sort candidates (counts) ascending for clean output
100+ candidates_counts .sort ()
101+
102+ if not candidates_counts :
103+ candidates_counts = [1 ]
103104
105+ best_chunk_count = 1
106+ best_median_bw = - 1.0
107+
108+ # 5 search iterations + 3 warmup (before search)
109+ warmup_iters = 3
110+ search_iters = 5
111+
112+ try :
113+ for _ in range (warmup_iters ):
114+ run_fn (1 ) # Warmup with 1 chunk
115+ except Exception :
116+ pass
117+
118+ for chunk_count in candidates_counts :
119+ times_ms = []
120+ try :
121+ for _ in range (search_iters ):
122+ t_start = time .perf_counter ()
123+ res = run_fn (chunk_count )
124+ t_end = time .perf_counter ()
125+
126+ if isinstance (res , (int , float )):
127+ times_ms .append (res )
128+ else :
129+ times_ms .append ((t_end - t_start ) * 1000 )
130+
131+ median_ms = np .median (times_ms )
132+ if median_ms > 0 :
133+ if best_median_bw < 0 or median_ms < best_median_bw :
134+ best_median_bw = median_ms
135+ best_chunk_count = chunk_count
136+ except Exception as e :
137+ continue
138+
139+ print (f" Found optimal chunk count: { best_chunk_count } (approx size: { data_per_device_mb / best_chunk_count :.2f} MB)" )
140+ return best_chunk_count
104141
105142def benchmark_host_device (
106143 mesh_shape : str ,
@@ -138,21 +175,47 @@ def benchmark_host_device(
138175 mesh , sharding .PartitionSpec (("x" , "y" ))
139176 )
140177
141- # --- ADAPTIVE WARM UP ---
142- _run_warmup (host_data , data_sharding , data_size_mb )
143-
144178 # Pre-calculate sharding info
145179 dummy_put = jax .device_put (host_data [:num_devices ], data_sharding )
146180 target_devices = [s .device for s in dummy_put .addressable_shards ]
147181 dummy_put .delete ()
148182
149183 host_shards = np .split (host_data , num_devices , axis = 0 )
150184
185+ # --- SEARCH OPTIMAL CHUNKS ---
186+ # Define wrappers for search
187+
188+ def h2d_run_fn (c ):
189+ bufs = _run_h2d_chunked (host_shards , target_devices , num_devices , c )
190+ for b in bufs : b .delete ()
191+
192+ # H2D Search
193+ h2d_chunks = _find_optimal_chunk_size (h2d_run_fn , num_devices , data_size_mb )
194+
195+ # D2H Search
196+ # We need persistent data on device for D2H search to avoid H2D overhead in D2H measurement
197+ data_on_device_for_search = jax .device_put (host_data , data_sharding )
198+ data_on_device_for_search .block_until_ready ()
199+
200+ def d2h_run_fn (c ):
201+ # Force a new buffer to avoid host-side caching of device_get
202+ # Adding 0.0 creates a new DeviceArray with same sharding
203+ fresh_data = jax .lax .add (data_on_device_for_search , 0.0 )
204+ fresh_data .block_until_ready ()
205+
206+ t0 = time .perf_counter ()
207+ _run_d2h_chunked (fresh_data , num_devices , c )
208+ t1 = time .perf_counter ()
209+
210+ fresh_data .delete ()
211+ return (t1 - t0 ) * 1000
212+
213+ d2h_chunks = _find_optimal_chunk_size (d2h_run_fn , num_devices , data_size_mb )
214+
215+ data_on_device_for_search .delete ()
216+
151217 # Performance Lists
152218 h2d_perf , d2h_perf = [], []
153-
154- # --- SMART CHUNKING CONFIG ---
155- chunks_per_device = _get_chunks_per_device (data_size_mb , num_devices )
156219
157220 # Profiling Context
158221 if trace_dir :
@@ -171,37 +234,54 @@ def benchmark_host_device(
171234 step_context = contextlib .nullcontext ()
172235
173236 with step_context :
174- # Optimized Chunked Transfer (Sole Strategy)
175- h2d_ms , d2h_ms = _run_chunked (
176- host_data , data_sharding , host_shards , target_devices ,
177- num_devices , chunks_per_device
237+ # H2D
238+ t0 = time . perf_counter ()
239+ chunked_buffers = _run_h2d_chunked (
240+ host_shards , target_devices , num_devices , h2d_chunks
178241 )
179- h2d_perf .append (h2d_ms )
180- d2h_perf .append (d2h_ms )
242+ t1 = time .perf_counter ()
243+ h2d_perf .append ((t1 - t0 ) * 1000 )
244+
245+ for db in chunked_buffers :
246+ db .delete ()
247+
248+ # D2H
249+ # We need data on device again
250+ data_on_device = jax .device_put (host_data , data_sharding )
251+ data_on_device .block_until_ready ()
252+
253+ t2 = time .perf_counter ()
254+ _run_d2h_chunked (data_on_device , num_devices , d2h_chunks )
255+ t3 = time .perf_counter ()
256+ d2h_perf .append ((t3 - t2 ) * 1000 )
257+
258+ data_on_device .delete ()
181259
182260 del host_data , host_shards
183261 gc .collect ()
184262
185263 return {
186264 "H2D_Bandwidth" : h2d_perf ,
187265 "D2H_Bandwidth" : d2h_perf ,
188- "Chunk_Count" : chunks_per_device ,
189- "Thread_Count" : num_devices * chunks_per_device ,
266+ "H2D_Chunk_Size_MB" : (data_size_mb / num_devices ) / h2d_chunks if h2d_chunks > 0 else 0 ,
267+ "D2H_Chunk_Size_MB" : (data_size_mb / num_devices ) / d2h_chunks if d2h_chunks > 0 else 0 ,
268+ "Thread_Count" : num_devices * max (h2d_chunks , d2h_chunks ), # Approx
190269 }
191270
192271def benchmark_host_device_calculate_metrics (
193272 mesh_shape : str ,
194273 data_size_mb : int ,
195274 H2D_Bandwidth : List [float ],
196275 D2H_Bandwidth : List [float ],
197- Chunk_Count : int ,
276+ H2D_Chunk_Size_MB : float ,
277+ D2H_Chunk_Size_MB : float ,
198278 Thread_Count : int ,
199279) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
200280 """Calculates metrics for Host-Device transfer."""
201281 params = locals ().items ()
202282
203283 # Filter out list params from metadata to avoid explosion
204- metadata_keys = {"mesh_shape" , "data_size_mb" , "Chunk_Count " , "Thread_Count" }
284+ metadata_keys = {"mesh_shape" , "data_size_mb" , "H2D_Chunk_Size_MB" , "D2H_Chunk_Size_MB " , "Thread_Count" }
205285 metadata = {k : v for k , v in params if k in metadata_keys }
206286
207287 metrics = {}
0 commit comments