Skip to content

Commit 56ad603

Browse files
committed
Rewrote latency test
1 parent f537a20 commit 56ad603

1 file changed

Lines changed: 10 additions & 20 deletions

File tree

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515

1616

1717
@triton.jit()
18-
def ping_pong(
18+
def load_remote(
1919
data,
2020
n_elements,
2121
skip,
2222
niter,
23-
flag,
2423
curr_rank,
2524
peer_rank,
2625
BLOCK_SIZE: tl.constexpr,
@@ -34,25 +33,18 @@ def ping_pong(
3433

3534
data_mask = offsets < n_elements
3635
time_stmp_mask = offsets < BLOCK_SIZE
37-
flag_mask = offsets < 1
3836

3937
for i in range(niter + skip):
4038
if i == skip:
4139
start = read_realtime()
4240
tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
43-
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
44-
token_first_done = i + 1
45-
token_second_done = i + 2
46-
if curr_rank == first_rank:
47-
iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask)
48-
iris.atomic_xchg(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
49-
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done:
50-
pass
51-
else:
52-
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done:
53-
pass
54-
iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask)
55-
iris.atomic_xchg(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
41+
42+
# iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask)
43+
from_base = tl.load(heap_bases + curr_rank)
44+
to_base = tl.load(heap_bases + peer_rank)
45+
offset = tl.cast(data + offsets, tl.uint64) - from_base
46+
translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, (data + offsets).dtype)
47+
result = tl.load(translated_ptr, mask=data_mask, cache_modifier=".cv", volatile=True)
5648

5749
stop = read_realtime()
5850
tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
@@ -244,19 +236,17 @@ def print_run_settings(
244236
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
245237

246238
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
247-
flag = shmem.ones(1, dtype=torch.int32)
248239

249240
grid = lambda meta: (1,)
250241
for source_rank in range(num_ranks):
251242
for destination_rank in range(num_ranks):
252-
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
243+
if cur_rank in [source_rank, destination_rank]:
253244
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
254-
ping_pong[grid](
245+
load_remote[grid](
255246
source_buffer,
256247
BUFFER_LEN,
257248
skip,
258249
niter,
259-
flag,
260250
cur_rank,
261251
peer_for_me,
262252
BLOCK_SIZE,

0 commit comments

Comments
 (0)