Skip to content

Commit 1e6c308

Browse files
authored
Logging of sparse core usage for ICI collectives (#68)
* Logging of sparse core usage for ICI collectives
1 parent 7618323 commit 1e6c308

3 files changed

Lines changed: 102 additions & 7 deletions

File tree

Ironwood/src/benchmark_collectives.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from typing import Any, Dict
77

8+
from benchmark_utils import find_sparsecore_usage_from_xplane
89
from benchmark_utils import get_lhs_named_shading
910
from benchmark_utils import get_out_sharding
1011
from benchmark_utils import MetricsStatistics
@@ -26,7 +27,7 @@
2627
SEED = 0
2728
GLOBAL_SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
2829
GLOBAL_PSTATE = 7
29-
30+
LOG_SPARSECORE_USAGE = False
3031

3132
def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
3233
"""Creates a mesh with the given ICI size."""
@@ -85,6 +86,7 @@ def unified_ici_collectives_metrics(
8586
ici_average_time_ms_list: list[float],
8687
iteration: int,
8788
op_type: str,
89+
trace_dir: str = None,
8890
) -> Dict[str, Any]:
8991
"""Calculates the metrics for the ICI collectives benchmark."""
9092

@@ -147,8 +149,15 @@ def unified_ici_collectives_metrics(
147149
/ rank
148150
)
149151

150-
152+
153+
sparsecore_used = "NA"
154+
if LOG_SPARSECORE_USAGE:
155+
print("trace_dir: ", trace_dir)
156+
if trace_dir:
157+
sparsecore_used = find_sparsecore_usage_from_xplane(trace_dir)
158+
print("sparsecore_used: ", sparsecore_used)
151159
print("hlo first replica group: ", hlo_first_replica_group)
160+
152161
metadata = {
153162
"iteration": iteration,
154163
"op_type": op_type,
@@ -164,6 +173,7 @@ def unified_ici_collectives_metrics(
164173
"hlo_input_shape": json.dumps(hlo_input_shape),
165174
"hlo_output_shape": json.dumps(hlo_output_shape),
166175
"hlo_replica_groups": json.dumps(hlo_replica_groups),
176+
"sparsecore_used": sparsecore_used,
167177
}
168178
achieved_bw = [transferred_data*1000/my_time for my_time in ici_average_time_ms_list]
169179
achieved_bw_statistics = MetricsStatistics(
@@ -294,6 +304,7 @@ def data_generator():
294304
"ici_average_time_ms_list": time_ms_list,
295305
"matrix_shape": (m, n, k),
296306
"op_type": "AR",
307+
"trace_dir": trace_dir,
297308
}
298309

299310

@@ -308,6 +319,7 @@ def psum_benchmark_calculate_metrics(
308319
matrix_shape: tuple[int, int, int],
309320
xla_output: str,
310321
op_type: str,
322+
trace_dir: str,
311323
) -> Dict[str, Any]:
312324
"""Calculates the metrics for the psum benchmark."""
313325
# Build dictionary of all the parameters in the function
@@ -322,8 +334,10 @@ def psum_benchmark_calculate_metrics(
322334
ici_average_time_ms_list,
323335
matrix_dim,
324336
op_type,
337+
trace_dir,
325338
)
326339

340+
327341
def psum_scatter_benchmark(
328342
matrix_dim: int,
329343
dtype: jnp.dtype,
@@ -382,7 +396,7 @@ def f(x):
382396
)
383397
sharding_strategy_tuple = tuple(map(int, sharding_strategy.split("x")))
384398
op_dimension_tuple_multiplier = math.prod(sharding_strategy_tuple)
385-
m = op_dimension_tuple_multiplier * 2
399+
m = op_dimension_tuple_multiplier
386400
n = matrix_dim
387401
k = 256
388402

@@ -405,6 +419,7 @@ def data_generator():
405419
"ici_average_time_ms_list": time_ms_list,
406420
"matrix_shape": (m, n, k),
407421
"op_type": "RS",
422+
"trace_dir": trace_dir,
408423
}
409424

410425

@@ -419,6 +434,7 @@ def psum_scatter_benchmark_calculate_metrics(
419434
matrix_shape: tuple[int, int, int],
420435
xla_output: str,
421436
op_type: str,
437+
trace_dir: str,
422438
) -> Dict[str, Any]:
423439
"""Calculates the metrics for the psum_scatter benchmark."""
424440
# Build dictionary of all the parameters in the function
@@ -433,6 +449,7 @@ def psum_scatter_benchmark_calculate_metrics(
433449
ici_average_time_ms_list,
434450
matrix_dim,
435451
op_type,
452+
trace_dir,
436453
)
437454

438455
def all_gather_benchmark(
@@ -473,7 +490,9 @@ def all_gather_benchmark(
473490
"--xla_tpu_use_single_sparse_core_for_all_gather_offload=true",
474491
"--xla_tpu_use_tc_device_shape_on_sc=true",
475492
f"--xla_tpu_dvfs_p_state={GLOBAL_PSTATE}",
493+
"--xla_tpu_scoped_vmem_limit_kib=65536",
476494
]
495+
# libtpu_init_args=[ ]
477496
os.environ["LIBTPU_INIT_ARGS"] = " ".join(libtpu_init_args)
478497
mesh = create_mesh(ici_size, mesh_shape)
479498

@@ -513,7 +532,8 @@ def data_generator():
513532
return {
514533
"ici_average_time_ms_list": time_ms_list,
515534
"matrix_shape": (m, n, k),
516-
"op_type": "AG"
535+
"op_type": "AG",
536+
"trace_dir": trace_dir,
517537
}
518538

519539

@@ -528,6 +548,7 @@ def all_gather_benchmark_calculate_metrics(
528548
matrix_shape: tuple[int, int, int],
529549
xla_output: str,
530550
op_type: str,
551+
trace_dir: str,
531552
) -> Dict[str, Any]:
532553
"""Calculates the metrics for the all_gather benchmark."""
533554
# Build dictionary of all the parameters in the function
@@ -542,6 +563,7 @@ def all_gather_benchmark_calculate_metrics(
542563
ici_average_time_ms_list,
543564
matrix_dim,
544565
op_type,
566+
trace_dir,
545567
)
546568

547569

@@ -621,6 +643,7 @@ def data_generator():
621643
"ici_average_time_ms_list": time_ms_list,
622644
"matrix_shape": (m, n, k),
623645
"op_type": "A2A",
646+
"trace_dir": trace_dir,
624647
}
625648

626649

@@ -635,6 +658,7 @@ def all_to_all_benchmark_calculate_metrics(
635658
matrix_shape: tuple[int, int, int],
636659
xla_output: str,
637660
op_type: str,
661+
trace_dir: str,
638662
) -> Dict[str, Any]:
639663
"""Calculates the metrics for the all_to_all benchmark."""
640664
# Build dictionary of all the parameters in the function
@@ -649,5 +673,6 @@ def all_to_all_benchmark_calculate_metrics(
649673
ici_average_time_ms_list,
650674
matrix_dim,
651675
op_type,
676+
trace_dir,
652677
)
653678

Ironwood/src/benchmark_send_recv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""Benchmarking p2p source target transfer."""
22

33
import os
4-
from typing import Any, Dict, List, Tuple
4+
from typing import Any, Dict, Tuple
55
import jax
66
from jax.experimental import mesh_utils
77
import jax.numpy as jnp
88
import jax.sharding
99
from benchmark_utils import (
10-
MetricsStatistics,
1110
get_trace,
1211
)
1312
from common import MARKER

Ironwood/src/benchmark_utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from jax.sharding import Mesh
2525
from jax.sharding import NamedSharding
2626
from jax.sharding import PartitionSpec as P
27+
import gc
28+
import jax.extend
29+
from tensorflow.tsl.profiler.protobuf import xplane_pb2
2730

2831
# The dictionary to map a JAX (collective) function to its main HLO.
2932
TARGET_TASK_NAME_COLLECTIVES_MAP = {
@@ -121,6 +124,13 @@ def multiple_iteration_timeit_from_trace_throttling(
121124
return multiple_iteration_get_metrics_from_trace(trace)
122125

123126

127+
def clear_jax_memory():
128+
backend = jax.extend.backend.get_backend()
129+
for buf in backend.live_buffers():
130+
buf.delete()
131+
gc.collect()
132+
133+
124134
def multiple_iteration_timeit_from_trace(
125135
compute_func: Callable,
126136
data_generator: Callable,
@@ -153,11 +163,13 @@ def multiple_iteration_timeit_from_trace(
153163
print(f"[{task}] Running iteration {i} of {tries} with {matrix_dim}...")
154164
data_args = data_generator()
155165
jax.devices()
166+
156167
with jax.profiler.StepTraceAnnotation(task, step_num=i):
157168
with jax.named_scope(f"{MARKER}_{i}"):
169+
158170
result = compute_func(*data_args)
159171
jax.block_until_ready(result)
160-
172+
clear_jax_memory()
161173
trace = get_trace(tmp_trace_dir)
162174

163175
if trace_full_dir != tmp_trace_dir:
@@ -502,6 +514,65 @@ def get_trace(log_dir: str) -> dict[str, Any]:
502514
return trace
503515

504516

517+
def find_sparsecore_usage_from_xplane(log_dir: str) -> xplane_pb2.XSpace:
518+
"""Extract the XSpace object from the log directory.
519+
520+
Returns:
521+
An XSpace protobuf object.
522+
"""
523+
print("find_sparsecore_usage_from_xplane: ", log_dir)
524+
525+
# Handle partial log_dir
526+
if not (pathlib.Path(log_dir) / "plugins" / "profile").exists():
527+
potential_dirs = glob.glob(f"{log_dir}*")
528+
potential_dirs = [d for d in potential_dirs if os.path.isdir(d)]
529+
potential_dirs.sort(key=os.path.getmtime, reverse=True)
530+
531+
for d in potential_dirs:
532+
d_path = pathlib.Path(d)
533+
if (d_path / "plugins" / "profile").exists():
534+
log_dir = d
535+
print(f"Updated log_dir to match partial path: {log_dir}")
536+
break
537+
538+
# Check subdirectories recursively
539+
candidates = list(d_path.glob("**/plugins/profile"))
540+
if candidates:
541+
latest = max(candidates, key=lambda p: p.stat().st_mtime)
542+
log_dir = str(latest.parent.parent)
543+
print(f"Updated log_dir via recursive search: {log_dir}")
544+
break
545+
546+
trace_folders = (
547+
pathlib.Path(log_dir).absolute() / "plugins" / "profile"
548+
).iterdir()
549+
latest_trace_folder = max(trace_folders, key=os.path.getmtime)
550+
551+
# XPlane files usually end with .xplane.pb
552+
xplane_files = list(latest_trace_folder.glob("*.xplane.pb"))
553+
try:
554+
(xplane_file,) = xplane_files
555+
except ValueError as value_error:
556+
raise ValueError(
557+
f"Invalid trace folder: {latest_trace_folder}. Expected 1"
558+
f" '*.xplane.pb' file, but found {len(xplane_files)}."
559+
) from value_error
560+
561+
with open(xplane_file, "rb") as f:
562+
serialized_space = f.read()
563+
564+
space = xplane_pb2.XSpace()
565+
space.ParseFromString(serialized_space)
566+
# print("space: ", space)
567+
sparsecore_found = False
568+
for _, plane in enumerate(space.planes):
569+
print("plane: ", plane.name)
570+
if "SparseCore" in plane.name:
571+
sparsecore_found = True
572+
break
573+
return sparsecore_found
574+
575+
505576
def get_metrics_from_trace(trace: dict[str, Any], task: str) -> list[float]:
506577
# Check if the given task name is a collective with corresponding TPU opertion.
507578
# This is a workaround and should be reverted or refactored in future.

0 commit comments

Comments
 (0)