Skip to content

Commit 4b4f2c2

Browse files
committed
[TLE][MTHREADS] Support atomic operands
1 parent 38a529e commit 4b4f2c2

3 files changed

Lines changed: 125 additions & 1 deletion

File tree

third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ namespace triton {
2727
struct GlobalMemory : public SideEffects::Resource::Base<GlobalMemory> {
2828
StringRef getName() final { return "<GlobalMemory>"; }
2929
};
30+
#ifdef __TLE__
31+
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
32+
StringRef getName() final { return "<SharedMemory>"; }
33+
};
34+
#endif
3035

3136
class DialectInferLayoutInterface
3237
: public DialectInterface::Base<DialectInferLayoutInterface> {

third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
2020
// Interfaces
2121
//
2222
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
23+
#ifdef __TLE__
24+
def SharedMemory : Resource<"::mlir::triton::SharedMemory">;
25+
#endif // __TLE__
2326

2427
//
2528
// Op Base
@@ -350,8 +353,13 @@ def TT_StoreOp : TT_Op<"store", [
350353
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
351354
SameOperandsAndResultShape,
352355
SameOperandsAndResultEncoding,
356+
#ifdef __TLE__
357+
TypesMatchWith<"value type matches ptr type", "ptr", "val",
358+
"getPointeeType($_self)">,
359+
#else
353360
TypesMatchWith<"ptr type matches value type", "val", "ptr",
354361
"getPointerTypeSameShape($_self)">,
362+
#endif // __TLE__
355363
TypesMatchWith<"mask type matches value type",
356364
"val", "mask", "getI1SameShape($_self)",
357365
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
@@ -366,7 +374,12 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
366374

367375
let arguments = (ins
368376
TT_AtomicRMWAttr:$atomic_rmw_op,
377+
#ifdef __TLE__
378+
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>,
379+
MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$ptr,
380+
#else
369381
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$ptr,
382+
#endif // __TLE__
370383
TT_Type:$val,
371384
Optional<TT_BoolLike>:$mask,
372385
TT_MemSemanticAttr:$sem,
@@ -386,10 +399,17 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
386399
def TT_AtomicCASOp : TT_Op<"atomic_cas", [
387400
SameOperandsAndResultShape,
388401
SameOperandsAndResultEncoding,
402+
#ifdef __TLE__
403+
TypesMatchWith<"cmp type matches ptr type", "ptr", "cmp",
404+
"getPointeeType($_self)">,
405+
TypesMatchWith<"value type matches ptr type", "ptr", "val",
406+
"getPointeeType($_self)">
407+
#else
389408
TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr",
390409
"getPointerTypeSameShape($_self)">,
391410
TypesMatchWith<"ptr type matches value type", "val", "ptr",
392411
"getPointerTypeSameShape($_self)">
412+
#endif // __TLE__
393413
]> {
394414
let summary = "atomic cas";
395415

@@ -404,7 +424,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [
404424
}];
405425

406426
let arguments = (ins
427+
#ifdef __TLE__
428+
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>,
429+
MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$ptr,
430+
#else
407431
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$ptr,
432+
#endif // __TLE__
408433
TT_Type:$cmp,
409434
TT_Type:$val,
410435
TT_MemSemanticAttr:$sem,

third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import triton.experimental.tle.language as tle
66
from triton.compiler.errors import CompilationError
77

8-
from test_tle_utils import compile_musa, require_mthreads_libtriton
8+
from test_tle_utils import compile_musa, compile_to_ttir, require_mthreads_libtriton
99

1010
require_mthreads_libtriton()
1111

@@ -40,6 +40,40 @@ def _local_ptr_full_view_kernel(out_ptr):
4040
tl.store(out_ptr + tl.arange(0, 16), loaded)
4141

4242

43+
@triton.jit
44+
def _local_ptr_atomic_add_kernel(out_ptr, BLOCK: tl.constexpr):
45+
offsets = tl.arange(0, BLOCK)
46+
init = tl.full((BLOCK, ), 0, tl.int32)
47+
smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False)
48+
ptrs = tle.gpu.local_ptr(smem, (offsets, ))
49+
increments = offsets.to(tl.int32) + 1
50+
old = tl.atomic_add(ptrs, increments, sem="relaxed", scope="cta")
51+
after = tl.load(ptrs)
52+
tl.store(out_ptr + offsets, old)
53+
tl.store(out_ptr + BLOCK + offsets, after)
54+
55+
56+
@triton.jit
57+
def _local_ptr_atomic_cas_kernel(out_ptr):
58+
init = tl.full((1, ), 3, tl.int32)
59+
smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False)
60+
ptr = tle.gpu.local_ptr(smem, (0, ))
61+
old = tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta")
62+
after = tl.load(ptr)
63+
tl.store(out_ptr, old)
64+
tl.store(out_ptr + 1, after)
65+
66+
67+
@triton.jit
68+
def _local_ptr_atomic_cas_update_kernel(out_ptr):
69+
init = tl.full((1, ), 3, tl.int32)
70+
smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False)
71+
ptr = tle.gpu.local_ptr(smem, (0, ))
72+
tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta")
73+
after = tl.load(ptr)
74+
tl.store(out_ptr, after)
75+
76+
4377
@triton.jit
4478
def _local_ptr_non_integer_index_kernel(out_ptr):
4579
smem = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False)
@@ -110,6 +144,44 @@ def test_tle_local_ptr_full_view_store_load_rewrites_to_memdesc_ops():
110144
assert "musa_tle.local_pointers" not in llir, llir
111145

