|
5 | 5 | import triton.experimental.tle.language as tle |
6 | 6 | from triton.compiler.errors import CompilationError |
7 | 7 |
|
8 | | -from test_tle_utils import compile_musa, require_mthreads_libtriton |
| 8 | +from test_tle_utils import compile_musa, compile_to_ttir, require_mthreads_libtriton |
9 | 9 |
|
10 | 10 | require_mthreads_libtriton() |
11 | 11 |
|
@@ -40,6 +40,40 @@ def _local_ptr_full_view_kernel(out_ptr): |
40 | 40 | tl.store(out_ptr + tl.arange(0, 16), loaded) |
41 | 41 |
|
42 | 42 |
|
| 43 | +@triton.jit |
| 44 | +def _local_ptr_atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): |
| 45 | + offsets = tl.arange(0, BLOCK) |
| 46 | + init = tl.full((BLOCK, ), 0, tl.int32) |
| 47 | + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) |
| 48 | + ptrs = tle.gpu.local_ptr(smem, (offsets, )) |
| 49 | + increments = offsets.to(tl.int32) + 1 |
| 50 | + old = tl.atomic_add(ptrs, increments, sem="relaxed", scope="cta") |
| 51 | + after = tl.load(ptrs) |
| 52 | + tl.store(out_ptr + offsets, old) |
| 53 | + tl.store(out_ptr + BLOCK + offsets, after) |
| 54 | + |
| 55 | + |
| 56 | +@triton.jit |
| 57 | +def _local_ptr_atomic_cas_kernel(out_ptr): |
| 58 | + init = tl.full((1, ), 3, tl.int32) |
| 59 | + smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) |
| 60 | + ptr = tle.gpu.local_ptr(smem, (0, )) |
| 61 | + old = tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta") |
| 62 | + after = tl.load(ptr) |
| 63 | + tl.store(out_ptr, old) |
| 64 | + tl.store(out_ptr + 1, after) |
| 65 | + |
| 66 | + |
| 67 | +@triton.jit |
| 68 | +def _local_ptr_atomic_cas_update_kernel(out_ptr): |
| 69 | + init = tl.full((1, ), 3, tl.int32) |
| 70 | + smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) |
| 71 | + ptr = tle.gpu.local_ptr(smem, (0, )) |
| 72 | + tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta") |
| 73 | + after = tl.load(ptr) |
| 74 | + tl.store(out_ptr, after) |
| 75 | + |
| 76 | + |
43 | 77 | @triton.jit |
44 | 78 | def _local_ptr_non_integer_index_kernel(out_ptr): |
45 | 79 | smem = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) |
@@ -110,6 +144,44 @@ def test_tle_local_ptr_full_view_store_load_rewrites_to_memdesc_ops(): |
110 | 144 | assert "musa_tle.local_pointers" not in llir, llir |
111 | 145 |
|
112 | 146 |
|
| 147 | +def test_tle_local_ptr_atomic_ops_accept_addrspace3_ttir(): |
| 148 | + add_ttir = compile_to_ttir( |
| 149 | + _local_ptr_atomic_add_kernel, |
| 150 | + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, |
| 151 | + constexprs={"BLOCK": 16}, |
| 152 | + ) |
| 153 | + cas_ttir = compile_to_ttir(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"}) |
| 154 | + |
| 155 | + assert "tt.atomic_rmw add, relaxed, cta" in add_ttir, add_ttir |
| 156 | + assert ("(tensor<16x!tt.ptr<i32, 3>>, tensor<16xi32>, tensor<16xi1>) -> tensor<16xi32>" in add_ttir), add_ttir |
| 157 | + assert "tt.atomic_cas relaxed, cta" in cas_ttir, cas_ttir |
| 158 | + assert "(!tt.ptr<i32, 3>, i32, i32) -> i32" in cas_ttir, cas_ttir |
| 159 | + |
| 160 | + |
| 161 | +def test_tle_local_ptr_atomic_add_lowers_through_mthreads_llvm(): |
| 162 | + compiled = compile_musa( |
| 163 | + _local_ptr_atomic_add_kernel, |
| 164 | + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, |
| 165 | + constexprs={"BLOCK": 16}, |
| 166 | + ) |
| 167 | + |
| 168 | + ttgir = compiled.asm["ttgir"] |
| 169 | + llir = compiled.asm["llir"] |
| 170 | + assert "tt.atomic_rmw" in ttgir, ttgir |
| 171 | + assert "tensor<16x!tt.ptr<i32, 3>" in ttgir, ttgir |
| 172 | + assert "musa_tle.local_pointers" not in llir, llir |
| 173 | + |
| 174 | + |
| 175 | +def test_tle_local_ptr_atomic_cas_lowers_through_mthreads_llvm(): |
| 176 | + compiled = compile_musa(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"}) |
| 177 | + |
| 178 | + ttgir = compiled.asm["ttgir"] |
| 179 | + llir = compiled.asm["llir"] |
| 180 | + assert "tt.atomic_cas" in ttgir, ttgir |
| 181 | + assert "-> !tt.ptr<i32, 3>" in ttgir, ttgir |
| 182 | + assert "musa_tle.local_pointers" not in llir, llir |
| 183 | + |
| 184 | + |
113 | 185 | def test_tle_local_ptr_rejects_non_integer_indices(): |
114 | 186 | with pytest.raises(CompilationError, match="local_ptr indices must use integer dtypes"): |
115 | 187 | compile_musa(_local_ptr_non_integer_index_kernel, signature={"out_ptr": "*fp32"}) |
@@ -158,3 +230,25 @@ def test_tle_local_ptr_full_view_runtime_round_trip(): |
158 | 230 |
|
159 | 231 | ref = torch.arange(0, 16, dtype=torch.float32) + 7.0 |
160 | 232 | torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) |
| 233 | + |
| 234 | + |
| 235 | +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") |
| 236 | +def test_tle_local_ptr_atomic_add_runtime_round_trip(): |
| 237 | + block = 16 |
| 238 | + out = torch.empty((block * 2, ), device="musa", dtype=torch.int32) |
| 239 | + |
| 240 | + _local_ptr_atomic_add_kernel[(1, )](out, BLOCK=block, num_warps=1) |
| 241 | + |
| 242 | + ref_old = torch.zeros((block, ), dtype=torch.int32) |
| 243 | + ref_after = torch.arange(1, block + 1, dtype=torch.int32) |
| 244 | + torch.testing.assert_close(out[:block].cpu(), ref_old, rtol=0, atol=0) |
| 245 | + torch.testing.assert_close(out[block:].cpu(), ref_after, rtol=0, atol=0) |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") |
| 249 | +def test_tle_local_ptr_atomic_cas_runtime_round_trip(): |
| 250 | + out = torch.empty((1, ), device="musa", dtype=torch.int32) |
| 251 | + |
| 252 | + _local_ptr_atomic_cas_update_kernel[(1, )](out, num_warps=1) |
| 253 | + |
| 254 | + torch.testing.assert_close(out.cpu(), torch.tensor([9], dtype=torch.int32), rtol=0, atol=0) |
0 commit comments