Skip to content

Commit a0360ec

Browse files
committed
fix copilot issues
1 parent dda5cc2 commit a0360ec

2 files changed

Lines changed: 10 additions & 1 deletion

File tree

tests/ccl/test_all_gather_gluon.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
Test suite for all-gather collective operation using Gluon.
66
"""
77

8+
import os
9+
810
import pytest
911
import torch
1012
import torch.distributed as dist
@@ -77,8 +79,14 @@ def test_all_gather_gluon(impl, mode, vary, dtype, M, N, block_size_m, block_siz
7779
async_op = mode != "eager_barrier"
7880
capture = mode == "graph"
7981

80-
shmem = iris.iris(2**33) # 8 GB
82+
# Size heap to fit input (M*N) + output (max_ranks*M*N) with headroom
83+
max_ranks = int(os.environ.get("WORLD_SIZE", 8))
84+
elem_size = torch.tensor([], dtype=dtype).element_size()
85+
needed = (1 + max_ranks) * M * N * elem_size
86+
heap_size = max(2**30, int(needed * 2)) # 2x headroom, minimum 1GB
87+
shmem = iris.iris(heap_size)
8188
rank, world_size = shmem.get_rank(), shmem.get_num_ranks()
89+
torch.cuda.set_device(rank)
8290
src = torch.empty((M, N), dtype=dtype, device=f"cuda:{rank}")
8391
stage_buf, result, config = _make_buffers(impl, shmem, rank, world_size, M, N, dtype, block_size_m, block_size_n)
8492
shmem.barrier()

tests/ccl/test_all_to_all_gluon.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_all_to_all_gluon(impl, mode, vary, dtype, M, N):
5656

5757
shmem = iris.iris(2**33) # 8 GB
5858
rank, world_size = shmem.get_rank(), shmem.get_num_ranks()
59+
torch.cuda.set_device(rank)
5960
width = N * world_size
6061
src = torch.empty((M, width), dtype=dtype, device=f"cuda:{rank}")
6162
stage_buf = shmem.zeros((M, width), dtype=dtype)

0 commit comments

Comments
 (0)