1818os .environ ["TPU_PREMAPPED_BUFFER_SIZE" ] = "68719476736"
1919os .environ ["TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES" ] = "68719476736"
2020
21- def get_tpu_devices (num_devices : int ):
22- devices = jax .devices ()
23- if len (devices ) < num_devices :
24- raise RuntimeError (f"Require { num_devices } devices, found { len (devices )} " )
25- return devices [:num_devices ]
2621
2722def benchmark_host_device (
28- num_devices : int ,
29- data_size_mb : int ,
23+ data_size_mib : int ,
3024 num_runs : int = 100 ,
3125 trace_dir : str = None ,
3226) -> Dict [str , Any ]:
3327 """Benchmarks H2D/D2H transfer using simple device_put/device_get."""
34- tpu_devices = get_tpu_devices (num_devices )
3528
36- num_elements = 1024 * 1024 * data_size_mb // np .dtype (np .float32 ).itemsize
29+ num_elements = 1024 * 1024 * data_size_mib // np .dtype (np .float32 ).itemsize
3730
3831 # Allocate Host Source Buffer
39- host_data = np .random .normal (size = (num_elements ,)).astype (np .float32 )
32+ column = 128
33+ host_data = np .random .normal (size = (num_elements // column , column )).astype (np .float32 )
4034
4135 print (
42- f"Benchmarking (Simple) Transfer with Data Size: { data_size_mb } MB on"
43- f" { num_devices } devices for { num_runs } iterations"
36+ f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations"
4437 )
4538
46- # Setup Mesh Sharding (1D)
47- mesh = sharding .Mesh (
48- np .array (tpu_devices ).reshape ((num_devices ,)), axis_names = ("x" ,)
49- )
50- # Shard the 1D array across "x"
51- partition_spec = sharding .PartitionSpec ("x" )
52-
53- data_sharding = sharding .NamedSharding (mesh , partition_spec )
54-
5539 # Performance Lists
5640 h2d_perf , d2h_perf = [], []
57-
41+
5842 # Profiling Context
5943 import contextlib
6044 if trace_dir :
@@ -65,7 +49,7 @@ def benchmark_host_device(
6549 with profiler_context :
6650 # Warmup
6751 for _ in range (2 ):
68- device_array = jax .device_put (host_data , data_sharding )
52+ device_array = jax .device_put (host_data )
6953 device_array .block_until_ready ()
7054 host_out = np .array (device_array )
7155 device_array .delete ()
@@ -83,15 +67,14 @@ def benchmark_host_device(
8367 t0 = time .perf_counter ()
8468
8569 # Simple device_put
86- device_array = jax .device_put (host_data , data_sharding )
70+ device_array = jax .device_put (host_data )
8771 device_array .block_until_ready ()
8872
8973 t1 = time .perf_counter ()
9074 h2d_perf .append ((t1 - t0 ) * 1000 )
9175
92- # Verify H2D shape/sharding
76+ # Verify H2D shape
9377 assert device_array .shape == host_data .shape
94- assert device_array .sharding == data_sharding
9578
9679 # D2H
9780 t2 = time .perf_counter ()
@@ -111,19 +94,15 @@ def benchmark_host_device(
11194 }
11295
11396def benchmark_host_device_calculate_metrics (
114- num_devices : int ,
115- data_size_mb : int ,
97+ data_size_mib : int ,
11698 H2D_Bandwidth_ms : List [float ],
11799 D2H_Bandwidth_ms : List [float ],
118100) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
119101 """Calculates metrics for Host-Device transfer."""
120102 params = locals ().items ()
121103
122- data_size_mib = data_size_mb
123-
124104 # Filter out list params from metadata to avoid explosion
125105 metadata_keys = {
126- "num_devices" ,
127106 "data_size_mib" ,
128107 }
129108 metadata = {k : v for k , v in params if k in metadata_keys }
@@ -134,7 +113,7 @@ def add_metric(name, ms_list):
134113 # Report Bandwidth (GiB/s)
135114 # Handle division by zero if ms is 0
136115 bw_list = [
137- ((data_size_mb / 1024 ) / (ms / 1000 )) if ms > 0 else 0.0
116+ ((data_size_mib / 1024 ) / (ms / 1000 )) if ms > 0 else 0.0
138117 for ms in ms_list
139118 ]
140119 stats_bw = MetricsStatistics (bw_list , f"{ name } _bw (GiB/s)" )
0 commit comments