1- """Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline) ."""
1+ """Benchmarks Host-to-Device and Device-to-Host transfer performance."""
22
3+ import concurrent .futures
4+ import gc
35import time
46import os
57from typing import Any , Dict , Tuple , List
911import numpy as np
1012from benchmark_utils import MetricsStatistics
1113
12- # 64 GiB
13- os .environ ["TPU_PREMAPPED_BUFFER_SIZE" ] = "68719476736"
14+ os .environ ["TPU_PREMAPPED_BUFFER_SIZE" ] = "68719476736" # 64 GiB
1415os .environ ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES" ] = "68719476736"
1516
1617def get_tpu_devices (num_devices : int ):
@@ -19,26 +20,107 @@ def get_tpu_devices(num_devices: int):
1920 raise RuntimeError (f"Require { num_devices } devices, found { len (devices )} " )
2021 return devices [:num_devices ]
2122
23+ def _run_chunked (host_data , data_sharding , host_shards , target_devices , num_devices , chunks_per_device ):
24+ # Smart Chunked H2D
25+ chk_h2d_start = time .perf_counter ()
26+ total_workers = num_devices * chunks_per_device
27+ with concurrent .futures .ThreadPoolExecutor (max_workers = total_workers ) as executor :
28+ chunked_futures = []
29+ for shard , dev in zip (host_shards , target_devices ):
30+ sub_chunks = np .array_split (shard , chunks_per_device , axis = 0 )
31+ for chunk in sub_chunks :
32+ chunked_futures .append (
33+ executor .submit (jax .device_put , chunk , dev )
34+ )
35+ chunked_buffers = [f .result () for f in chunked_futures ]
36+ for db in chunked_buffers :
37+ db .block_until_ready ()
38+ chk_h2d_end = time .perf_counter ()
39+ h2d_ms = (chk_h2d_end - chk_h2d_start ) * 1000
40+ for db in chunked_buffers :
41+ db .delete ()
42+
43+ # Smart Chunked D2H
44+ data_on_device = jax .device_put (host_data , data_sharding )
45+ data_on_device .block_until_ready ()
46+
47+ chk_d2h_start = time .perf_counter ()
48+ with concurrent .futures .ThreadPoolExecutor (max_workers = total_workers ) as executor :
49+ d2h_futures = []
50+ for shard in data_on_device .addressable_shards :
51+ # Direct slicing on device array to avoid copy
52+ shard_len = shard .data .shape [0 ]
53+ chunk_size = (shard_len + chunks_per_device - 1 ) // chunks_per_device
54+ for i in range (chunks_per_device ):
55+ start = i * chunk_size
56+ end = min ((i + 1 ) * chunk_size , shard_len )
57+ if start < end :
58+ d2h_futures .append (
59+ executor .submit (jax .device_get , shard .data [start :end ])
60+ )
61+ _ = [f .result () for f in d2h_futures ]
62+ chk_d2h_end = time .perf_counter ()
63+ d2h_ms = (chk_d2h_end - chk_d2h_start ) * 1000
64+ data_on_device .delete ()
65+
66+ return h2d_ms , d2h_ms
67+
68+
69+ def _run_warmup (host_data , data_sharding , data_size_mb ):
70+ # --- ADAPTIVE WARM UP ---
71+ if data_size_mb <= 128 :
72+ warmup_iters = 50
73+ elif data_size_mb >= 8192 :
74+ warmup_iters = 3
75+ else :
76+ warmup_iters = 10
77+
78+ for _ in range (warmup_iters ):
79+ data_on_device = jax .device_put (host_data , data_sharding )
80+ data_on_device .block_until_ready ()
81+ _ = jax .device_get (data_on_device )
82+ data_on_device .delete ()
83+
84+ gc .collect ()
85+
86+ def _get_chunks_per_device (data_size_mb , num_devices ):
87+ # --- SMART CHUNKING CONFIG ---
88+ target_chunk_size_mb = 16
89+ max_global_threads = 256
90+
91+ data_per_device_mb = data_size_mb / num_devices
92+
93+ if data_per_device_mb < target_chunk_size_mb :
94+ chunks_per_device = 1
95+ else :
96+ chunks_per_device = int (data_per_device_mb / target_chunk_size_mb )
97+
98+ total_threads = num_devices * chunks_per_device
99+ if total_threads > max_global_threads :
100+ chunks_per_device = max (1 , int (max_global_threads / num_devices ))
101+
102+ return chunks_per_device
103+
104+
22105def benchmark_host_device (
23106 mesh_shape : str ,
24107 data_size_mb : int ,
25108 num_runs : int = 100 ,
26109 trace_dir : str = None ,
27110) -> Dict [str , Any ]:
28- """Benchmarks H2D/D2H transfer using simple device_put/device_get ."""
111+ """Benchmarks H2D/D2H transfer using smart chunking ."""
29112 dims = [int (d ) for d in mesh_shape .split ("x" )]
30113 mesh_shape = tuple (dims )
31114
32115 num_devices = int (np .prod (mesh_shape ))
33116 tpu_devices = get_tpu_devices (num_devices )
34117
35- num_elements = 1024 * 1024 * data_size_mb // np .dtype (np .float32 ).itemsize
118+ rows = 1024 * data_size_mb // np .dtype (np .float32 ).itemsize
36119
37- # Allocate Host Source Buffer
38- host_data = np .ones ((num_elements ,), dtype = np .float32 )
120+ host_data = np .ones ((rows , 8 , 128 ), dtype = np .float32 )
39121
40122 print (
41- f"Benchmarking (Simple) Transfer with Data Size: { data_size_mb } MB on"
123+ f"Benchmarking Transfer with Data Size: { data_size_mb } MB on"
42124 f" { num_devices } devices for { num_runs } iterations"
43125 )
44126
@@ -47,25 +129,37 @@ def benchmark_host_device(
47129 mesh = sharding .Mesh (
48130 np .array (tpu_devices ).reshape (mesh_shape ), axis_names = ("x" ,)
49131 )
50- # Shard the 1D array across "x"
51- partition_spec = sharding .PartitionSpec ("x" )
132+ data_sharding = sharding .NamedSharding (mesh , sharding .PartitionSpec ("x" ))
52133 else :
53134 mesh = sharding .Mesh (
54135 np .array (tpu_devices ).reshape (mesh_shape ), axis_names = ("x" , "y" )
55136 )
56- # Shard the 1D array across BOTH "x" and "y" (product sharding)
57- partition_spec = sharding .PartitionSpec (("x" , "y" ))
58-
59- data_sharding = sharding .NamedSharding (mesh , partition_spec )
137+ data_sharding = sharding .NamedSharding (
138+ mesh , sharding .PartitionSpec (("x" , "y" ))
139+ )
60140
141+ # --- ADAPTIVE WARM UP ---
142+ _run_warmup (host_data , data_sharding , data_size_mb )
143+
144+ # Pre-calculate sharding info
145+ dummy_put = jax .device_put (host_data [:num_devices ], data_sharding )
146+ target_devices = [s .device for s in dummy_put .addressable_shards ]
147+ dummy_put .delete ()
148+
149+ host_shards = np .split (host_data , num_devices , axis = 0 )
150+
61151 # Performance Lists
62152 h2d_perf , d2h_perf = [], []
153+
154+ # --- SMART CHUNKING CONFIG ---
155+ chunks_per_device = _get_chunks_per_device (data_size_mb , num_devices )
63156
64157 # Profiling Context
65- import contextlib
66158 if trace_dir :
67159 profiler_context = jax .profiler .trace (trace_dir )
68160 else :
161+ # No-op context manager
162+ import contextlib
69163 profiler_context = contextlib .nullcontext ()
70164
71165 with profiler_context :
@@ -77,53 +171,37 @@ def benchmark_host_device(
77171 step_context = contextlib .nullcontext ()
78172
79173 with step_context :
80- # H2D
81- t0 = time .perf_counter ()
82-
83- # Simple device_put
84- device_array = jax .device_put (host_data , data_sharding )
85- device_array .block_until_ready ()
86-
87- t1 = time .perf_counter ()
88- h2d_perf .append ((t1 - t0 ) * 1000 )
89-
90- # Verify H2D shape/sharding
91- assert device_array .shape == host_data .shape
92- assert device_array .sharding == data_sharding
93-
94- # D2H
95- t2 = time .perf_counter ()
96-
97- # Simple device_get
98- # Note: device_get returns a numpy array (copy)
99- _ = jax .device_get (device_array )
100-
101- t3 = time .perf_counter ()
102- d2h_perf .append ((t3 - t2 ) * 1000 )
103-
104- device_array .delete ()
174+ # Optimized Chunked Transfer (Sole Strategy)
175+ h2d_ms , d2h_ms = _run_chunked (
176+ host_data , data_sharding , host_shards , target_devices ,
177+ num_devices , chunks_per_device
178+ )
179+ h2d_perf .append (h2d_ms )
180+ d2h_perf .append (d2h_ms )
181+
182+ del host_data , host_shards
183+ gc .collect ()
105184
106185 return {
107186 "H2D_Bandwidth" : h2d_perf ,
108187 "D2H_Bandwidth" : d2h_perf ,
188+ "Chunk_Count" : chunks_per_device ,
189+ "Thread_Count" : num_devices * chunks_per_device ,
109190 }
110191
111192def benchmark_host_device_calculate_metrics (
112193 mesh_shape : str ,
113194 data_size_mb : int ,
114195 H2D_Bandwidth : List [float ],
115196 D2H_Bandwidth : List [float ],
197+ Chunk_Count : int ,
198+ Thread_Count : int ,
116199) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
117200 """Calculates metrics for Host-Device transfer."""
118201 params = locals ().items ()
119202
120- data_size_mib = data_size_mb
121-
122203 # Filter out list params from metadata to avoid explosion
123- metadata_keys = {
124- "mesh_shape" ,
125- "data_size_mib" ,
126- }
204+ metadata_keys = {"mesh_shape" , "data_size_mb" , "Chunk_Count" , "Thread_Count" }
127205 metadata = {k : v for k , v in params if k in metadata_keys }
128206
129207 metrics = {}
0 commit comments