1818os .environ ["TPU_PREMAPPED_BUFFER_SIZE" ] = "68719476736"
1919os .environ ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES" ] = "68719476736"
2020
21- def get_tpu_devices (num_devices : int ):
22- devices = jax .devices ()
23- if len (devices ) < num_devices :
24- raise RuntimeError (f"Require { num_devices } devices, found { len (devices )} " )
25- return devices [:num_devices ]
21+
2622
2723def benchmark_host_device (
28- num_devices : int ,
29- data_size_mb : int ,
24+ data_size_mib : int ,
3025 num_runs : int = 100 ,
3126 trace_dir : str = None ,
3227) -> Dict [str , Any ]:
3328 """Benchmarks H2D/D2H transfer using simple device_put/device_get."""
34- tpu_devices = get_tpu_devices (num_devices )
3529
36- num_elements = 1024 * 1024 * data_size_mb // np .dtype (np .float32 ).itemsize
30+ num_elements = 1024 * 1024 * data_size_mib // np .dtype (np .float32 ).itemsize
3731
3832 # Allocate Host Source Buffer
39- host_data = np .random .normal (size = (num_elements ,)).astype (np .float32 )
33+ column = 128
34+ host_data = np .random .normal (size = (num_elements // column , column )).astype (np .float32 )
4035
4136 print (
42- f"Benchmarking (Simple) Transfer with Data Size: { data_size_mb } MB on"
43- f" { num_devices } devices for { num_runs } iterations"
37+ f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations"
4438 )
4539
46- # Setup Mesh Sharding (1D)
47- mesh = sharding .Mesh (
48- np .array (tpu_devices ).reshape ((num_devices ,)), axis_names = ("x" ,)
49- )
50- # Shard the 1D array across "x"
51- partition_spec = sharding .PartitionSpec ("x" )
52-
53- data_sharding = sharding .NamedSharding (mesh , partition_spec )
54-
5540 # Performance Lists
5641 h2d_perf , d2h_perf = [], []
57-
42+
5843 # Profiling Context
5944 import contextlib
6045 if trace_dir :
@@ -65,7 +50,7 @@ def benchmark_host_device(
6550 with profiler_context :
6651 # Warmup
6752 for _ in range (2 ):
68- device_array = jax .device_put (host_data , data_sharding )
53+ device_array = jax .device_put (host_data )
6954 device_array .block_until_ready ()
7055 host_out = np .array (device_array )
7156 device_array .delete ()
@@ -83,15 +68,14 @@ def benchmark_host_device(
8368 t0 = time .perf_counter ()
8469
8570 # Simple device_put
86- device_array = jax .device_put (host_data , data_sharding )
71+ device_array = jax .device_put (host_data )
8772 device_array .block_until_ready ()
8873
8974 t1 = time .perf_counter ()
9075 h2d_perf .append ((t1 - t0 ) * 1000 )
9176
92- # Verify H2D shape/sharding
77+ # Verify H2D shape
9378 assert device_array .shape == host_data .shape
94- assert device_array .sharding == data_sharding
9579
9680 # D2H
9781 t2 = time .perf_counter ()
@@ -111,19 +95,17 @@ def benchmark_host_device(
11195 }
11296
11397def benchmark_host_device_calculate_metrics (
114- num_devices : int ,
115- data_size_mb : int ,
98+ data_size_mib : int ,
11699 H2D_Bandwidth_ms : List [float ],
117100 D2H_Bandwidth_ms : List [float ],
118101) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
119102 """Calculates metrics for Host-Device transfer."""
120103 params = locals ().items ()
121104
122- data_size_mib = data_size_mb
105+ data_size_mib = data_size_mib
123106
124107 # Filter out list params from metadata to avoid explosion
125108 metadata_keys = {
126- "num_devices" ,
127109 "data_size_mib" ,
128110 }
129111 metadata = {k : v for k , v in params if k in metadata_keys }
@@ -134,7 +116,7 @@ def add_metric(name, ms_list):
134116 # Report Bandwidth (GiB/s)
135117 # Handle division by zero if ms is 0
136118 bw_list = [
137- ((data_size_mb / 1024 ) / (ms / 1000 )) if ms > 0 else 0.0
119+ ((data_size_mib / 1024 ) / (ms / 1000 )) if ms > 0 else 0.0
138120 for ms in ms_list
139121 ]
140122 stats_bw = MetricsStatistics (bw_list , f"{ name } _bw (GiB/s)" )
0 commit comments