1515
1616
1717@triton .jit ()
18- def ping_pong (
18+ def load_remote (
1919 data ,
2020 n_elements ,
2121 skip ,
2222 niter ,
23- flag ,
2423 curr_rank ,
2524 peer_rank ,
2625 BLOCK_SIZE : tl .constexpr ,
@@ -34,25 +33,18 @@ def ping_pong(
3433
3534 data_mask = offsets < n_elements
3635 time_stmp_mask = offsets < BLOCK_SIZE
37- flag_mask = offsets < 1
3836
3937 for i in range (niter + skip ):
4038 if i == skip :
4139 start = read_realtime ()
4240 tl .store (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
43- first_rank = tl .minimum (curr_rank , peer_rank ) if (i % 2 ) == 0 else tl .maximum (curr_rank , peer_rank )
44- token_first_done = i + 1
45- token_second_done = i + 2
46- if curr_rank == first_rank :
47- iris .store (data + offsets , i , curr_rank , peer_rank , heap_bases , mask = data_mask )
48- iris .atomic_xchg (flag + offsets , token_first_done , curr_rank , peer_rank , heap_bases , mask = flag_mask )
49- while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != token_second_done :
50- pass
51- else :
52- while tl .load (flag , cache_modifier = ".cv" , volatile = True ) != token_first_done :
53- pass
54- iris .store (data + offsets , i , curr_rank , peer_rank , heap_bases , mask = data_mask )
55- iris .atomic_xchg (flag + offsets , token_second_done , curr_rank , peer_rank , heap_bases , mask = flag_mask )
41+
42+ # iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask)
43+ from_base = tl .load (heap_bases + curr_rank )
44+ to_base = tl .load (heap_bases + peer_rank )
45+ offset = tl .cast (data + offsets , tl .uint64 ) - from_base
46+ translated_ptr = tl .cast (tl .cast (to_base , tl .pointer_type (tl .int8 )) + offset , (data + offsets ).dtype )
47+ result = tl .load (translated_ptr , mask = data_mask , cache_modifier = ".cv" , volatile = True )
5648
5749 stop = read_realtime ()
5850 tl .store (mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , stop , time_stmp_mask )
@@ -244,19 +236,17 @@ def print_run_settings(
244236 local_latency = torch .zeros ((num_ranks ), dtype = torch .float32 , device = "cuda" )
245237
246238 source_buffer = shmem .ones (BUFFER_LEN , dtype = dtype )
247- flag = shmem .ones (1 , dtype = torch .int32 )
248239
249240 grid = lambda meta : (1 ,)
250241 for source_rank in range (num_ranks ):
251242 for destination_rank in range (num_ranks ):
252- if source_rank != destination_rank and cur_rank in [source_rank , destination_rank ]:
243+ if cur_rank in [source_rank , destination_rank ]:
253244 peer_for_me = destination_rank if cur_rank == source_rank else source_rank
254- ping_pong [grid ](
245+ load_remote [grid ](
255246 source_buffer ,
256247 BUFFER_LEN ,
257248 skip ,
258249 niter ,
259- flag ,
260250 cur_rank ,
261251 peer_for_me ,
262252 BLOCK_SIZE ,
0 commit comments