112146

147+
def test_tle_local_ptr_atomic_ops_accept_addrspace3_ttir():
148+
add_ttir = compile_to_ttir(
149+
_local_ptr_atomic_add_kernel,
150+
signature={"out_ptr": "*i32", "BLOCK": "constexpr"},
151+
constexprs={"BLOCK": 16},
152+
)
153+
cas_ttir = compile_to_ttir(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"})
154+
155+
assert "tt.atomic_rmw add, relaxed, cta" in add_ttir, add_ttir
156+
assert ("(tensor<16x!tt.ptr<i32, 3>>, tensor<16xi32>, tensor<16xi1>) -> tensor<16xi32>" in add_ttir), add_ttir
157+
assert "tt.atomic_cas relaxed, cta" in cas_ttir, cas_ttir
158+
assert "(!tt.ptr<i32, 3>, i32, i32) -> i32" in cas_ttir, cas_ttir
159+
160+
161+
def test_tle_local_ptr_atomic_add_lowers_through_mthreads_llvm():
162+
compiled = compile_musa(
163+
_local_ptr_atomic_add_kernel,
164+
signature={"out_ptr": "*i32", "BLOCK": "constexpr"},
165+
constexprs={"BLOCK": 16},
166+
)
167+
168+
ttgir = compiled.asm["ttgir"]
169+
llir = compiled.asm["llir"]
170+
assert "tt.atomic_rmw" in ttgir, ttgir
171+
assert "tensor<16x!tt.ptr<i32, 3>" in ttgir, ttgir
172+
assert "musa_tle.local_pointers" not in llir, llir
173+
174+
175+
def test_tle_local_ptr_atomic_cas_lowers_through_mthreads_llvm():
176+
compiled = compile_musa(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"})
177+
178+
ttgir = compiled.asm["ttgir"]
179+
llir = compiled.asm["llir"]
180+
assert "tt.atomic_cas" in ttgir, ttgir
181+
assert "-> !tt.ptr<i32, 3>" in ttgir, ttgir
182+
assert "musa_tle.local_pointers" not in llir, llir
183+
184+
113185
def test_tle_local_ptr_rejects_non_integer_indices():
114186
with pytest.raises(CompilationError, match="local_ptr indices must use integer dtypes"):
115187
compile_musa(_local_ptr_non_integer_index_kernel, signature={"out_ptr": "*fp32"})
@@ -158,3 +230,25 @@ def test_tle_local_ptr_full_view_runtime_round_trip():
158230

159231
ref = torch.arange(0, 16, dtype=torch.float32) + 7.0
160232
torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0)
233+
234+
235+
@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available")
236+
def test_tle_local_ptr_atomic_add_runtime_round_trip():
237+
block = 16
238+
out = torch.empty((block * 2, ), device="musa", dtype=torch.int32)
239+
240+
_local_ptr_atomic_add_kernel[(1, )](out, BLOCK=block, num_warps=1)
241+
242+
ref_old = torch.zeros((block, ), dtype=torch.int32)
243+
ref_after = torch.arange(1, block + 1, dtype=torch.int32)
244+
torch.testing.assert_close(out[:block].cpu(), ref_old, rtol=0, atol=0)
245+
torch.testing.assert_close(out[block:].cpu(), ref_after, rtol=0, atol=0)
246+
247+
248+
@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available")
249+
def test_tle_local_ptr_atomic_cas_runtime_round_trip():
250+
out = torch.empty((1, ), device="musa", dtype=torch.int32)
251+
252+
_local_ptr_atomic_cas_update_kernel[(1, )](out, num_warps=1)
253+
254+
torch.testing.assert_close(out.cpu(), torch.tensor([9], dtype=torch.int32), rtol=0, atol=0)

0 commit comments

Comments
 (0)