-
Notifications
You must be signed in to change notification settings - Fork 39
Expand file tree
/
Copy pathatomic_add_bench.py
More file actions
executable file
·254 lines (206 loc) · 8.54 KB
/
atomic_add_bench.py
File metadata and controls
executable file
·254 lines (206 loc) · 8.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
import argparse
import json
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import triton
import triton.language as tl
import sys
import iris
from examples.common.utils import torch_dtype_from_str
torch.manual_seed(123)
random.seed(123)
@triton.jit
def atomic_add_kernel(
source_buffer, # tl.tensor: pointer to source data
buffer_size, # int32: total number of elements
source_rank: tl.constexpr,
destination_rank: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases_ptr: tl.tensor, # tl.tensor: pointer to heap bases pointers
):
pid = tl.program_id(0)
# Compute start index of this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Guard for out-of-bounds accesses
mask = offsets < buffer_size
# Get data from target buffer
result = iris.atomic_add(
source_buffer + offsets, 1, source_rank, destination_rank, heap_bases_ptr, mask=mask, sem="relaxed", scope="sys"
)
def parse_args():
parser = argparse.ArgumentParser(
description="Parse Message Passing configuration.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-t",
"--datatype",
type=str,
default="fp16",
choices=["fp16", "fp32", "bf16", "int32", "int64"],
help="Datatype of computation",
)
parser.add_argument("-z", "--buffer_size", type=int, default=1 << 32, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
parser.add_argument("-d", "--validate", action="store_true", help="Enable validation output")
parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("-o", "--output_file", type=str, default="", help="Output file")
parser.add_argument("-x", "--num_experiments", type=int, default=16, help="Number of experiments")
parser.add_argument("-w", "--num_warmup", type=int, default=4, help="Number of warmup experiments")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
return vars(parser.parse_args())
def run_experiment(shmem, args, source_rank, destination_rank, source_buffer):
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
if source_rank >= world_size:
raise ValueError(
f"Source rank must be less than or equal to the world size. World size is {world_size} and source rank is {source_rank}."
)
elif destination_rank >= world_size:
raise ValueError(
f"Destination rank must be less than or equal to the world size. World size is {world_size} and destination rank is {destination_rank}."
)
if cur_rank == 0:
if args["verbose"]:
shmem.info(f"Measuring bandwidth between the ranks {source_rank} and {destination_rank}...")
n_elements = source_buffer.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
def run_atomic_add():
if cur_rank == source_rank:
atomic_add_kernel[grid](
source_buffer,
n_elements,
source_rank,
destination_rank,
args["block_size"],
shmem.get_heap_bases(),
)
def preamble():
source_buffer.fill_(0)
# Warmup
run_atomic_add()
shmem.barrier()
atomic_add_ms = iris.do_bench(
run_atomic_add,
barrier_fn=shmem.barrier,
preamble_fn=preamble,
n_repeat=args["num_experiments"],
n_warmup=args["num_warmup"],
)
# Subtract overhead
triton_ms = atomic_add_ms
bandwidth_gbps = 0
if cur_rank == source_rank:
triton_sec = triton_ms * 1e-3
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
total_bytes = n_elements * element_size_bytes
bandwidth_gbps = total_bytes / triton_sec / 2**30
if args["verbose"]:
shmem.info(f"Copied {total_bytes / 2**30:.2f} GiB in {triton_sec:.4f} seconds")
shmem.info(f"Bandwidth between {source_rank} and {destination_rank} is {bandwidth_gbps:.4f} GiB/s")
shmem.barrier()
bandwidth_gbps = shmem.broadcast(bandwidth_gbps, source_rank)
success = True
if args["validate"] and cur_rank == destination_rank:
if args["verbose"]:
shmem.info("Validating output...")
expected = torch.ones(n_elements, dtype=dtype, device="cuda")
diff_mask = ~torch.isclose(source_buffer, expected)
if torch.any(diff_mask):
max_diff = (source_buffer - expected).abs().max().item()
shmem.info(f"Max absolute difference: {max_diff}")
first_mismatch_idx = torch.argmax(diff_mask.float()).item()
computed_val = source_buffer[first_mismatch_idx]
expected_val = expected[first_mismatch_idx]
shmem.error(f"First mismatch at index {first_mismatch_idx}: C={computed_val}, expected={expected_val}")
success = False
if success and args["verbose"]:
shmem.info("Validation successful.")
if not success and args["verbose"]:
shmem.error("Validation failed.")
success = shmem.broadcast(success, source_rank)
shmem.barrier()
if not success:
dist.destroy_process_group()
sys.exit(1)
return bandwidth_gbps, source_buffer.clone()
def print_bandwidth_matrix(
matrix, label="Unidirectional ATOMIC_ADD bandwidth GiB/s [Remote atomic add]", output_file=None
):
num_ranks = matrix.shape[0]
col_width = 10 # Adjust for alignment
print(f"\n{label}")
header = " SRC\\DST ".ljust(col_width)
for dst in range(num_ranks):
header += f"GPU {dst:02d}".rjust(col_width)
print(header)
for src in range(num_ranks):
row = f"GPU {src:02d} ->".ljust(col_width)
for dst in range(num_ranks):
row += f"{matrix[src, dst]:10.2f}"
print(row)
if output_file != "":
if output_file.endswith(".json"):
detailed_results = []
for src in range(num_ranks):
for dst in range(num_ranks):
detailed_results.append(
{
"source_gpu": f"GPU_{src:02d}",
"destination_gpu": f"GPU_{dst:02d}",
"source_rank": src,
"destination_rank": dst,
"bandwidth_gbps": float(matrix[src, dst]),
}
)
with open(output_file, "w") as f:
json.dump(detailed_results, f, indent=2)
else:
raise ValueError(f"Unsupported output file extension: {output_file}")
def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
"""Worker function for PyTorch distributed execution."""
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend,
init_method=init_url,
world_size=world_size,
rank=local_rank,
device_id=torch.device(f"cuda:{local_rank}"),
)
# Main benchmark logic
shmem = iris.iris(args["heap_size"])
num_ranks = shmem.get_num_ranks()
bandwidth_matrix = np.zeros((num_ranks, num_ranks), dtype=np.float32)
dtype = torch_dtype_from_str(args["datatype"])
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
source_buffer = shmem.arange(args["buffer_size"] // element_size_bytes, device="cuda", dtype=dtype)
for source_rank in range(num_ranks):
for destination_rank in range(num_ranks):
bandwidth_gbps, _ = run_experiment(shmem, args, source_rank, destination_rank, source_buffer)
bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps
shmem.barrier()
if shmem.get_rank() == 0:
print_bandwidth_matrix(bandwidth_matrix, output_file=args["output_file"])
dist.barrier()
dist.destroy_process_group()
def main():
args = parse_args()
num_ranks = args["num_ranks"]
init_url = "tcp://127.0.0.1:29500"
mp.spawn(
fn=_worker,
args=(num_ranks, init_url, args),
nprocs=num_ranks,
join=True,
)
if __name__ == "__main__":
main()