55import os
66from typing import Any , Dict
77
8+ from benchmark_utils import find_sparsecore_usage_from_xplane
89from benchmark_utils import get_lhs_named_shading
910from benchmark_utils import get_out_sharding
1011from benchmark_utils import MetricsStatistics
2627SEED = 0
2728GLOBAL_SHARDING_STRATEGY = ShardingStrategy .NO_SHARDING
2829GLOBAL_PSTATE = 7
29-
30+ LOG_SPARSECORE_USAGE = False
3031
3132def create_mesh (ici_size : int , mesh_shape : str ) -> Mesh :
3233 """Creates a mesh with the given ICI size."""
@@ -85,6 +86,7 @@ def unified_ici_collectives_metrics(
8586 ici_average_time_ms_list : list [float ],
8687 iteration : int ,
8788 op_type : str ,
89+ trace_dir : str = None ,
8890) -> Dict [str , Any ]:
8991 """Calculates the metrics for the ICI collectives benchmark."""
9092
@@ -147,8 +149,15 @@ def unified_ici_collectives_metrics(
147149 / rank
148150 )
149151
150-
152+
153+ sparsecore_used = "NA"
154+ if LOG_SPARSECORE_USAGE :
155+ print ("trace_dir: " , trace_dir )
156+ if trace_dir :
157+ sparsecore_used = find_sparsecore_usage_from_xplane (trace_dir )
158+ print ("sparsecore_used: " , sparsecore_used )
151159 print ("hlo first replica group: " , hlo_first_replica_group )
160+
152161 metadata = {
153162 "iteration" : iteration ,
154163 "op_type" : op_type ,
@@ -164,6 +173,7 @@ def unified_ici_collectives_metrics(
164173 "hlo_input_shape" : json .dumps (hlo_input_shape ),
165174 "hlo_output_shape" : json .dumps (hlo_output_shape ),
166175 "hlo_replica_groups" : json .dumps (hlo_replica_groups ),
176+ "sparsecore_used" : sparsecore_used ,
167177 }
168178 achieved_bw = [transferred_data * 1000 / my_time for my_time in ici_average_time_ms_list ]
169179 achieved_bw_statistics = MetricsStatistics (
@@ -294,6 +304,7 @@ def data_generator():
294304 "ici_average_time_ms_list" : time_ms_list ,
295305 "matrix_shape" : (m , n , k ),
296306 "op_type" : "AR" ,
307+ "trace_dir" : trace_dir ,
297308 }
298309
299310
@@ -308,6 +319,7 @@ def psum_benchmark_calculate_metrics(
308319 matrix_shape : tuple [int , int , int ],
309320 xla_output : str ,
310321 op_type : str ,
322+ trace_dir : str ,
311323) -> Dict [str , Any ]:
312324 """Calculates the metrics for the psum benchmark."""
313325 # Build dictionary of all the parameters in the function
@@ -322,8 +334,10 @@ def psum_benchmark_calculate_metrics(
322334 ici_average_time_ms_list ,
323335 matrix_dim ,
324336 op_type ,
337+ trace_dir ,
325338 )
326339
340+
327341def psum_scatter_benchmark (
328342 matrix_dim : int ,
329343 dtype : jnp .dtype ,
@@ -382,7 +396,7 @@ def f(x):
382396 )
383397 sharding_strategy_tuple = tuple (map (int , sharding_strategy .split ("x" )))
384398 op_dimension_tuple_multiplier = math .prod (sharding_strategy_tuple )
385- m = op_dimension_tuple_multiplier * 2
399+ m = op_dimension_tuple_multiplier
386400 n = matrix_dim
387401 k = 256
388402
@@ -405,6 +419,7 @@ def data_generator():
405419 "ici_average_time_ms_list" : time_ms_list ,
406420 "matrix_shape" : (m , n , k ),
407421 "op_type" : "RS" ,
422+ "trace_dir" : trace_dir ,
408423 }
409424
410425
@@ -419,6 +434,7 @@ def psum_scatter_benchmark_calculate_metrics(
419434 matrix_shape : tuple [int , int , int ],
420435 xla_output : str ,
421436 op_type : str ,
437+ trace_dir : str ,
422438) -> Dict [str , Any ]:
423439 """Calculates the metrics for the psum_scatter benchmark."""
424440 # Build dictionary of all the parameters in the function
@@ -433,6 +449,7 @@ def psum_scatter_benchmark_calculate_metrics(
433449 ici_average_time_ms_list ,
434450 matrix_dim ,
435451 op_type ,
452+ trace_dir ,
436453 )
437454
438455def all_gather_benchmark (
@@ -473,7 +490,9 @@ def all_gather_benchmark(
473490 "--xla_tpu_use_single_sparse_core_for_all_gather_offload=true" ,
474491 "--xla_tpu_use_tc_device_shape_on_sc=true" ,
475492 f"--xla_tpu_dvfs_p_state={ GLOBAL_PSTATE } " ,
493+ "--xla_tpu_scoped_vmem_limit_kib=65536" ,
476494 ]
495+ # libtpu_init_args=[ ]
477496 os .environ ["LIBTPU_INIT_ARGS" ] = " " .join (libtpu_init_args )
478497 mesh = create_mesh (ici_size , mesh_shape )
479498
@@ -513,7 +532,8 @@ def data_generator():
513532 return {
514533 "ici_average_time_ms_list" : time_ms_list ,
515534 "matrix_shape" : (m , n , k ),
516- "op_type" : "AG"
535+ "op_type" : "AG" ,
536+ "trace_dir" : trace_dir ,
517537 }
518538
519539
@@ -528,6 +548,7 @@ def all_gather_benchmark_calculate_metrics(
528548 matrix_shape : tuple [int , int , int ],
529549 xla_output : str ,
530550 op_type : str ,
551+ trace_dir : str ,
531552) -> Dict [str , Any ]:
532553 """Calculates the metrics for the all_gather benchmark."""
533554 # Build dictionary of all the parameters in the function
@@ -542,6 +563,7 @@ def all_gather_benchmark_calculate_metrics(
542563 ici_average_time_ms_list ,
543564 matrix_dim ,
544565 op_type ,
566+ trace_dir ,
545567 )
546568
547569
@@ -621,6 +643,7 @@ def data_generator():
621643 "ici_average_time_ms_list" : time_ms_list ,
622644 "matrix_shape" : (m , n , k ),
623645 "op_type" : "A2A" ,
646+ "trace_dir" : trace_dir ,
624647 }
625648
626649
@@ -635,6 +658,7 @@ def all_to_all_benchmark_calculate_metrics(
635658 matrix_shape : tuple [int , int , int ],
636659 xla_output : str ,
637660 op_type : str ,
661+ trace_dir : str ,
638662) -> Dict [str , Any ]:
639663 """Calculates the metrics for the all_to_all benchmark."""
640664 # Build dictionary of all the parameters in the function
@@ -649,5 +673,6 @@ def all_to_all_benchmark_calculate_metrics(
649673 ici_average_time_ms_list ,
650674 matrix_dim ,
651675 op_type ,
676+ trace_dir ,
652677 )
653678
0 commit comments