1616 MetricsStatistics ,
1717 get_metrics_helper ,
1818 str_to_dtype ,
19- get_peak_flops_multiplier
19+ get_peak_flops_multiplier ,
20+ unified_bytes_metrics ,
2021)
2122from common import MARKER
2223import jax
@@ -178,18 +179,6 @@ def gemm_all_reduce_calculate_metrics(
178179 peak_flops_multiplier = get_peak_flops_multiplier (dtype_str )
179180 peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier
180181
181- # Unified FLOPs metrics
182- metadata , metrics = unified_flops_metrics (
183- m ,
184- n ,
185- k ,
186- time_ms_list ,
187- total_flops_per_device ,
188- total_flops_all_devices ,
189- peak_flops ,
190- dtype = dtype_str ,
191- )
192-
193182 # Calculate Bandwidth for AllReduce
194183 # AllReduce moves Matrix C: M x N
195184 matrix_c_size_bytes = m * n * dtype .dtype .itemsize
@@ -201,3 +190,227 @@ def gemm_all_reduce_calculate_metrics(
201190 )
202191
203192 return metadata , metrics
193+
194+
195+ def gemm_only (
196+ m : int ,
197+ k : int ,
198+ n : int ,
199+ dtype : jnp .dtype = jnp .bfloat16 ,
200+ num_runs : int = 1 ,
201+ trace_dir : str = None ,
202+ ) -> Dict [str , Any ]:
203+ """Benchmarks only the Matmul part of gemm_all_reduce.
204+
205+ A: [M, K]
206+ B: [K, N]
207+ C = A @ B: [M, N]
208+ """
209+
210+ dtype_str = dtype .dtype .name
211+ print (f"Running gemm_only benchmark with m={ m } , k={ k } , n={ n } , dtype={ dtype_str } , runs={ num_runs } " )
212+
213+ def f (x , y ):
214+ with jax .named_scope (MARKER ):
215+ # Matmul
216+ acc = jax .numpy .einsum (
217+ "ij,jk->ik" , x , y , preferred_element_type = jnp .float32
218+ )
219+ c = acc .astype (dtype )
220+ return c
221+
222+ mesh = create_mesh (SHARDING_STRATEGY )
223+ lhs_sharding = get_lhs_named_shading (mesh , SHARDING_STRATEGY )
224+ rhs_sharding = get_rhs_named_shading (mesh , SHARDING_STRATEGY )
225+ out_sharding = get_out_sharding (SHARDING_STRATEGY )
226+
227+ jit_sharded_f = jax .jit (
228+ shard_map (
229+ f ,
230+ mesh ,
231+ in_specs = (
232+ lhs_sharding .spec ,
233+ rhs_sharding .spec ,
234+ ),
235+ out_specs = out_sharding ,
236+ check_rep = False ,
237+ )
238+ )
239+
240+ lhs_shape = (m , k )
241+ rhs_shape = (k , n )
242+
243+ lhs_dtype = dtype
244+ rhs_dtype = dtype
245+
246+ key = jax .random .key (SEED )
247+
248+ def data_generator ():
249+ """Creates new random data on host and puts it on device."""
250+ nonlocal key
251+ key , key_lhs , key_rhs = jax .random .split (key , 3 )
252+
253+ # Create random data on host
254+ lhs_host = jax .random .normal (key_lhs , lhs_shape ).astype (lhs_dtype )
255+ rhs_host = jax .random .normal (key_rhs , rhs_shape ).astype (rhs_dtype )
256+
257+ # Put on device (HBM)
258+ lhs_device = jax .device_put (lhs_host , lhs_sharding )
259+ rhs_device = jax .device_put (rhs_host , rhs_sharding )
260+
261+ return (lhs_device , rhs_device )
262+
263+ time_ms_list = iteration_timeit (
264+ jit_sharded_f ,
265+ data_generator ,
266+ matrix_dim = f"{ dtype_str } _{ m } x{ n } x{ k } " ,
267+ tries = num_runs ,
268+ task = f"gemm_only_{ dtype_str } " ,
269+ trace_dir = trace_dir ,
270+ )
271+ return {
272+ "time_ms_list" : time_ms_list ,
273+ }
274+
275+
276+ def gemm_only_calculate_metrics (
277+ m : int ,
278+ k : int ,
279+ n : int ,
280+ dtype : jnp .dtype ,
281+ time_ms_list : list [float ],
282+ ) -> Dict [str , Any ]:
283+ # Calculate FLOPs (Matmul)
284+ total_flops = 2 * m * k * n
285+
286+ total_flops_per_device , total_flops_all_devices = handle_based_on_sharding (
287+ total_flops , SHARDING_STRATEGY
288+ )
289+
290+ dtype_str = dtype .dtype .name
291+ peak_flops_multiplier = get_peak_flops_multiplier (dtype_str )
292+ peak_flops = PEAK_FLOPS_PER_DEVICE * peak_flops_multiplier
293+
294+ metadata , metrics = unified_flops_metrics (
295+ m , n , k , time_ms_list , total_flops_per_device , total_flops_all_devices , peak_flops , dtype = dtype_str ,
296+ )
297+
298+ return metadata , metrics
299+
300+
301+ def all_reduce_only (
302+ m : int ,
303+ k : int ,
304+ n : int ,
305+ dtype : jnp .dtype = jnp .bfloat16 ,
306+ num_runs : int = 1 ,
307+ trace_dir : str = None ,
308+ ) -> Dict [str , Any ]:
309+ """Benchmarks only the AllReduce part of gemm_all_reduce independently.
310+
311+ Input: C [M, N]
312+ Output = AllReduce(C)
313+ """
314+
315+ dtype_str = dtype .dtype .name
316+ print (f"Running all_reduce_only benchmark with m={ m } , k={ k } , n={ n } , dtype={ dtype_str } , runs={ num_runs } " )
317+
318+ def f (c ):
319+ with jax .named_scope (MARKER ):
320+ # AllReduce (psum)
321+ out = jax .lax .psum (c , axis_name = "device" )
322+ return out
323+
324+ mesh = create_mesh (SHARDING_STRATEGY )
325+ # Input to AllReduce is the output of Matmul, which is C [M, N]
326+ # In gemm_all_reduce, C is effectively replicated or sharded depending on strategy,
327+ # but here SHARDING_STRATEGY is NO_SHARDING, so it's replicated?
328+ # Actually, let's double check gemm_all_reduce out_sharding.
329+ # out_sharding = get_out_sharding(SHARDING_STRATEGY) -> P(None, None) for NO_SHARDING
330+
331+ # So the input to THIS function should match the output of the GEMM part in gemm_all_reduce
332+ # In gemm_all_reduce:
333+ # f(x,y): ... return out
334+ # out_sharding is P(None, None).
335+
336+ # But wait, inside gemm_all_reduce's `f`, `c = acc.astype(dtype)`.
337+ # `c` is local to the device in shard_map terms if check_rep=False and in_specs are P(None, None).
338+ # Yes, `gemm_all_reduce` uses `in_specs=(lhs_sharding.spec, rhs_sharding.spec)`.
339+ # For NO_SHARDING, lhs_sharding is P(None, None), rhs is P(None, None).
340+ # So `c` is [M, N] per device.
341+
342+ # So here, we want input `c` to be P(None, None) per device.
343+
344+ input_sharding = get_out_sharding (SHARDING_STRATEGY ) # Reusing this as it matched C's distribution
345+ out_sharding = get_out_sharding (SHARDING_STRATEGY )
346+
347+ jit_sharded_f = jax .jit (
348+ shard_map (
349+ f ,
350+ mesh ,
351+ in_specs = (input_sharding ,),
352+ out_specs = out_sharding ,
353+ check_rep = False ,
354+ )
355+ )
356+
357+ # Shape of C
358+ c_shape = (m , n )
359+ c_dtype = dtype
360+
361+ key = jax .random .key (SEED )
362+
363+ def data_generator ():
364+ """Creates new random data on host and puts it on device."""
365+ nonlocal key
366+ key , key_c = jax .random .split (key , 2 )
367+
368+ # Create random data on host
369+ c_host = jax .random .normal (key_c , c_shape ).astype (c_dtype )
370+
371+ # Put on device (HBM)
372+ # We need to wrap input_sharding (which is a PartitionSpec) in NamedSharding
373+ # because device_put needs to know the mesh.
374+ named_input_sharding = jax .sharding .NamedSharding (mesh , input_sharding )
375+ c_device = jax .device_put (c_host , named_input_sharding )
376+
377+ return (c_device ,)
378+
379+ time_ms_list = iteration_timeit (
380+ jit_sharded_f ,
381+ data_generator ,
382+ matrix_dim = f"{ dtype_str } _{ m } x{ n } x{ k } " ,
383+ tries = num_runs ,
384+ task = f"all_reduce_only_{ dtype_str } " ,
385+ trace_dir = trace_dir ,
386+ )
387+ return {
388+ "time_ms_list" : time_ms_list ,
389+ }
390+
391+
392+ def all_reduce_only_calculate_metrics (
393+ m : int ,
394+ k : int ,
395+ n : int ,
396+ dtype : jnp .dtype ,
397+ time_ms_list : list [float ],
398+ ) -> Dict [str , Any ]:
399+
400+ # Calculate Bandwidth for AllReduce
401+ # AllReduce moves Matrix C: M x N
402+ matrix_c_size_bytes = m * n * dtype .dtype .itemsize
403+
404+ # Use unified_bytes_metrics for bandwidth-bound operations
405+ # We estimate total_bytes_all_devices assuming full replication or reduction over all devices
406+ num_devices = jax .device_count ()
407+ total_bytes_all_devices = matrix_c_size_bytes * num_devices
408+
409+ metadata , metrics = unified_bytes_metrics (
410+ m , n , time_ms_list ,
411+ total_bytes = matrix_c_size_bytes ,
412+ total_bytes_all_devices = total_bytes_all_devices ,
413+ dtype = dtype .dtype .name
414+ )
415+
416+ return metadata , metrics
0 commit comments