@@ -38,167 +38,6 @@ def benchmark_host_device(
3838 num_devices_to_perform_h2d = 1
3939 target_devices = jax .devices ()[:num_devices_to_perform_h2d ]
4040
41- < << << << Updated upstream
42- print (
43- f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations with { h2d_type = } " ,
44- flush = True
45- )
46-
47- # Performance Lists
48- h2d_perf , d2h_perf = [], []
49-
50- # Profiling Context
51- import contextlib
52- if trace_dir :
53- profiler_context = jax .profiler .trace (trace_dir )
54- else :
55- profiler_context = contextlib .nullcontext ()
56-
57- with profiler_context :
58- # Warmup
59- for _ in range (2 ):
60- device_array = jax .device_put (host_data )
61- device_array .block_until_ready ()
62- host_out = np .array (device_array )
63- device_array .delete ()
64- del host_out
65-
66- for i in range (num_runs ):
67- # Step Context
68- if trace_dir :
69- step_context = jax .profiler .StepTraceAnnotation ("host_device" , step_num = i )
70- else :
71- step_context = contextlib .nullcontext ()
72-
73- with step_context :
74- # H2D
75- if h2d_type == "simple" :
76- t0 = time .perf_counter ()
77- # Simple device_put
78- device_array = jax .device_put (host_data )
79- device_array .block_until_ready ()
80- t1 = time .perf_counter ()
81-
82- # Verify H2D shape
83- assert device_array .shape == host_data .shape
84-
85- h2d_perf .append ((t1 - t0 ) * 1000 )
86-
87- # D2H
88- t2 = time .perf_counter ()
89-
90- # Simple device_get
91- # Note: device_get returns a numpy array (copy)
92- _ = jax .device_get (device_array )
93-
94- t3 = time .perf_counter ()
95- d2h_perf .append ((t3 - t2 ) * 1000 )
96-
97- device_array .delete ()
98- elif h2d_type == "pipelined" :
99- target_chunk_size_mib = 16 # Sweet spot from profiling
100- num_devices = len (target_devices )
101-
102- tensors_on_device = []
103-
104- # Calculate chunks per device
105- data_per_dev = data_size_mib / num_devices
106- chunks_per_dev = int (data_per_dev / target_chunk_size_mib )
107- chunks_per_dev = max (1 , chunks_per_dev )
108-
109- chunks = np .array_split (host_data , chunks_per_dev * num_devices , axis = 0 )
110-
111- t0 = time .perf_counter ()
112- if chunks_per_dev > 1 :
113- # We need to map chunks to the correct device
114- # This simple example assumes chunks are perfectly divisible and ordered
115- # In production, use `jax.sharding` mesh logic for complex layouts
116-
117- # approach 1: simple for loop
118- for idx , chunk in enumerate (chunks ):
119- if num_devices > 1 :
120- dev = target_devices [idx % num_devices ]
121- else :
122- dev = target_devices [0 ]
123- tensors_on_device .append (jax .device_put (chunk , dev ))
124- # Re-assemble array
125- result = jnp .vstack (tensors_on_device )
126- # Wait for all chunks to be transferred
127- result .block_until_ready ()
128-
129- # approach 2: generator (slightly less overhead)
130- # def chunk_generator(num_devices, chunks_per_dev):
131- # for n in range(chunks_per_dev):
132- # for d in range(num_devices):
133- # # 1. Get the specific small chunk
134- # chunk = chunks[d*chunks_per_dev+n]
135-
136- # # 2. Trigger an individual DMA transfer for this specific chunk
137- # # This is where NUMA-local memory access matters
138- # yield jax.device_put(chunk, target_devices[d])
139-
140- # # Re-assemble array
141- # result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev)))
142- # # Wait for all chunks to be transferred
143- # result.block_until_ready()
144- else :
145- print (f"Warning: { data_size_mib = } is not larger than { target_chunk_size_mib = } , falling back to standard JAX put." )
146- # Fallback to standard JAX put for small data
147- result = jax .device_put (host_data , target_devices [0 ])
148- result .block_until_ready ()
149-
150- t1 = time .perf_counter ()
151- h2d_perf .append ((t1 - t0 ) * 1000 )
152-
153- # D2H
154- t2 = time .perf_counter ()
155- # Simple device_get
156- # Note: device_get returns a numpy array (copy)
157- _ = jax .device_get (result )
158-
159- t3 = time .perf_counter ()
160- if not np .allclose (result , host_data ):
161- print ("pipelined result not equal to host_data" )
162- d2h_perf .append ((t3 - t2 ) * 1000 )
163-
164- for r in tensors_on_device :
165- r .delete ()
166- del tensors_on_device
167-
168- return {
169- "H2D_Bandwidth_ms" : h2d_perf ,
170- "D2H_Bandwidth_ms" : d2h_perf ,
171- }
172-
173- def benchmark_host_device_calculate_metrics (
174- data_size_mib : int ,
175- H2D_Bandwidth_ms : List [float ],
176- D2H_Bandwidth_ms : List [float ],
177- h2d_type : str = "simple" ,
178- ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
179- """Calculates metrics for Host-Device transfer."""
180- params = locals ().items ()
181-
182- # Filter out list params from metadata to avoid explosion
183- metadata_keys = {
184- "data_size_mib" ,
185- }
186- metadata = {k : v for k , v in params if k in metadata_keys }
187- metadata ["dtype" ] = "float32"
188- metadata ["h2d_type" ] = h2d_type
189-
190- metrics = {}
191-
192- def add_metric (name , ms_list ):
193- # Report Bandwidth (GiB/s)
194- # Handle division by zero if ms is 0
195- bw_list = [
196- ((data_size_mib / 1024 ) / (ms / 1000 )) if ms > 0 else 0.0
197- for ms in ms_list
198- ]
199- stats_bw = MetricsStatistics (bw_list , f"{ name } _bw (GiB/s)" )
200- == == == =
201- >> >> >> > Stashed changes
20241 print (
20342 f"Benchmarking Transfer with Data Size: { data_size_mib } MB for { num_runs } iterations with { h2d_type = } " ,
20443 flush = True
0 commit comments