55from typing import Any , Dict , Tuple , List
66
77import jax
8- from jax import sharding
8+ from jax import numpy as jnp
99import numpy as np
1010from benchmark_utils import MetricsStatistics
1111
@@ -23,17 +23,23 @@ def benchmark_host_device(
2323 data_size_mib : int ,
2424 num_runs : int = 100 ,
2525 trace_dir : str = None ,
26+ h2d_type : str = "simple" ,
2627) -> Dict [str , Any ]:
27- """Benchmarks H2D/D2H transfer using simple device_put/device_get."""
28+ """Benchmarks H2D/D2H transfer using device_put/device_get."""
2829
2930 num_elements = 1024 * 1024 * data_size_mib // np .dtype (np .float32 ).itemsize
3031
3132 # Allocate Host Source Buffer
3233 column = 128
3334 host_data = np .random .normal (size = (num_elements // column , column )).astype (np .float32 )
3435
36+ # Used in pipelined flow
37+ # TODO: turn into a param
38+ num_devices_to_perform_h2d = 1
39+ target_devices = jax .devices ()[:num_devices_to_perform_h2d ]
40+
3541 print (
36- f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations" ,
42+ f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations with { h2d_type = } " ,
3743 flush = True
3844 )
3945
@@ -65,29 +71,98 @@ def benchmark_host_device(
6571
6672 with step_context :
6773 # H2D
68- t0 = time .perf_counter ()
69-
70- # Simple device_put
71- device_array = jax .device_put (host_data )
72- device_array .block_until_ready ()
73-
74- t1 = time .perf_counter ()
75- h2d_perf .append ((t1 - t0 ) * 1000 )
76-
77- # Verify H2D shape
78- assert device_array .shape == host_data .shape
79-
80- # D2H
81- t2 = time .perf_counter ()
82-
83- # Simple device_get
84- # Note: device_get returns a numpy array (copy)
85- _ = jax .device_get (device_array )
86-
87- t3 = time .perf_counter ()
88- d2h_perf .append ((t3 - t2 ) * 1000 )
74+ if h2d_type == "simple" :
75+ t0 = time .perf_counter ()
76+ # Simple device_put
77+ device_array = jax .device_put (host_data )
78+ device_array .block_until_ready ()
79+ t1 = time .perf_counter ()
80+
81+ # Verify H2D shape
82+ assert device_array .shape == host_data .shape
83+
84+ h2d_perf .append ((t1 - t0 ) * 1000 )
8985
90- device_array .delete ()
86+ # D2H
87+ t2 = time .perf_counter ()
88+
89+ # Simple device_get
90+ # Note: device_get returns a numpy array (copy)
91+ _ = jax .device_get (device_array )
92+
93+ t3 = time .perf_counter ()
94+ d2h_perf .append ((t3 - t2 ) * 1000 )
95+
96+ device_array .delete ()
97+ elif h2d_type == "pipelined" :
98+ target_chunk_size_mib = 16 # Sweet spot from profiling
99+ num_devices = len (target_devices )
100+
101+ tensors_on_device = []
102+
103+ # Calculate chunks per device
104+ data_per_dev = data_size_mib / num_devices
105+ chunks_per_dev = int (data_per_dev / target_chunk_size_mib )
106+ chunks_per_dev = max (1 , chunks_per_dev )
107+
108+ chunks = np .array_split (host_data , chunks_per_dev * num_devices , axis = 0 )
109+
110+ t0 = time .perf_counter ()
111+ if chunks_per_dev > 1 :
112+ # We need to map chunks to the correct device
113+ # This simple example assumes chunks are perfectly divisible and ordered
114+ # In production, use `jax.sharding` mesh logic for complex layouts
115+
116+ # approach 1: simple for loop
117+ for idx , chunk in enumerate (chunks ):
118+ if num_devices > 1 :
119+ dev = target_devices [idx % num_devices ]
120+ else :
121+ dev = target_devices [0 ]
122+ tensors_on_device .append (jax .device_put (chunk , dev ))
123+ # Re-assemble array
124+ result = jnp .vstack (tensors_on_device )
125+ # Wait for all chunks to be transferred
126+ result .block_until_ready ()
127+
128+ # approach 2: generator (slightly less overhead)
129+ # def chunk_generator(num_devices, chunks_per_dev):
130+ # for n in range(chunks_per_dev):
131+ # for d in range(num_devices):
132+ # # 1. Get the specific small chunk
133+ # chunk = chunks[d*chunks_per_dev+n]
134+
135+ # # 2. Trigger an individual DMA transfer for this specific chunk
136+ # # This is where NUMA-local memory access matters
137+ # yield jax.device_put(chunk, target_devices[d])
138+
139+ # # Re-assemble array
140+ # result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev)))
141+ # # Wait for all chunks to be transferred
142+ # result.block_until_ready()
143+ else :
144+ print (f"Warning: { data_size_mib = } is not larger than { target_chunk_size_mib = } , falling back to standard JAX put." )
145+ # Fallback to standard JAX put for small data
146+ result = jax .device_put (host_data , target_devices [0 ])
147+ result .block_until_ready ()
148+
149+ t1 = time .perf_counter ()
150+ h2d_perf .append ((t1 - t0 ) * 1000 )
151+
152+ # D2H
153+ t2 = time .perf_counter ()
154+ # Simple device_get
155+ # Note: device_get returns a numpy array (copy)
156+ _ = jax .device_get (result )
157+
158+ t3 = time .perf_counter ()
159+ if not np .allclose (result , host_data ):
160+ print ("pipelined result not equal to host_data" )
161+ d2h_perf .append ((t3 - t2 ) * 1000 )
162+
163+ for r in tensors_on_device :
164+ r .delete ()
165+ del tensors_on_device
91166
92167 return {
93168 "H2D_Bandwidth_ms" : h2d_perf ,
@@ -98,6 +173,7 @@ def benchmark_host_device_calculate_metrics(
98173 data_size_mib : int ,
99174 H2D_Bandwidth_ms : List [float ],
100175 D2H_Bandwidth_ms : List [float ],
176+ h2d_type : str = "simple" ,
101177) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
102178 """Calculates metrics for Host-Device transfer."""
103179 params = locals ().items ()
0 commit comments