From 1acbb5a7a2879f8ae4c5cc03d3e79ca67e313eef Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sat, 23 May 2026 16:45:43 +0800 Subject: [PATCH 01/10] [CHORE] Ignore FLAGTREE_BACKEND and mthreads build artifacts --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index ec72ee8b0f..48cdf0ef0b 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,11 @@ python/triton/_C/*.pdb python/triton/_C/*.exe python/triton/_C/*.ilk python/triton/FileCheck +python/triton/FLAGTREE_BACKEND + +third_party/mthreads/python/triton/_C/*.so +third_party/mthreads/python/triton/FileCheck +third_party/mthreads/python/*.egg-info # Backends copied from submodules python/triton/backends/* From 503a9134d31dbf0c09aeb90eb26699baada4f92c Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sat, 23 May 2026 17:21:40 +0800 Subject: [PATCH 02/10] [TLE][MTHREADS] Support TLE memory_space on mthreads backend --- third_party/mthreads/backend/compiler.py | 2 + .../include/TritonMUSAGPUTransforms/Passes.td | 17 + .../TritonMUSAGPUTransforms/CMakeLists.txt | 1 + .../TLE/EarlyAssignMemorySpace.cpp | 307 ++++++++++++++++++ .../mthreads/python/test/unit/tle/test_tle.py | 21 ++ .../test/unit/tle/test_tle_memory_space.py | 108 ++++++ third_party/mthreads/triton_mthreads.cc | 2 + 7 files changed, 458 insertions(+) create mode 100644 third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp create mode 100644 third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py diff --git a/third_party/mthreads/backend/compiler.py b/third_party/mthreads/backend/compiler.py index cb7310bf7b..f76458555f 100644 --- a/third_party/mthreads/backend/compiler.py +++ b/third_party/mthreads/backend/compiler.py @@ -740,6 +740,8 @@ def make_ttgir(mod, metadata, opt, arch, capability): passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_early_assign_memory_space"): + mthreads.passes.ttgpuir.add_tle_early_assign_memory_space(pm) mthreads.passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) mthreads.passes.ttgpuir.add_optimize_dot_operands(pm) diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td index afce112fec..4f316eeb25 100644 --- a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td @@ -191,6 +191,23 @@ def TritonMUSAGPUMarkInplaceLoads } #ifdef __TLE__ +def TritonMUSAGPUTLEEarlyAssignMemorySpace + : Pass<"tritonmusa-tle-early-assign-memory-space", "mlir::ModuleOp"> { + let summary = "Materialize TLE memory-space annotations for MUSA"; + let description = [{ + Rewrite tensors marked with `tt.memory_space = "shared_memory"` into + explicit shared-memory memdesc traffic before MUSA-specific layout and MMA + transforms run. Legal loads are lowered through MUSA async copy plus + commit/wait; all other producers use initialized local_alloc/local_load + materialization to preserve tensor semantics. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + def TritonMUSAGPUTLELowerAsyncLoad : Pass<"tritonmusa-tle-lower-async-load", "mlir::ModuleOp"> { let summary = "Lower TLE async load hints to MUSA async copies"; diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt index 7d7625c0c5..94f436452a 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt @@ -1,5 +1,6 @@ if(FLAGTREE_MTHREADS_TLE) set(_TLE_SOURCES + TLE/EarlyAssignMemorySpace.cpp TLE/LowerAsyncLoad.cpp ) else() diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp new file mode 100644 index 0000000000..1c4510b2f2 --- /dev/null +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp @@ -0,0 +1,307 @@ +#include "TritonMUSACommon/SqmmaAttrUtils.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/StringRef.h" +#include + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { + +constexpr llvm::StringLiteral kMemorySpaceAttr = "tt.memory_space"; +constexpr llvm::StringLiteral kSharedMemory = "shared_memory"; + +static bool isSplatConstantTrue(Value value) { + auto splat = value.getDefiningOp(); + if (!splat) + return false; + return isConstantIntValue(splat.getSrc(), 1); +} + +static bool hasDynamicSplatMask(tt::LoadOp op) { + Value mask = op.getMask(); + if (!mask) + return false; + auto splat = mask.getDefiningOp(); + return splat && !isSplatConstantTrue(mask); +} + +static bool hasSupportedElementType(RankedTensorType type) { + Type elemTy = type.getElementType(); + if (!elemTy.isIntOrFloat()) + return false; + unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); + return bitWidth == 8 || bitWidth == 16 || bitWidth == 32 || bitWidth == 64; +} + +static bool hasTensorPointerSource(tt::LoadOp op) { + if (isLoadFromTensorPtr(op)) + return false; + + auto ptrTy = dyn_cast(op.getPtr().getType()); + if (!ptrTy) + return false; + + auto elemPtrTy = dyn_cast(ptrTy.getElementType()); + return elemPtrTy && elemPtrTy.getAddressSpace() == 1; +} + +static Operation *getFirstUseInSameBlock(tt::LoadOp op) { + Operation *firstUse = nullptr; + Block *block = op->getBlock(); + for (Operation *user : op->getUsers()) { + if (user->getBlock() != block) + return nullptr; + if (!firstUse || user->isBeforeInBlock(firstUse)) + firstUse = user; + } + return firstUse; +} + +static unsigned +getAsyncLoadContiguity(tt::LoadOp op, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + Value ptr = op.getPtr(); + unsigned contiguity = axisInfoAnalysis.getContiguity(ptr); + if (Value mask = op.getMask()) + contiguity = + std::min(contiguity, axisInfoAnalysis.getMaskAlignment(mask)); + return std::max(1u, contiguity); +} + +static bool +canLowerAsyncMemorySpaceLoad(tt::LoadOp op, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + if (op->use_empty()) + return false; + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !hasSupportedElementType(resultTy)) + return false; + if (!hasTensorPointerSource(op)) + return false; + if (op.getIsVolatile()) + return false; + if (hasDynamicSplatMask(op)) + return false; + if (!getFirstUseInSameBlock(op)) + return false; + + return tt::canBeAsyncLoad(op) && + tt::canBeConvertedToAsyncLoad(op, axisInfoAnalysis); +} + +static bool hasNonZeroOther(tt::LoadOp op) { + return op.getOther() && !isZeroConst(op.getOther()); +} + +static bool getSqmmaAttrValue(Operation *op, llvm::StringRef name, + Attribute &value) { + value = op->getAttr(name); + return static_cast(value); +} + +static bool haveSameSqmmaAttrs(Operation *lhs, Operation *rhs) { + for (auto name : tt::musa::kSqmmaAttrNames) { + Attribute lhsAttr; + Attribute rhsAttr; + bool lhsHas = getSqmmaAttrValue(lhs, name, lhsAttr); + bool rhsHas = getSqmmaAttrValue(rhs, name, rhsAttr); + if (lhsHas != rhsHas) + return false; + if (lhsHas && lhsAttr != rhsAttr) + return false; + } + return true; +} + +struct ForwardedLocalAllocInfo { + bool canForward = true; + std::optional attrSource; +}; + +static ForwardedLocalAllocInfo +getForwardedLocalAllocInfo(tt::LoadOp op, ttg::MemDescType dstTy) { + ForwardedLocalAllocInfo info; + std::optional attrSource; + + for (Operation *user : op->getUsers()) { + auto localAlloc = dyn_cast(user); + if (!localAlloc) + continue; + + auto userTy = localAlloc.getType(); + if (userTy.getEncoding() != dstTy.getEncoding()) + continue; + + if (!attrSource) { + attrSource = localAlloc; + continue; + } + if (!haveSameSqmmaAttrs(attrSource->getOperation(), + localAlloc.getOperation())) { + info.canForward = false; + return info; + } + } + + info.attrSource = attrSource; + return info; +} + +static ttg::SharedEncodingTrait getSharedEncodingFor(Operation *op, + RankedTensorType type) { + if (isa(op)) + return tt::getSharedEncoding(op); + return tt::getSharedEncoding(type); +} + +static ttg::MemDescType getSharedMemDescType(Operation *op, + RankedTensorType type) { + auto sharedEncoding = getSharedEncodingFor(op, type); + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(type.getContext()); + return ttg::MemDescType::get(type.getShape(), type.getElementType(), + sharedEncoding, sharedMemorySpace, + /*mutableMemory=*/true); +} + +static ttg::LocalAllocOp createLocalAllocForLoad(OpBuilder &builder, + tt::LoadOp op, + bool &canForwardLocalAllocs) { + auto resultTy = cast(op.getType()); + auto memDescTy = getSharedMemDescType(op.getOperation(), resultTy); + auto alloc = ttg::LocalAllocOp::create(builder, op.getLoc(), memDescTy); + auto forwardInfo = getForwardedLocalAllocInfo(op, memDescTy); + canForwardLocalAllocs = forwardInfo.canForward; + if (forwardInfo.attrSource) + tt::musa::copySqmmaAttrs(forwardInfo.attrSource->getOperation(), + alloc.getOperation()); + return alloc; +} + +static void materializeViaInitializedSharedAlloc(Operation *op, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Location loc = op->getLoc(); + OpResult result = op->getResult(0); + auto type = cast(result.getType()); + auto memDescTy = getSharedMemDescType(op, type); + + builder.setInsertionPointAfter(op); + auto alloc = ttg::LocalAllocOp::create(builder, loc, memDescTy, result); + auto localLoad = + ttg::LocalLoadOp::create(builder, loc, type, alloc.getResult()); + result.replaceUsesWithIf(localLoad.getResult(), [&](OpOperand &use) { + return use.getOwner() != alloc.getOperation(); + }); + op->removeAttr(kMemorySpaceAttr); +} + +static void +lowerLoadViaAsyncSharedCopy(tt::LoadOp op, RewriterBase &rewriter, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + OpBuilder::InsertionGuard guard(rewriter); + Location loc = op.getLoc(); + rewriter.setInsertionPoint(op); + + bool canForwardLocalAllocs = true; + auto alloc = createLocalAllocForLoad(rewriter, op, canForwardLocalAllocs); + auto copy = ttg::AsyncCopyGlobalToLocalOp::create( + rewriter, loc, op.getPtr(), alloc.getResult(), op.getMask(), + op.getOther(), op.getCache(), op.getEvict(), op.getIsVolatile(), + getAsyncLoadContiguity(op, axisInfoAnalysis)); + auto commit = ttg::AsyncCommitGroupOp::create(rewriter, loc, copy.getToken()); + + Operation *firstUse = getFirstUseInSameBlock(op); + assert(firstUse && "memory_space async load should have a same-block use"); + rewriter.setInsertionPoint(firstUse); + auto wait = ttg::AsyncWaitOp::create(rewriter, loc, commit.getResult(), 0); + + if (hasNonZeroOther(op) && op.getMask()) { + auto localLoad = ttg::LocalLoadOp::create( + rewriter, loc, op.getType(), alloc.getResult(), wait.getResult()); + auto select = + arith::SelectOp::create(rewriter, loc, op.getType(), op.getMask(), + localLoad.getResult(), op.getOther()); + op.getResult().replaceAllUsesWith(select.getResult()); + } else if (canForwardLocalAllocs) { + tt::replaceUsesWithLocalLoad(rewriter, op->getResult(0), alloc.getResult(), + wait.getResult()); + } else { + auto localLoad = ttg::LocalLoadOp::create( + rewriter, loc, op.getType(), alloc.getResult(), wait.getResult()); + op.getResult().replaceAllUsesWith(localLoad.getResult()); + } + + rewriter.eraseOp(op); +} + +} // namespace + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLEEARLYASSIGNMEMORYSPACE +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +struct TritonMUSAGPUTLEEarlyAssignMemorySpacePass + : impl::TritonMUSAGPUTLEEarlyAssignMemorySpaceBase< + TritonMUSAGPUTLEEarlyAssignMemorySpacePass> { + void runOnOperation() override { + ModuleOp mod = getOperation(); + IRRewriter rewriter(&getContext()); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + + SmallVector ops; + mod.walk([&](Operation *op) { + if (op->hasAttr(kMemorySpaceAttr)) + ops.push_back(op); + }); + + bool failed = false; + for (Operation *op : ops) { + auto memorySpace = op->getAttrOfType(kMemorySpaceAttr); + if (!memorySpace || memorySpace.getValue() != kSharedMemory) { + op->emitError("unsupported MUSA TLE memory space: ") + << (memorySpace ? memorySpace.getValue() : ""); + failed = true; + continue; + } + + if (op->getNumResults() != 1 || + !isa(op->getResult(0).getType())) { + op->emitError("MUSA TLE shared memory_space expects one ranked tensor " + "result"); + failed = true; + continue; + } + + if (op->getResult(0).use_empty()) { + op->removeAttr(kMemorySpaceAttr); + continue; + } + + if (auto load = dyn_cast(op); + load && canLowerAsyncMemorySpaceLoad(load, axisInfoAnalysis)) { + lowerLoadViaAsyncSharedCopy(load, rewriter, axisInfoAnalysis); + continue; + } + + materializeViaInitializedSharedAlloc(op, rewriter); + } + + if (failed) + signalPassFailure(); + } +}; + +} // namespace mlir diff --git a/third_party/mthreads/python/test/unit/tle/test_tle.py b/third_party/mthreads/python/test/unit/tle/test_tle.py index fa241dee9a..ec53b04b10 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle.py @@ -51,6 +51,12 @@ def test_tle_language_import_exports_load_signature(): "is_async", "_semantic", ] + assert list(inspect.signature(tle.gpu.memory_space).parameters) == [ + "input", + "space", + "_builder", + "_semantic", + ] def test_tle_load_sets_async_bool_attr(): @@ -79,6 +85,21 @@ def tle_kernel(src, dst, ASYNC: tl.constexpr): assert "tt.load.async = true" in async_ttir +def test_tle_gpu_memory_space_sets_shared_memory_string_attr(): + + @triton.jit + def kernel(src, dst): + offsets = tl.arange(0, 16) + values = tle.load(src + offsets) + values = tle.gpu.memory_space(values, "shared_memory") + tl.store(dst + offsets, values) + + ttir = _compile_to_ttir(kernel, {"src": "*fp32", "dst": "*fp32"}) + + assert ttir.count(" = tt.load ") == 1 + assert 'tt.memory_space = "shared_memory"' in ttir + + def test_tle_load_forwards_tl_load_options(): @triton.jit diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py b/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py new file mode 100644 index 0000000000..8c3687a7ca --- /dev/null +++ b/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py @@ -0,0 +1,108 @@ +import os + +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle +from triton._C import libtriton +from triton.backends.compiler import GPUTarget +from triton.compiler import ASTSource + +if not hasattr(libtriton, "mthreads"): + pytest.skip("mthreads backend not built in libtriton", allow_module_level=True) + + +def _musa_target(): + arch = os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1" + return GPUTarget("musa", arch, 32) + + +def _compile_musa(fn, signature, constexprs=None): + src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) + return triton.compile(src, target=_musa_target()) + + +@triton.jit +def _memory_space_load_kernel(x_ptr, out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + vals = tle.load(x_ptr + offs) + vals = tle.gpu.memory_space(vals, "shared_memory") + tl.store(out_ptr + offs, vals) + + +@triton.jit +def _memory_space_non_load_kernel(out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + vals = offs.to(tl.float32) + 3.0 + vals = tle.gpu.memory_space(vals, "shared_memory") + tl.store(out_ptr + offs, vals) + + +@triton.jit +def _memory_space_unsupported_kernel(out_ptr, BLOCK: tl.constexpr): + offs = tl.arange(0, BLOCK) + vals = offs.to(tl.float32) + vals = tle.gpu.memory_space(vals, "tensor_memory") + tl.store(out_ptr + offs, vals) + + +def test_tle_memory_space_load_uses_shared_async_copy(): + compiled = _compile_musa( + _memory_space_load_kernel, + signature={"x_ptr": "*fp32", "out_ptr": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + ttgir = compiled.asm["ttgir"] + assert "tt.memory_space" not in ttgir + assert "ttg.async_copy_global_to_local" in ttgir, ttgir + assert "ttg.local_alloc" in ttgir, ttgir + assert "ttg.local_load" in ttgir, ttgir + assert "llvm.musa.memcpy.g2s" in compiled.asm["llir"] + + +def test_tle_memory_space_non_load_uses_initialized_shared_alloc(): + compiled = _compile_musa( + _memory_space_non_load_kernel, + signature={"out_ptr": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + ttgir = compiled.asm["ttgir"] + assert "tt.memory_space" not in ttgir + assert "ttg.async_copy_global_to_local" not in ttgir + assert "ttg.local_alloc" in ttgir, ttgir + assert "ttg.local_load" in ttgir, ttgir + + +def test_tle_memory_space_rejects_unsupported_space(capfd): + with pytest.raises(RuntimeError, match="PassManager::run failed"): + _compile_musa( + _memory_space_unsupported_kernel, + signature={"out_ptr": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + assert "unsupported MUSA TLE memory space: tensor_memory" in capfd.readouterr().err + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_memory_space_runtime_matches_input(): + block = 64 + x = torch.arange(block, device="musa", dtype=torch.float32) + out = torch.empty_like(x) + + _memory_space_load_kernel[(1, )](x, out, BLOCK=block, num_warps=1) + + torch.testing.assert_close(out.cpu(), x.cpu(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_memory_space_non_load_runtime(): + block = 64 + out = torch.empty((block, ), device="musa", dtype=torch.float32) + + _memory_space_non_load_kernel[(1, )](out, BLOCK=block, num_warps=1) + + ref = torch.arange(block, dtype=torch.float32) + 3.0 + torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) diff --git a/third_party/mthreads/triton_mthreads.cc b/third_party/mthreads/triton_mthreads.cc index 65960ae09f..3a2d23b8a4 100644 --- a/third_party/mthreads/triton_mthreads.cc +++ b/third_party/mthreads/triton_mthreads.cc @@ -139,6 +139,8 @@ void init_triton_musa_passes_ttgpuir(py::module &&m) { ADD_PASS_WRAPPER_0("add_optimize_sqmma_accumulator_layout", mlir::createTritonMUSAGPUOptimizeSqmmaAccumulatorLayout); #ifdef __TLE__ + ADD_PASS_WRAPPER_0("add_tle_early_assign_memory_space", + mlir::createTritonMUSAGPUTLEEarlyAssignMemorySpace); ADD_PASS_WRAPPER_0("add_tle_lower_async_load", mlir::createTritonMUSAGPUTLELowerAsyncLoad); #endif // __TLE__ From 7654831a35136d6f76f0477b83393d43083e2d82 Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sat, 23 May 2026 17:29:25 +0800 Subject: [PATCH 03/10] [TEST][MTHREADS] Deduplicate mthreads TLE test helpers --- .../mthreads/python/test/unit/tle/test_tle.py | 41 +++------------ .../test/unit/tle/test_tle_async_load.py | 24 ++------- .../test/unit/tle/test_tle_memory_space.py | 24 ++------- .../python/test/unit/tle/test_tle_utils.py | 51 +++++++++++++++++++ 4 files changed, 67 insertions(+), 73 deletions(-) create mode 100644 third_party/mthreads/python/test/unit/tle/test_tle_utils.py diff --git a/third_party/mthreads/python/test/unit/tle/test_tle.py b/third_party/mthreads/python/test/unit/tle/test_tle.py index ec53b04b10..1c99ca48d2 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle.py @@ -1,39 +1,10 @@ import inspect -import os -import pytest import triton import triton.language as tl import triton.experimental.tle.language as tle -from triton._C import libtriton -from triton._C.libtriton import ir -from triton.backends import backends -from triton.backends.compiler import GPUTarget -from triton.compiler import ASTSource - -def _get_musa_backend(): - if not hasattr(libtriton, "mthreads"): - pytest.skip("mthreads backend not built in libtriton") - if "mthreads" not in backends: - pytest.skip("mthreads backend not discovered") - target = GPUTarget("musa", - os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1", 32) - return target, backends["mthreads"].compiler(target) - - -def _compile_to_ttir(fn, signature, constexprs=None): - target, backend = _get_musa_backend() - - context = ir.context() - ir.load_dialects(context) - backend.load_dialects(context) - - options = backend.parse_options({}) - module_map = backend.get_module_map() - codegen_fns = backend.get_codegen_implementation(options) - src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) - return src.make_ir(target, options, codegen_fns, module_map, context).str_nodebug() +from test_tle_utils import compile_to_ttir def test_tle_language_import_exports_load_signature(): @@ -74,9 +45,9 @@ def tle_kernel(src, dst, ASYNC: tl.constexpr): tl.store(dst + offsets, values) signature = {"src": "*fp32", "dst": "*fp32", "ASYNC": "constexpr"} - tl_ttir = _compile_to_ttir(tl_kernel, {"src": "*fp32", "dst": "*fp32"}) - non_async_ttir = _compile_to_ttir(tle_kernel, signature, {"ASYNC": False}) - async_ttir = _compile_to_ttir(tle_kernel, signature, {"ASYNC": True}) + tl_ttir = compile_to_ttir(tl_kernel, {"src": "*fp32", "dst": "*fp32"}) + non_async_ttir = compile_to_ttir(tle_kernel, signature, {"ASYNC": False}) + async_ttir = compile_to_ttir(tle_kernel, signature, {"ASYNC": True}) assert "tt.load.async" not in tl_ttir assert tl_ttir.count(" = tt.load ") == 1 @@ -94,7 +65,7 @@ def kernel(src, dst): values = tle.gpu.memory_space(values, "shared_memory") tl.store(dst + offsets, values) - ttir = _compile_to_ttir(kernel, {"src": "*fp32", "dst": "*fp32"}) + ttir = compile_to_ttir(kernel, {"src": "*fp32", "dst": "*fp32"}) assert ttir.count(" = tt.load ") == 1 assert 'tt.memory_space = "shared_memory"' in ttir @@ -110,7 +81,7 @@ def kernel(src, dst): offsets = tl.arange(0, 16) tl.store(dst + offsets, values) - ttir = _compile_to_ttir(kernel, {"src": "*fp32", "dst": "*fp32"}) + ttir = compile_to_ttir(kernel, {"src": "*fp32", "dst": "*fp32"}) assert "tt.load.async = true" in ttir assert "boundaryCheck = array" in ttir diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_async_load.py b/third_party/mthreads/python/test/unit/tle/test_tle_async_load.py index 5fa2630700..b0733aa990 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle_async_load.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle_async_load.py @@ -1,26 +1,12 @@ -import os - import pytest import torch import triton import triton.language as tl import triton.experimental.tle.language as tle -from triton._C import libtriton -from triton.backends.compiler import GPUTarget -from triton.compiler import ASTSource - -if not hasattr(libtriton, "mthreads"): - pytest.skip("musa backend not built in libtriton", allow_module_level=True) - - -def _musa_target(): - arch = os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1" - return GPUTarget("musa", arch, 32) +from test_tle_utils import compile_musa, require_mthreads_libtriton -def _compile_musa(fn, signature, constexprs=None): - src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) - return triton.compile(src, target=_musa_target()) +require_mthreads_libtriton() @triton.jit @@ -56,7 +42,7 @@ def _tle_load_mask_other_kernel(x_ptr, out_ptr, n: tl.constexpr, BLOCK: tl.const @pytest.mark.parametrize("is_async", [False, True]) def test_tle_load_async_copy_codegen(is_async): - compiled = _compile_musa( + compiled = compile_musa( _tle_load_asm_kernel, signature={"x_ptr": "*fp32", "out_ptr": "*fp32", "BLOCK": "constexpr", "IS_ASYNC": "constexpr"}, constexprs={"BLOCK": 64, "IS_ASYNC": is_async}, @@ -80,7 +66,7 @@ def test_tle_load_async_copy_codegen(is_async): ], ) def test_tle_load_async_unsupported_widths_fall_back(signature, block, expect_async): - compiled = _compile_musa( + compiled = compile_musa( _tle_load_hinted_asm_kernel, signature={"x_ptr": signature, "out_ptr": signature, "BLOCK": "constexpr"}, constexprs={"BLOCK": block}, @@ -92,7 +78,7 @@ def test_tle_load_async_unsupported_widths_fall_back(signature, block, expect_as def test_tle_load_async_survives_block_ptr_rewrite(): - compiled = _compile_musa( + compiled = compile_musa( _tle_load_block_ptr_asm_kernel, signature={"x_ptr": "*fp32", "out_ptr": "*fp32"}, ) diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py b/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py index 8c3687a7ca..9a7f721644 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle_memory_space.py @@ -1,26 +1,12 @@ -import os - import pytest import torch import triton import triton.language as tl import triton.experimental.tle.language as tle -from triton._C import libtriton -from triton.backends.compiler import GPUTarget -from triton.compiler import ASTSource - -if not hasattr(libtriton, "mthreads"): - pytest.skip("mthreads backend not built in libtriton", allow_module_level=True) - - -def _musa_target(): - arch = os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1" - return GPUTarget("musa", arch, 32) +from test_tle_utils import compile_musa, require_mthreads_libtriton -def _compile_musa(fn, signature, constexprs=None): - src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) - return triton.compile(src, target=_musa_target()) +require_mthreads_libtriton() @triton.jit @@ -48,7 +34,7 @@ def _memory_space_unsupported_kernel(out_ptr, BLOCK: tl.constexpr): def test_tle_memory_space_load_uses_shared_async_copy(): - compiled = _compile_musa( + compiled = compile_musa( _memory_space_load_kernel, signature={"x_ptr": "*fp32", "out_ptr": "*fp32", "BLOCK": "constexpr"}, constexprs={"BLOCK": 64}, @@ -63,7 +49,7 @@ def test_tle_memory_space_load_uses_shared_async_copy(): def test_tle_memory_space_non_load_uses_initialized_shared_alloc(): - compiled = _compile_musa( + compiled = compile_musa( _memory_space_non_load_kernel, signature={"out_ptr": "*fp32", "BLOCK": "constexpr"}, constexprs={"BLOCK": 64}, @@ -78,7 +64,7 @@ def test_tle_memory_space_non_load_uses_initialized_shared_alloc(): def test_tle_memory_space_rejects_unsupported_space(capfd): with pytest.raises(RuntimeError, match="PassManager::run failed"): - _compile_musa( + compile_musa( _memory_space_unsupported_kernel, signature={"out_ptr": "*fp32", "BLOCK": "constexpr"}, constexprs={"BLOCK": 64}, diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_utils.py b/third_party/mthreads/python/test/unit/tle/test_tle_utils.py new file mode 100644 index 0000000000..d7219e87f2 --- /dev/null +++ b/third_party/mthreads/python/test/unit/tle/test_tle_utils.py @@ -0,0 +1,51 @@ +import os + +import pytest +import triton +from triton._C import libtriton +from triton._C.libtriton import ir +from triton.backends import backends +from triton.backends.compiler import GPUTarget +from triton.compiler import ASTSource + + +def require_mthreads_backend(): + if not hasattr(libtriton, "mthreads"): + pytest.skip("mthreads backend not built in libtriton") + if "mthreads" not in backends: + pytest.skip("mthreads backend not discovered") + + +def require_mthreads_libtriton(): + if not hasattr(libtriton, "mthreads"): + pytest.skip("mthreads backend not built in libtriton", allow_module_level=True) + + +def musa_target(): + arch = os.environ.get("TRITON_OVERRIDE_ARCH") or os.environ.get("TRITON_MUSA_ARCH") or "ph1" + return GPUTarget("musa", arch, 32) + + +def mthreads_backend(): + require_mthreads_backend() + target = musa_target() + return target, backends["mthreads"].compiler(target) + + +def compile_musa(fn, signature, constexprs=None): + src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) + return triton.compile(src, target=musa_target()) + + +def compile_to_ttir(fn, signature, constexprs=None): + target, backend = mthreads_backend() + + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + + options = backend.parse_options({}) + module_map = backend.get_module_map() + codegen_fns = backend.get_codegen_implementation(options) + src = ASTSource(fn=fn, signature=signature, constexprs=constexprs or {}) + return src.make_ir(target, options, codegen_fns, module_map, context).str_nodebug() From b1c1478d50d1daaef7dd563929148acc0eac67fc Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sat, 23 May 2026 19:20:07 +0800 Subject: [PATCH 04/10] [TLE][MTHREADS] Support TLE alloc on mthreads backend --- .../TritonToTritonGPU/TritonGPUConversion.cpp | 5 + .../TritonToTritonGPUPass.cpp | 3 + third_party/mthreads/python/src/ir.cc | 18 ++- third_party/mthreads/python/src/ir.h | 13 ++ .../mthreads/python/test/unit/tle/test_tle.py | 27 ++++ .../python/test/unit/tle/test_tle_alloc.py | 139 ++++++++++++++++++ third_party/mthreads/triton_mthreads.cc | 121 +++++++++++++++ 7 files changed, 323 insertions(+), 3 deletions(-) create mode 100644 third_party/mthreads/python/test/unit/tle/test_tle_alloc.py diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 129b86fba6..edbba2fd83 100644 --- a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -86,6 +86,11 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( addIllegalOp(); +#ifdef __TLE__ + addDynamicallyLegalOp( + [&](Operation *op) { return isDynamicallyLegal(op, typeConverter); }); +#endif + addDynamicallyLegalDialect( diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 780d0fd5a0..d8fbfb5d10 100644 --- a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -580,6 +580,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonScanPattern, GenericOpPattern, GenericOpPattern, +#ifdef __TLE__ + GenericOpPattern, +#endif TritonExpandDimsPattern, TritonTransPattern, TritonDotPattern, diff --git a/third_party/mthreads/python/src/ir.cc b/third_party/mthreads/python/src/ir.cc index bcd6ebc258..195d9edf36 100644 --- a/third_party/mthreads/python/src/ir.cc +++ b/third_party/mthreads/python/src/ir.cc @@ -41,6 +41,15 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/SourceMgr.h" +#ifdef __TLE__ +// Pointer to the TritonOpBuilder class, used to register IR ops for third-party +// dialects. +static py::class_ *builderClassPtr = nullptr; +namespace ir { +py::class_ *getBuilderClass() { return builderClassPtr; } +} // namespace ir +#endif + namespace { namespace py = pybind11; @@ -838,9 +847,12 @@ void init_triton_ir(py::module &&m) { py::class_(m, "InsertPoint", py::module_local()); - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) - .def(py::init()) + static py::class_ builderClass( + m, "builder", py::module_local(), py::dynamic_attr()); +#ifdef __TLE__ + builderClassPtr = &builderClass; +#endif + builderClass.def(py::init()) .def("get_op_builder", &TritonOpBuilder::getBuilder, ret::reference) // getters .def("create_module", diff --git a/third_party/mthreads/python/src/ir.h b/third_party/mthreads/python/src/ir.h index 499dd9e8a9..b36c3cd5d8 100644 --- a/third_party/mthreads/python/src/ir.h +++ b/third_party/mthreads/python/src/ir.h @@ -3,6 +3,13 @@ #include "triton/Tools/Sys/GetEnv.hpp" #include +#ifdef __TLE__ +#include +#include +#include +namespace py = pybind11; +#endif + // A custom op builder that keeps track of the last location class TritonOpBuilder { public: @@ -98,3 +105,9 @@ class TritonOpBuilder { return builder->getUnknownLoc(); } }; + +#ifdef __TLE__ +namespace ir { +extern py::class_ *getBuilderClass(); +} // namespace ir +#endif diff --git a/third_party/mthreads/python/test/unit/tle/test_tle.py b/third_party/mthreads/python/test/unit/tle/test_tle.py index 1c99ca48d2..1a3d3c1b4e 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle.py @@ -28,6 +28,15 @@ def test_tle_language_import_exports_load_signature(): "_builder", "_semantic", ] + assert list(inspect.signature(tle.gpu.alloc).parameters) == [ + "shape", + "dtype", + "layout", + "scope", + "init_value", + "nv_mma_shared_layout", + "_semantic", + ] def test_tle_load_sets_async_bool_attr(): @@ -71,6 +80,24 @@ def kernel(src, dst): assert 'tt.memory_space = "shared_memory"' in ttir +def test_tle_gpu_alloc_emits_local_alloc_in_ttir(): + + @triton.jit(noinline=True) + def consume_alloc(buf, out): + tl.store(out, 0.0) + + @triton.jit + def kernel(out): + buf = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) + consume_alloc(buf, out) + + ttir = compile_to_ttir(kernel, {"out": "*fp32"}) + + assert "ttg.local_alloc" in ttir, ttir + assert "!ttg.memdesc<16xf32" in ttir, ttir + assert "#smem" in ttir, ttir + + def test_tle_load_forwards_tl_load_options(): @triton.jit diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_alloc.py b/third_party/mthreads/python/test/unit/tle/test_tle_alloc.py new file mode 100644 index 0000000000..b462eb2edf --- /dev/null +++ b/third_party/mthreads/python/test/unit/tle/test_tle_alloc.py @@ -0,0 +1,139 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle +from triton.compiler.errors import CompilationError + +from test_tle_utils import compile_musa, require_mthreads_libtriton + +require_mthreads_libtriton() + + +@triton.jit(noinline=True) +def _consume_alloc(buf, out_ptr): + tl.store(out_ptr, 0.0) + + +@triton.jit +def _alloc_kernel(out_ptr): + buf = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_nv_mma_kernel(out_ptr): + buf = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=True) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_default_kernel(out_ptr): + buf = tle.gpu.alloc((16, ), dtype=tl.float32) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_init_kernel(out_ptr): + init = tl.full((16, ), 1.0, tl.float32) + buf = tle.gpu.alloc((16, ), dtype=tl.float32, init_value=init, nv_mma_shared_layout=False) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_tmem_kernel(out_ptr): + buf = tle.gpu.alloc((16, 16), dtype=tl.float32, scope=tle.gpu.tmem) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_explicit_layout_kernel(out_ptr, LAYOUT: tl.constexpr): + buf = tle.gpu.alloc((16, ), dtype=tl.float32, layout=LAYOUT) + _consume_alloc(buf, out_ptr) + + +@triton.jit +def _alloc_explicit_swizzled_layout_roundtrip_kernel(src_ptr, out_ptr, LAYOUT: tl.constexpr): + offs = tl.arange(0, 16) + values = tl.load(src_ptr + offs) + buf = tle.gpu.alloc((16, ), dtype=tl.float32, layout=LAYOUT, init_value=values) + _consume_alloc(buf, out_ptr) + tl.store(out_ptr + offs, values + 1.0) + + +def test_tle_alloc_ttgir_emits_smem_memdesc(): + compiled = compile_musa(_alloc_kernel, signature={"out_ptr": "*fp32"}) + ttgir = compiled.asm["ttgir"] + + assert "ttg.local_alloc" in ttgir, ttgir + assert "!ttg.memdesc<16xf32" in ttgir, ttgir + assert "#smem" in ttgir, ttgir + assert "#ttg.swizzled_shared" in ttgir, ttgir + assert "#ttg.nvmma_shared" not in ttgir, ttgir + assert "tensor_memory" not in ttgir, ttgir + + +def test_tle_alloc_nv_mma_shared_layout_true_raises(): + with pytest.raises(CompilationError, match="mthreads TLE alloc does not support nv_mma_shared_layout=True"): + compile_musa(_alloc_nv_mma_kernel, signature={"out_ptr": "*fp32"}) + + +def test_tle_alloc_default_nv_mma_shared_layout_raises(): + with pytest.raises(CompilationError, match="mthreads TLE alloc does not support nv_mma_shared_layout=True"): + compile_musa(_alloc_default_kernel, signature={"out_ptr": "*fp32"}) + + +def test_tle_alloc_explicit_swizzled_shared_layout_ttgir_emits_smem_memdesc(): + layout = tle.gpu.swizzled_shared_layout.make_default(rank=1) + compiled = compile_musa(_alloc_explicit_layout_kernel, signature={"out_ptr": "*fp32", "LAYOUT": "constexpr"}, + constexprs={"LAYOUT": layout}) + ttgir = compiled.asm["ttgir"] + + assert "ttg.local_alloc" in ttgir, ttgir + assert "!ttg.memdesc<16xf32" in ttgir, ttgir + assert "#smem" in ttgir, ttgir + assert "#ttg.swizzled_shared" in ttgir, ttgir + assert "#ttg.nvmma_shared" not in ttgir, ttgir + assert "tensor_memory" not in ttgir, ttgir + + +def test_tle_alloc_explicit_nv_mma_shared_layout_raises(): + layout = tle.gpu.nv_mma_shared_layout.make_default((16, ), tl.float32) + with pytest.raises(CompilationError, match="mthreads TLE alloc does not support nv_mma_shared_layout=True"): + compile_musa(_alloc_explicit_layout_kernel, signature={"out_ptr": "*fp32", "LAYOUT": "constexpr"}, + constexprs={"LAYOUT": layout}) + + +def test_tle_alloc_with_init_value_ttgir_emits_initialized_alloc(): + compiled = compile_musa(_alloc_init_kernel, signature={"out_ptr": "*fp32"}) + ttgir = compiled.asm["ttgir"] + + assert "ttg.local_alloc %" in ttgir, ttgir + assert "(tensor<16xf32" in ttgir, ttgir + assert "!ttg.memdesc<16xf32" in ttgir, ttgir + assert "tensor_memory" not in ttgir, ttgir + + +def test_tle_alloc_tmem_scope_raises(): + with pytest.raises(CompilationError, match="mthreads TLE alloc does not support tmem storage"): + compile_musa(_alloc_tmem_kernel, signature={"out_ptr": "*fp32"}) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_alloc_runtime_launch_smoke(): + out = torch.empty((1, ), device="musa", dtype=torch.float32) + + _alloc_kernel[(1, )](out, num_warps=4) + + torch.testing.assert_close(out.cpu(), torch.zeros((1, ), dtype=torch.float32), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_alloc_explicit_swizzled_shared_layout_runtime_precision(): + src = torch.arange(16, device="musa", dtype=torch.float32) + out = torch.empty((16, ), device="musa", dtype=torch.float32) + layout = tle.gpu.swizzled_shared_layout.make_default(rank=1) + + _alloc_explicit_swizzled_layout_roundtrip_kernel[(1, )](src, out, layout, num_warps=4) + + torch.testing.assert_close(out.cpu(), torch.arange(16, dtype=torch.float32) + 1.0, rtol=0, atol=0) diff --git a/third_party/mthreads/triton_mthreads.cc b/third_party/mthreads/triton_mthreads.cc index 3a2d23b8a4..c54341f342 100644 --- a/third_party/mthreads/triton_mthreads.cc +++ b/third_party/mthreads/triton_mthreads.cc @@ -3,6 +3,9 @@ #include "MTGPUToLLVM/Passes.h" #include "TritonMUSAGPUToLLVM/Passes.h" #include "TritonMUSAGPUTransforms/Passes.h" +#ifdef __TLE__ +#include "ir.h" +#endif #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" @@ -11,6 +14,9 @@ #include "mlir/Target/LLVMIR/Dialect/MTVM/MTVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "passes.h" +#ifdef __TLE__ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#endif #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -19,6 +25,10 @@ #include #include #include +#ifdef __TLE__ +#include +#include +#endif namespace py = pybind11; @@ -100,8 +110,115 @@ bool moduleUsesMulhiHelper(const llvm::Module &module) { return false; } +#ifdef __TLE__ +namespace ttg = mlir::triton::gpu; + +void checkCtaRank(llvm::ArrayRef order, + llvm::ArrayRef ctasPerCGA, + llvm::ArrayRef ctaSplitNum, + llvm::ArrayRef ctaOrder) { + if (order.size() != ctasPerCGA.size() || order.size() != ctaSplitNum.size() || + order.size() != ctaOrder.size()) + throw py::value_error("shared layout rank mismatch in CTA parameters"); +} + +ttg::CGAEncodingAttr makeCgaLayout(mlir::MLIRContext *context, + llvm::ArrayRef ctasPerCGA, + llvm::ArrayRef ctaSplitNum, + llvm::ArrayRef ctaOrder) { + return ttg::CGAEncodingAttr::fromSplitParams(context, ctasPerCGA, ctaSplitNum, + ctaOrder); +} + +mlir::Attribute getSharedMemorySpace(mlir::MLIRContext *context, + const std::string &storage) { + if (storage == "smem" || storage == "share_memory" || + storage == "shared_memory") + return ttg::SharedMemorySpaceAttr::get(context); + if (storage == "tmem" || storage == "tensor_memory") + throw py::value_error("mthreads TLE alloc does not support tmem storage"); + throw py::value_error("mthreads TLE alloc only supports smem storage"); +} +#endif // __TLE__ + } // namespace +#ifdef __TLE__ +void init_triton_mthreads_ir(py::module &&m) { + (void)m; + + auto *builderClsPtr = ir::getBuilderClass(); + if (!builderClsPtr) + throw std::runtime_error("triton IR builder class is not initialized"); + + auto &builderCls = *builderClsPtr; + builderCls + .def("make_swizzled_shared_encoding_attr", + [](TritonOpBuilder &self, unsigned vectorSize, unsigned perPhase, + unsigned maxPhase, std::vector order, + std::vector CTAsPerCGA, + std::vector CTASplitNum, + std::vector CTAOrder) -> mlir::Attribute { + checkCtaRank(order, CTAsPerCGA, CTASplitNum, CTAOrder); + auto *context = self.getBuilder().getContext(); + auto cgaLayout = + makeCgaLayout(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return ttg::SwizzledSharedEncodingAttr::get( + context, vectorSize, perPhase, maxPhase, order, cgaLayout); + }) + .def("make_nv_mma_shared_encoding_attr", + [](TritonOpBuilder &, std::vector, std::vector, + mlir::Type &, std::vector, std::vector, + std::vector, bool, bool) -> mlir::Attribute { + throw py::value_error("mthreads TLE alloc does not support " + "nv_mma_shared_layout=True"); + }) + .def("make_tensor_memory_encoding_attr", + [](TritonOpBuilder &, unsigned, unsigned, unsigned, unsigned, + unsigned, bool) -> mlir::Attribute { + throw py::value_error( + "mthreads TLE alloc does not support tmem storage"); + }) + .def("create_local_alloc", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, + mlir::Attribute &encoding) -> mlir::Value { + auto *context = self.getBuilder().getContext(); + auto memorySpace = ttg::SharedMemorySpaceAttr::get(context); + auto memDesc = ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true); + return self.create(memDesc); + }) + .def("create_local_alloc", + [](TritonOpBuilder &self, mlir::Type resultTy, + mlir::Value value) -> mlir::Value { + return self.create(resultTy, value); + }) + .def("get_memdesc_type", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, mlir::Attribute &encoding, + std::string storage) -> mlir::Type { + auto *context = self.getBuilder().getContext(); + auto memorySpace = getSharedMemorySpace(context, storage); + return ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true); + }) + .def("get_memdesc_type", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, mlir::Attribute &encoding, + std::string storage, + std::vector allocShape) -> mlir::Type { + auto *context = self.getBuilder().getContext(); + auto memorySpace = getSharedMemorySpace(context, storage); + return ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true, allocShape); + }); +} +#endif // __TLE__ + void init_triton_musa_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; m.def("add_mtgpu_to_llvm", [](mlir::PassManager &pm, int32_t capability) { @@ -147,6 +264,10 @@ void init_triton_musa_passes_ttgpuir(py::module &&m) { } void init_triton_mthreads(py::module &&m) { +#ifdef __TLE__ + init_triton_mthreads_ir(m.def_submodule("ir")); +#endif // __TLE__ + auto passes = m.def_submodule("passes"); init_triton_musa_passes_ttgpuir(passes.def_submodule("ttgpuir")); From a454f6a61d06cd8f5c9f067aa86476a021827743 Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sun, 24 May 2026 17:49:28 +0800 Subject: [PATCH 05/10] [TLE][MTHREADS] Support TLE local_ptr on mthreads backend --- third_party/mthreads/CMakeLists.txt | 18 + third_party/mthreads/backend/compiler.py | 8 + .../mthreads/bin/RegisterTritonDialects.h | 6 + .../TritonToTritonGPU/CMakeLists.txt | 10 + .../TritonToTritonGPU/TritonGPUConversion.cpp | 5 + .../TritonToTritonGPUPass.cpp | 15 + .../include/TritonMUSAGPUTransforms/Passes.td | 2 + .../lib/TritonMUSAGPUToLLVM/CMakeLists.txt | 7 + .../TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp | 11 + .../test/unit/tle/test_tle_local_ptr.py | 160 +++ third_party/mthreads/tle/CMakeLists.txt | 1 + .../mthreads/tle/dialect/CMakeLists.txt | 2 + .../tle/dialect/include/CMakeLists.txt | 1 + .../MUSATLEToLLVM/LocalPointersOpToLLVM.h | 16 + .../include/Dialect/MUSATLE/CMakeLists.txt | 1 + .../include/Dialect/MUSATLE/IR/CMakeLists.txt | 8 + .../include/Dialect/MUSATLE/IR/Dialect.h | 21 + .../Dialect/MUSATLE/IR/MUSATLEDialect.td | 22 + .../include/Dialect/MUSATLE/IR/MUSATLEOps.td | 26 + .../include/MUSATLE/Transforms/Passes.td | 50 + .../mthreads/tle/dialect/lib/CMakeLists.txt | 3 + .../Conversion/MUSATLEToLLVM/CMakeLists.txt | 11 + .../MUSATLEToLLVM/LocalPointersOpToLLVM.cpp | 283 ++++++ .../tle/dialect/lib/IR/CMakeLists.txt | 10 + .../mthreads/tle/dialect/lib/IR/Dialect.cpp | 134 +++ .../tle/dialect/lib/Transforms/CMakeLists.txt | 16 + .../Transforms/InsertLocalPointerBarriers.cpp | 452 +++++++++ .../Transforms/OptimizeLocalPointerLoads.cpp | 205 ++++ .../Transforms/OptimizeLocalPointerStores.cpp | 218 +++++ .../lib/Transforms/SelectEncodings.cpp | 922 ++++++++++++++++++ third_party/mthreads/triton_mthreads.cc | 136 +-- third_party/mthreads/triton_mthreads_tle.cc | 182 ++++ 32 files changed, 2842 insertions(+), 120 deletions(-) create mode 100644 third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py create mode 100644 third_party/mthreads/tle/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/include/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/include/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h create mode 100644 third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/Dialect.h create mode 100644 third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEDialect.td create mode 100644 third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEOps.td create mode 100644 third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td create mode 100644 third_party/mthreads/tle/dialect/lib/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.cpp create mode 100644 third_party/mthreads/tle/dialect/lib/IR/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/lib/IR/Dialect.cpp create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp create mode 100644 third_party/mthreads/triton_mthreads_tle.cc diff --git a/third_party/mthreads/CMakeLists.txt b/third_party/mthreads/CMakeLists.txt index 77fc07180f..fca39c3527 100644 --- a/third_party/mthreads/CMakeLists.txt +++ b/third_party/mthreads/CMakeLists.txt @@ -4,17 +4,35 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/musa/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/musa/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tle/dialect/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/tle/dialect/include) +if(FLAGTREE_MTHREADS_TLE) + add_subdirectory(tle) +endif() add_subdirectory(include) add_subdirectory(lib) add_subdirectory(musa) if(TRITON_BUILD_PYTHON_MODULE) + if(FLAGTREE_MTHREADS_TLE) + set(_MTHREADS_TLE_PLUGIN_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads_tle.cc) + set(_MTHREADS_TLE_PLUGIN_LIBS MUSATLEIR MUSATLETransforms) + set(_MTHREADS_TLE_PLUGIN_DEPS MUSATLETableGen MUSATLETransforms) + else() + set(_MTHREADS_TLE_PLUGIN_SOURCES "") + set(_MTHREADS_TLE_PLUGIN_LIBS "") + set(_MTHREADS_TLE_PLUGIN_DEPS "") + endif() add_triton_plugin(TritonMthreads ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads.cc + ${_MTHREADS_TLE_PLUGIN_SOURCES} LINK_LIBS TritonMUSAGPUToLLVM MTGPUToLLVM TritonMUSAGPUTransforms + ${_MTHREADS_TLE_PLUGIN_LIBS} MLIRMTVMToLLVMIRTranslation) add_dependencies(TritonMthreads MUSATableGen MUSAAttrDefsIncGen + ${_MTHREADS_TLE_PLUGIN_DEPS} MTGPUTableGen MTGPUTypesIncGen MTGPUConversionPassIncGen diff --git a/third_party/mthreads/backend/compiler.py b/third_party/mthreads/backend/compiler.py index f76458555f..8350ce996a 100644 --- a/third_party/mthreads/backend/compiler.py +++ b/third_party/mthreads/backend/compiler.py @@ -742,6 +742,14 @@ def make_ttgir(mod, metadata, opt, arch, capability): if hasattr(mthreads.passes.ttgpuir, "add_tle_early_assign_memory_space"): mthreads.passes.ttgpuir.add_tle_early_assign_memory_space(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_select_encodings"): + mthreads.passes.ttgpuir.add_tle_select_encodings(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_insert_local_pointer_barriers"): + mthreads.passes.ttgpuir.add_tle_insert_local_pointer_barriers(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_loads"): + mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_loads(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_stores"): + mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_stores(pm) mthreads.passes.ttgpuir.add_accelerate_matmul(pm) passes.ttgpuir.add_remove_layout_conversions(pm) mthreads.passes.ttgpuir.add_optimize_dot_operands(pm) diff --git a/third_party/mthreads/bin/RegisterTritonDialects.h b/third_party/mthreads/bin/RegisterTritonDialects.h index 8bd8ac4da4..7f7f5384c0 100644 --- a/third_party/mthreads/bin/RegisterTritonDialects.h +++ b/third_party/mthreads/bin/RegisterTritonDialects.h @@ -8,6 +8,9 @@ #include "Dialect/MTGPU/IR/Dialect.h" #include "Dialect/MUSA/IR/Dialect.h" +#ifdef __TLE__ +#include "Dialect/MUSATLE/IR/Dialect.h" +#endif #include "MTGPUToLLVM/Passes.h" #include "TritonMUSAGPUToLLVM/Passes.h" #include "TritonMUSAGPUTransforms/Passes.h" @@ -122,6 +125,9 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::TritonGPUDialect, mlir::triton::instrument::TritonInstrumentDialect, mlir::triton::musa::MUSADialect, mlir::triton::mtgpu::MTGPUDialect, +#ifdef __TLE__ + mlir::triton::musa_tle::MUSATLEDialect, +#endif mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect, diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index ed879c7dd5..de600e21b5 100644 --- a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,3 +1,11 @@ +if(FLAGTREE_MTHREADS_TLE) + set(_MUSATLE_DEPS MUSATLETableGen) + set(_MUSATLE_LIBS MUSATLEIR) +else() + set(_MUSATLE_DEPS "") + set(_MUSATLE_LIBS "") +endif() + add_triton_library(TritonToTritonGPU RelayoutTritonGPU.cpp TritonGPUConversion.cpp @@ -5,6 +13,7 @@ add_triton_library(TritonToTritonGPU DEPENDS TritonConversionPassIncGen + ${_MUSATLE_DEPS} LINK_LIBS PUBLIC MLIRIR @@ -13,4 +22,5 @@ add_triton_library(TritonToTritonGPU TritonIR ProtonIR TritonGPUIR + ${_MUSATLE_LIBS} ) diff --git a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index edbba2fd83..2211ddc2cb 100644 --- a/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/third_party/mthreads/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -6,6 +6,9 @@ #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" +#ifdef __TLE__ +#include "Dialect/MUSATLE/IR/Dialect.h" +#endif #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -89,6 +92,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( #ifdef __TLE__ addDynamicallyLegalOp( [&](Operation *op) { return isDynamicallyLegal(op, typeConverter); }); + addDynamicallyLegalDialect( + [&](Operation *op) { return isDynamicallyLegal(op, typeConverter); }); #endif addDynamicallyLegalDialect(typeConverter, context); } +#ifdef __TLE__ +void populateMUSATlePatterns(TritonGPUTypeConverter &typeConverter, + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add>( + typeConverter, context); +} +#endif + class ConvertTritonToTritonGPU : public triton::impl::ConvertTritonToTritonGPUBase< ConvertTritonToTritonGPU> { @@ -826,6 +838,9 @@ class ConvertTritonToTritonGPU // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); populateCFPatterns(typeConverter, patterns); +#ifdef __TLE__ + populateMUSATlePatterns(typeConverter, patterns); +#endif patterns.insert>(typeConverter, context); Builder b(&getContext()); diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td index 4f316eeb25..ab394f8f58 100644 --- a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td @@ -224,6 +224,8 @@ def TritonMUSAGPUTLELowerAsyncLoad "mlir::triton::gpu::TritonGPUDialect" ]; } + +include "MUSATLE/Transforms/Passes.td" #endif // __TLE__ #endif // TRITONMUSAGPU_PASSES diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt index e89315cb95..fa77303005 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/CMakeLists.txt @@ -1,3 +1,9 @@ +if(FLAGTREE_MTHREADS_TLE) + set(_TLE_LIBS MUSATLEToLLVM) +else() + set(_TLE_LIBS "") +endif() + add_triton_library(TritonMUSAGPUToLLVM AllocateSharedMemory.cpp BarrierOpToLLVM.cpp @@ -30,5 +36,6 @@ add_triton_library(TritonMUSAGPUToLLVM MLIRUBToLLVM MTGPUIR MUSAIR + ${_TLE_LIBS} MLIRGPUToMTVMTransforms ) diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp index 69d44427de..117239005f 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUToLLVM/TritonGPUToLLVM.cpp @@ -4,6 +4,10 @@ #include "TritonMUSAGPUToLLVM/Passes.h" #include "TritonMUSAGPUToLLVM/TargetInfo.h" #include "TritonMUSAGPUToLLVM/Utility.h" +#ifdef __TLE__ +#include "Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h" +#include "Dialect/MUSATLE/IR/Dialect.h" +#endif #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToMTVM/GPUToMTVMPass.h" @@ -120,6 +124,9 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); +#ifdef __TLE__ + addIllegalDialect(); +#endif addLegalOp(); addLegalOp(); @@ -358,6 +365,10 @@ struct ConvertTritonMUSAGPUToLLVM RewritePatternSet patterns(context); int benefit = patternBenefitPrioritizeOverLLVMConversions; +#ifdef __TLE__ + mlir::triton::musa_tle::populateMUSATLEToLLVMPatterns( + typeConverter, targetInfo, patterns, benefit); +#endif mlir::triton::MUSA::populateConvertLayoutOpToLLVMPatterns( typeConverter, targetInfo, patterns, benefit); mlir::triton::MUSA::populateDotOpToLLVMPatterns(typeConverter, patterns, diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py new file mode 100644 index 0000000000..49f8063d35 --- /dev/null +++ b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py @@ -0,0 +1,160 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle +from triton.compiler.errors import CompilationError + +from test_tle_utils import compile_musa, require_mthreads_libtriton + +require_mthreads_libtriton() + + +@triton.jit +def _local_ptr_subview_kernel(out_ptr, BLOCK: tl.constexpr): + init = tl.arange(0, 64).to(tl.float32) + 1.0 + smem = tle.gpu.alloc((64, ), dtype=tl.float32, init_value=init, nv_mma_shared_layout=False) + offsets = tl.arange(0, BLOCK) * 2 + ptrs = tle.gpu.local_ptr(smem, (offsets, )) + values = tl.load(ptrs) + tl.store(out_ptr + tl.arange(0, BLOCK), values) + + +@triton.jit +def _local_ptr_scalar_kernel(out_ptr): + init = tl.full((16, ), 0.0, tl.float32) + smem = tle.gpu.alloc((16, ), dtype=tl.float32, init_value=init, nv_mma_shared_layout=False) + ptr = tle.gpu.local_ptr(smem, (5, )) + tl.store(ptr, 42.0) + value = tl.load(ptr) + tl.store(out_ptr, value) + + +@triton.jit +def _local_ptr_full_view_kernel(out_ptr): + smem = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) + values = tl.arange(0, 16).to(tl.float32) + 7.0 + ptrs = tle.gpu.local_ptr(smem) + tl.store(ptrs, values) + loaded = tl.load(ptrs) + tl.store(out_ptr + tl.arange(0, 16), loaded) + + +@triton.jit +def _local_ptr_non_integer_index_kernel(out_ptr): + smem = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) + idx = tl.arange(0, 16).to(tl.float32) + ptrs = tle.gpu.local_ptr(smem, (idx, )) + values = tl.load(ptrs) + tl.store(out_ptr + tl.arange(0, 16), values) + + +@triton.jit +def _local_ptr_mixed_scalar_tensor_index_kernel(out_ptr): + smem = tle.gpu.alloc((4, 4), dtype=tl.float32, nv_mma_shared_layout=False) + cols = tl.arange(0, 4) + ptrs = tle.gpu.local_ptr(smem, (0, cols)) + values = tl.load(ptrs) + tl.store(out_ptr + cols, values) + + +@triton.jit +def _local_ptr_wrong_rank_index_kernel(out_ptr): + smem = tle.gpu.alloc((4, 4), dtype=tl.float32, nv_mma_shared_layout=False) + rows = tl.arange(0, 4) + ptrs = tle.gpu.local_ptr(smem, (rows, )) + values = tl.load(ptrs) + tl.store(out_ptr + rows, values) + + +@triton.jit +def _local_ptr_tmem_kernel(out_ptr): + smem = tle.gpu.alloc((16, 16), dtype=tl.float32, scope=tle.gpu.tmem) + idx = tl.arange(0, 16) + ptrs = tle.gpu.local_ptr(smem, (idx, idx)) + values = tl.load(ptrs) + tl.store(out_ptr + idx, values) + + +def test_tle_local_ptr_subview_lowers_through_mthreads_llvm(): + compiled = compile_musa( + _local_ptr_subview_kernel, + signature={"out_ptr": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 16}, + ) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "musa_tle.local_pointers" in ttgir, ttgir + assert "tensor<16x!tt.ptr" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + +def test_tle_local_ptr_scalar_lowers_through_mthreads_llvm(): + compiled = compile_musa(_local_ptr_scalar_kernel, signature={"out_ptr": "*fp32"}) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "musa_tle.local_pointers" in ttgir, ttgir + assert "-> !tt.ptr" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + +def test_tle_local_ptr_full_view_store_load_rewrites_to_memdesc_ops(): + compiled = compile_musa(_local_ptr_full_view_kernel, signature={"out_ptr": "*fp32"}) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "ttg.local_store" in ttgir, ttgir + assert "ttg.local_load" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + +def test_tle_local_ptr_rejects_non_integer_indices(): + with pytest.raises(CompilationError, match="local_ptr indices must use integer dtypes"): + compile_musa(_local_ptr_non_integer_index_kernel, signature={"out_ptr": "*fp32"}) + + +def test_tle_local_ptr_rejects_mixed_scalar_tensor_indices(): + with pytest.raises(CompilationError, match="local_ptr indices must be either all scalar or all tensors"): + compile_musa(_local_ptr_mixed_scalar_tensor_index_kernel, signature={"out_ptr": "*fp32"}) + + +def test_tle_local_ptr_rejects_wrong_index_rank(): + with pytest.raises(CompilationError, match="local_ptr indices must provide 2 tensors, got 1"): + compile_musa(_local_ptr_wrong_rank_index_kernel, signature={"out_ptr": "*fp32"}) + + +def test_tle_local_ptr_unsupported_storage_keeps_mthreads_error(): + with pytest.raises(CompilationError, match="mthreads TLE alloc does not support tmem storage"): + compile_musa(_local_ptr_tmem_kernel, signature={"out_ptr": "*fp32"}) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_subview_runtime_loads_shared_values(): + block = 16 + out = torch.empty((block, ), device="musa", dtype=torch.float32) + + _local_ptr_subview_kernel[(1, )](out, BLOCK=block, num_warps=1) + + ref = torch.arange(0, block * 2, 2, dtype=torch.float32) + 1.0 + torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_scalar_runtime_store_load(): + out = torch.empty((1, ), device="musa", dtype=torch.float32) + + _local_ptr_scalar_kernel[(1, )](out, num_warps=1) + + torch.testing.assert_close(out.cpu(), torch.tensor([42.0], dtype=torch.float32), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_full_view_runtime_round_trip(): + out = torch.empty((16, ), device="musa", dtype=torch.float32) + + _local_ptr_full_view_kernel[(1, )](out, num_warps=1) + + ref = torch.arange(0, 16, dtype=torch.float32) + 7.0 + torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) diff --git a/third_party/mthreads/tle/CMakeLists.txt b/third_party/mthreads/tle/CMakeLists.txt new file mode 100644 index 0000000000..562832e921 --- /dev/null +++ b/third_party/mthreads/tle/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(dialect) diff --git a/third_party/mthreads/tle/dialect/CMakeLists.txt b/third_party/mthreads/tle/dialect/CMakeLists.txt new file mode 100644 index 0000000000..8a43d93a8b --- /dev/null +++ b/third_party/mthreads/tle/dialect/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(include) +add_subdirectory(lib) diff --git a/third_party/mthreads/tle/dialect/include/CMakeLists.txt b/third_party/mthreads/tle/dialect/include/CMakeLists.txt new file mode 100644 index 0000000000..a4913e0480 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect/MUSATLE) diff --git a/third_party/mthreads/tle/dialect/include/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h b/third_party/mthreads/tle/dialect/include/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h new file mode 100644 index 0000000000..a5ba5141fa --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h @@ -0,0 +1,16 @@ +#ifndef MTHREADS_MUSATLE_CONVERSION_MUSATLETOLLVM_LOCALPOINTERSOPTOLLVM_H +#define MTHREADS_MUSATLE_CONVERSION_MUSATLETOLLVM_LOCALPOINTERSOPTOLLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" + +namespace mlir::triton::musa_tle { + +void populateMUSATLEToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace mlir::triton::musa_tle + +#endif // MTHREADS_MUSATLE_CONVERSION_MUSATLETOLLVM_LOCALPOINTERSOPTOLLVM_H diff --git a/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/CMakeLists.txt b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/CMakeLists.txt b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/CMakeLists.txt new file mode 100644 index 0000000000..50b3308e7c --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/CMakeLists.txt @@ -0,0 +1,8 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS MUSATLEOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=musa_tle -D__TLE__) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=musa_tle -D__TLE__) +mlir_tablegen(Ops.h.inc -gen-op-decls -D__TLE__) +mlir_tablegen(Ops.cpp.inc -gen-op-defs -D__TLE__) +add_public_tablegen_target(MUSATLETableGen) diff --git a/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/Dialect.h b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/Dialect.h new file mode 100644 index 0000000000..a150dc1198 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/Dialect.h @@ -0,0 +1,21 @@ +#ifndef TRITON_DIALECT_MUSATLE_IR_DIALECT_H_ +#define TRITON_DIALECT_MUSATLE_IR_DIALECT_H_ + +#ifdef __TLE__ + +#include "mlir/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" + +// clang-format off +#include "Dialect/MUSATLE/IR/Dialect.h.inc" +// clang-format on + +#define GET_OP_CLASSES +#include "Dialect/MUSATLE/IR/Ops.h.inc" + +#endif // __TLE__ + +#endif // TRITON_DIALECT_MUSATLE_IR_DIALECT_H_ diff --git a/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEDialect.td b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEDialect.td new file mode 100644 index 0000000000..9031769797 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEDialect.td @@ -0,0 +1,22 @@ +#ifndef MUSA_TLE_DIALECT +#define MUSA_TLE_DIALECT + +include "mlir/IR/OpBase.td" + +#ifdef __TLE__ +def MUSATLE_Dialect : Dialect { + let name = "musa_tle"; + let cppNamespace = "::mlir::triton::musa_tle"; + let description = [{ + MUSA backend-local Triton Language Extension dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "triton::gpu::TritonGPUDialect", + ]; + let usePropertiesForAttributes = 1; +} +#endif // __TLE__ + +#endif // MUSA_TLE_DIALECT diff --git a/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEOps.td b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEOps.td new file mode 100644 index 0000000000..5e61c99634 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/Dialect/MUSATLE/IR/MUSATLEOps.td @@ -0,0 +1,26 @@ +#ifndef MUSA_TLE_OPS +#define MUSA_TLE_OPS + +#ifdef __TLE__ +include "mlir/Interfaces/SideEffectInterfaces.td" +include "MUSATLEDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" + +class MUSATLE_Op traits = []> + : Op {} + +def MUSATLE_LocalPointerResultType : AnyTypeOf<[TT_Tensor, TT_Ptr]>; +def MUSATLE_LocalPointerIndexType : AnyTypeOf<[TT_Tensor, TT_Int]>; + +def MUSATLE_LocalPointersOp : MUSATLE_Op<"local_pointers", [Pure]> { + let arguments = (ins TTG_MemDescType:$src, + Variadic:$indices); + let results = (outs MUSATLE_LocalPointerResultType:$result); + let hasVerifier = 1; +} +#endif // __TLE__ + +#endif // MUSA_TLE_OPS diff --git a/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td new file mode 100644 index 0000000000..c55ed3d065 --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td @@ -0,0 +1,50 @@ +#ifndef MUSA_TLE_TRANSFORMS_PASSES +#define MUSA_TLE_TRANSFORMS_PASSES + +def TritonMUSAGPUTLESelectEncodings + : Pass<"tritonmusa-tle-select-encodings", "mlir::ModuleOp"> { + let summary = "Select MUSA encodings for local pointer users"; + let description = [{ + Select stable tensor encodings for `musa_tle.local_pointers` and + dependent load/store users before MUSA LLVM lowering. + }]; + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::musa_tle::MUSATLEDialect" + ]; +} + +def TritonMUSAGPUTLEInsertLocalPointerBarriers + : Pass<"tritonmusa-tle-insert-local-pointer-barriers", "mlir::ModuleOp"> { + let summary = "Insert barriers between local pointer stores and loads"; + let dependentDialects = [ + "mlir::gpu::GPUDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::musa_tle::MUSATLEDialect" + ]; +} + +def TritonMUSAGPUTLEOptimizeLocalPointerLoads + : Pass<"tritonmusa-tle-optimize-local-pointer-loads", "mlir::ModuleOp"> { + let summary = "Rewrite full-view local pointer loads to ttg.local_load"; + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::musa_tle::MUSATLEDialect" + ]; +} + +def TritonMUSAGPUTLEOptimizeLocalPointerStores + : Pass<"tritonmusa-tle-optimize-local-pointer-stores", "mlir::ModuleOp"> { + let summary = "Rewrite full-view local pointer stores to ttg.local_store"; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::musa_tle::MUSATLEDialect" + ]; +} + +#endif // MUSA_TLE_TRANSFORMS_PASSES diff --git a/third_party/mthreads/tle/dialect/lib/CMakeLists.txt b/third_party/mthreads/tle/dialect/lib/CMakeLists.txt new file mode 100644 index 0000000000..579da2029d --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(IR) +add_subdirectory(Conversion/MUSATLEToLLVM) +add_subdirectory(Transforms) diff --git a/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/CMakeLists.txt b/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/CMakeLists.txt new file mode 100644 index 0000000000..88f34328d3 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(MUSATLEToLLVM + LocalPointersOpToLLVM.cpp + + DEPENDS + MUSATLETableGen + + LINK_LIBS PUBLIC + MUSATLEIR + TritonGPUToLLVM + TritonGPUIR +) diff --git a/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.cpp b/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.cpp new file mode 100644 index 0000000000..6f01104838 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.cpp @@ -0,0 +1,283 @@ +#ifdef __TLE__ + +#include "Conversion/MUSATLEToLLVM/LocalPointersOpToLLVM.h" + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Tools/LayoutUtils.h" +#include "llvm/ADT/STLExtras.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace musa_tle = mlir::triton::musa_tle; + +struct LocalPointersOpConversion + : public ConvertOpToLLVMPattern { + LocalPointersOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(musa_tle::LocalPointersOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = op.getContext(); + auto typeConverter = getTypeConverter(); + auto reportFailure = [&](StringRef msg) -> LogicalResult { + return op.emitOpError() << msg; + }; + + auto memDescTy = cast(op.getSrc().getType()); + auto resultTensorTy = dyn_cast(op.getResult().getType()); + auto resultPtrTy = dyn_cast(op.getResult().getType()); + if (!resultTensorTy && !resultPtrTy) + return reportFailure("local_pointers result must be tensor or ptr"); + auto ptrTy = + resultTensorTy + ? cast(resultTensorTy.getElementType()) + : resultPtrTy; + auto llvmElemTy = typeConverter->convertType(memDescTy.getElementType()); + auto llvmPtrTy = + cast(typeConverter->convertType(ptrTy)); + if (llvmPtrTy.getAddressSpace() != + static_cast(targetInfo.getSharedAddressSpace())) + return reportFailure("local_pointers must lower to shared addrspace"); + + auto smemObj = LLVM::getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), + llvmElemTy, rewriter); + auto i32Ty = rewriter.getIntegerType(32); + auto ensureI32 = [&](Value v) -> Value { + if (v.getType() == i32Ty) + return v; + if (auto intTy = dyn_cast(v.getType())) { + if (intTy.getWidth() > 32) + return LLVM::TruncOp::create(rewriter, loc, i32Ty, v); + if (intTy.isUnsigned()) + return LLVM::ZExtOp::create(rewriter, loc, i32Ty, v); + return LLVM::SExtOp::create(rewriter, loc, i32Ty, v); + } + return Value(); + }; + + auto sharedEnc = cast(memDescTy.getEncoding()); + auto kReg = str_attr("register"); + auto kOffset = str_attr("offset"); + LinearLayout regLayout; + if (resultTensorTy) { + if (!resultTensorTy.getEncoding()) + return reportFailure( + "tensor local_pointers result must carry an encoding"); + regLayout = ttg::toLinearLayout(resultTensorTy); + } + for (Value operand : op.getIndices()) { + if (resultTensorTy) { + auto idxTy = dyn_cast(operand.getType()); + if (!idxTy) + return reportFailure("tensor result requires ranked-tensor indices"); + if (resultTensorTy.getEncoding() && idxTy.getEncoding() && + resultTensorTy.getEncoding() != idxTy.getEncoding()) + return reportFailure( + "indices tensor encoding must match result encoding"); + } else if (!isa(operand.getType())) { + return reportFailure("scalar result requires scalar integer indices"); + } + } + + const size_t outSize = resultTensorTy ? regLayout.getInDimSize(kReg) : 1; + SmallVector outVals(outSize, Value()); + + TritonLLVMOpBuilder b(loc, rewriter); + int elemBits = llvmElemTy.getIntOrFloatBitWidth(); + assert(elemBits % 8 == 0 && "element bitwidth must be byte addressable"); + int elemBytes = elemBits / 8; + Value elemBytesVal = + elemBytes > 1 ? b.i32_val(static_cast(elemBytes)) : Value(); + auto i8Ty = IntegerType::get(ctx, 8); + auto i8PtrTy = LLVM::LLVMPointerType::get(ctx, llvmPtrTy.getAddressSpace()); + + SmallVector bufferShape; + for (int64_t dim : memDescTy.getShape()) + bufferShape.push_back(static_cast(dim)); + auto bufferRank = bufferShape.size(); + auto smemOffsets = smemObj.getOffsets(); + const bool isRank0BackingMemDesc = + bufferRank == 1 && memDescTy.getShape().front() == 1; + const bool isLogicalRank0Scalar = + !resultTensorTy && op.getIndices().empty() && + (bufferRank == 0 || isRank0BackingMemDesc); + if (!isLogicalRank0Scalar && smemOffsets.size() != bufferRank) + return reportFailure("shared memory offsets rank mismatch"); + + auto indexVals = adaptor.getIndices(); + const bool hasExplicitIndices = !indexVals.empty(); + if (hasExplicitIndices) { + if (indexVals.size() != bufferRank) + return reportFailure("indices must provide buffer-rank values"); + } else { + if (!resultTensorTy && !isLogicalRank0Scalar) + return reportFailure( + "zero-index scalar local_pointers requires rank-0 buffer"); + if (resultTensorTy && resultTensorTy.getShape() != memDescTy.getShape()) + return reportFailure( + "zero-index tensor local_pointers requires full buffer shape"); + } + + SmallVector> indexElems; + if (hasExplicitIndices) { + indexElems.reserve(indexVals.size()); + for (Value indexVal : indexVals) { + if (resultTensorTy) { + auto elems = unpackLLElements(loc, indexVal, rewriter); + if (elems.size() != outVals.size()) + return reportFailure( + "indices tensors must match local_pointers result shape"); + indexElems.push_back(std::move(elems)); + } else { + Value scalar = ensureI32(indexVal); + if (!scalar) + return reportFailure("scalar indices must lower to i32 values"); + indexElems.push_back(SmallVector{scalar}); + } + } + } else if (resultTensorTy) { + auto fullCoords = + emitIndices(loc, rewriter, targetInfo, resultTensorTy.getEncoding(), + resultTensorTy, + /*withCTAOffset=*/false); + if (fullCoords.size() != outVals.size()) + return reportFailure( + "failed to synthesize full indices for local_pointers"); + indexElems.assign(bufferRank, SmallVector{}); + for (size_t idx = 0; idx < fullCoords.size(); ++idx) { + if (fullCoords[idx].size() != bufferRank) + return reportFailure("synthesized full indices rank mismatch"); + for (size_t dim = 0; dim < bufferRank; ++dim) { + Value coord = ensureI32(fullCoords[idx][dim]); + if (!coord) + return reportFailure( + "synthesized full indices must lower to i32 values"); + indexElems[dim].push_back(coord); + } + } + } + + for (size_t idx = 0; idx < outVals.size(); ++idx) { + SmallVector idxCoords; + idxCoords.reserve(bufferRank); + for (size_t dim = 0; dim < indexElems.size(); ++dim) { + Value val = ensureI32(indexElems[dim][idx]); + if (!val) + return reportFailure("indices must lower to i32 scalars"); + Value offset = smemOffsets[dim]; + Value offVal = ensureI32(offset); + if (!offVal) + return reportFailure("shared memory offsets must be i32"); + idxCoords.push_back(b.add(val, offVal)); + } + + Value elemOffset; + if (isLogicalRank0Scalar || bufferRank == 0) { + elemOffset = b.i32_val(0); + } else if (isa(sharedEnc)) { + auto order = ttg::getOrder(sharedEnc, memDescTy.getShape()); + elemOffset = + LLVM::linearize(rewriter, loc, idxCoords, bufferShape, order); + } else { + auto dimNames = standardOutDimNames(ctx, bufferRank); + SmallVector> logicalOffsets; + logicalOffsets.reserve(bufferRank); + for (auto [dim, offset] : llvm::zip_equal(dimNames, idxCoords)) + logicalOffsets.push_back({dim, offset}); + LinearLayout sharedLayout = ttg::toLinearLayout(memDescTy); + sharedLayout = sharedLayout.sublayout({kOffset}, dimNames); + LinearLayout invSharedLayout = sharedLayout.invert(); + + SmallVector> orderedLogicalOffsets; + orderedLogicalOffsets.reserve(invSharedLayout.getNumInDims()); + for (StringAttr inDim : invSharedLayout.getInDimNames()) { + bool found = false; + for (auto &logical : logicalOffsets) { + if (logical.first == inDim) { + orderedLogicalOffsets.push_back(logical); + found = true; + break; + } + } + if (!found) + return reportFailure( + "missing logical offset for inverted shared-layout in-dim"); + } + + auto remappedOffsets = applyLinearLayout(loc, rewriter, invSharedLayout, + orderedLogicalOffsets); + if (remappedOffsets.empty()) + return reportFailure("failed to remap shared-memory linear offsets"); + + bool foundOffset = false; + for (auto &mapped : remappedOffsets) { + if (mapped.first == kOffset) { + elemOffset = mapped.second; + foundOffset = true; + break; + } + } + if (!foundOffset) + return reportFailure( + "remapped shared layout does not contain offset"); + } + + Value byteOffset = elemOffset; + if (elemBytes > 1) + byteOffset = b.mul(byteOffset, elemBytesVal); + if (auto paddedEnc = dyn_cast(sharedEnc)) { + auto shifts = getPaddedSharedShifts(paddedEnc, elemBits, + /*offsetInBytes=*/true); + byteOffset = applyPadding(loc, rewriter, byteOffset, shifts); + } + + Value ptrI8 = b.bitcast(smemObj.getBase(), i8PtrTy); + Value advanced = b.gep(i8PtrTy, i8Ty, ptrI8, byteOffset, + LLVM::GEPNoWrapFlags::inbounds); + outVals[idx] = b.bitcast(advanced, llvmPtrTy); + } + + if (resultTensorTy) { + Value result = + packLLElements(loc, typeConverter, outVals, rewriter, resultTensorTy); + rewriter.replaceOp(op, result); + } else { + rewriter.replaceOp(op, outVals.front()); + } + return success(); + } + +private: + const TargetInfoBase &targetInfo; +}; + +} // namespace + +namespace mlir::triton::musa_tle { + +void populateMUSATLEToLLVMPatterns(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} + +} // namespace mlir::triton::musa_tle + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/lib/IR/CMakeLists.txt b/third_party/mthreads/tle/dialect/lib/IR/CMakeLists.txt new file mode 100644 index 0000000000..bce9e20c7e --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_triton_library(MUSATLEIR + Dialect.cpp + + DEPENDS + MUSATLETableGen + + LINK_LIBS PUBLIC + TritonIR + TritonGPUIR +) diff --git a/third_party/mthreads/tle/dialect/lib/IR/Dialect.cpp b/third_party/mthreads/tle/dialect/lib/IR/Dialect.cpp new file mode 100644 index 0000000000..05d2151a72 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/IR/Dialect.cpp @@ -0,0 +1,134 @@ +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +// clang-format off +#include "Dialect/MUSATLE/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +namespace ttg = mlir::triton::gpu; + +namespace mlir::triton::musa_tle { + +void MUSATLEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "Dialect/MUSATLE/IR/Ops.cpp.inc" + >(); +} + +} // namespace mlir::triton::musa_tle + +#define GET_OP_CLASSES +#include "Dialect/MUSATLE/IR/Ops.cpp.inc" + +namespace mlir::triton::musa_tle { +namespace { +constexpr int kSharedMemoryAddressSpace = 3; + +static bool isRank0BackingMemDesc(ttg::MemDescType memDescTy) { + return memDescTy.getShape().size() == 1 && memDescTy.getShape().front() == 1; +} +} // namespace + +LogicalResult LocalPointersOp::verify() { + auto memDescTy = dyn_cast(getSrc().getType()); + if (!memDescTy) + return emitOpError() << "expects src operand to be a ttg.memdesc"; + if (!isa(memDescTy.getMemorySpace())) + return emitOpError() << "expects src memdesc to live in shared memory"; + if (!isa(memDescTy.getEncoding())) + return emitOpError() << "expects src memdesc to use a shared encoding"; + + auto resultTensorTy = dyn_cast(getResult().getType()); + auto resultPtrTy = dyn_cast(getResult().getType()); + if (!resultTensorTy && !resultPtrTy) + return emitOpError() + << "expects result to be either tensor> or tt.ptr"; + + auto ptrTy = + resultTensorTy + ? dyn_cast(resultTensorTy.getElementType()) + : resultPtrTy; + if (!ptrTy) + return emitOpError() << "expects result element type to be tt.ptr"; + + if (ptrTy.getPointeeType() != memDescTy.getElementType()) + return emitOpError() << "expects pointer pointee type " + << ptrTy.getPointeeType() + << " to match memdesc element type " + << memDescTy.getElementType(); + + if (ptrTy.getAddressSpace() != kSharedMemoryAddressSpace) + return emitOpError() << "expects pointers to live in shared memory"; + + auto indices = getIndices(); + if (indices.empty()) { + if (resultTensorTy) { + if (resultTensorTy.getShape() != memDescTy.getShape()) + return emitOpError() + << "zero-index local_pointers expects tensor result shape to " + "match buffer shape"; + return success(); + } + if (!memDescTy.getShape().empty() && !isRank0BackingMemDesc(memDescTy)) + return emitOpError() + << "zero-index scalar local_pointers is only valid for rank-0 " + "buffers"; + return success(); + } + + if (indices.size() != memDescTy.getShape().size()) + return emitOpError() << "expects indices count to match buffer rank"; + + if (resultTensorTy) { + auto resultShape = resultTensorTy.getShape(); + Attribute resultEncoding = resultTensorTy.getEncoding(); + + ArrayRef indexShape; + for (Value val : indices) { + auto indexTy = dyn_cast(val.getType()); + if (!indexTy) + return emitOpError() + << "tensor result expects indices to be ranked tensors"; + if (!indexTy.getElementType().isInteger()) + return emitOpError() << "expects indices return tensors to have " + "integer element types"; + if (indexShape.empty()) + indexShape = indexTy.getShape(); + else if (indexTy.getShape() != indexShape) + return emitOpError() + << "expects indices return tensors to have identical shapes"; + if (resultEncoding && indexTy.getEncoding() && + resultEncoding != indexTy.getEncoding()) + return emitOpError() + << "expects indices return tensors to match result encoding"; + } + + if (indexShape != resultShape) + return emitOpError() + << "expects indices return tensor shape to match result shape"; + return success(); + } + + for (Value val : indices) { + if (auto indexTy = dyn_cast(val.getType())) { + if (!indexTy.isSignlessInteger()) + return emitOpError() + << "expects scalar indices to be signless integers"; + continue; + } + return emitOpError() << "scalar result expects scalar integer indices"; + } + + return success(); +} + +} // namespace mlir::triton::musa_tle + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt b/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..d98285836e --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(MUSATLETransforms + InsertLocalPointerBarriers.cpp + OptimizeLocalPointerLoads.cpp + OptimizeLocalPointerStores.cpp + SelectEncodings.cpp + + DEPENDS + MUSATLETableGen + TritonMUSAGPUTransformsIncGen + + LINK_LIBS PUBLIC + MUSATLEIR + TritonIR + TritonGPUIR + TritonGPUTransforms +) diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp new file mode 100644 index 0000000000..70ffd7403a --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp @@ -0,0 +1,452 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// flagtree tle + +#ifdef __TLE__ + +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" +#include + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLEINSERTLOCALPOINTERBARRIERS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +constexpr StringLiteral kBarrierGroupAttr = "musa_tle.barrier_group"; + +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto cvt = current.getDefiningOp()) { + current = cvt.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + return range && range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +static std::optional matchFullViewMemDesc(triton::LoadOp load) { + if (load.getMask() || load.getOther() || load.getIsVolatile()) + return std::nullopt; + if (load.getCache() != triton::CacheModifier::NONE || + load.getEvict() != triton::EvictionPolicy::NORMAL) + return std::nullopt; + + auto loadTy = dyn_cast(load.getType()); + if (!loadTy) + return std::nullopt; + + Value ptr = stripConvertLayouts(load.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return std::nullopt; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + auto memDescTy = dyn_cast(localPointers.getSrc().getType()); + if (!ptrTy || !memDescTy) + return std::nullopt; + + auto memDescShape = memDescTy.getShape(); + if (loadTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return std::nullopt; + if (loadTy.getElementType() != memDescTy.getElementType()) + return std::nullopt; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return localPointers.getSrc(); + if (indices.size() != memDescShape.size()) + return std::nullopt; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return std::nullopt; + + return localPointers.getSrc(); +} + +static void createLocalBarrier(OpBuilder &builder, Location loc) { + ttg::BarrierOp::create(builder, loc, ttg::AddrSpace::Local); +} + +static bool hasOnlyDotOperandUses(Value value, + llvm::SmallPtrSetImpl &seen) { + for (OpOperand &use : value.getUses()) { + Operation *user = use.getOwner(); + if (!seen.insert(user).second) + continue; + + if (auto cvt = dyn_cast(user)) { + if (!hasOnlyDotOperandUses(cvt.getResult(), seen)) + return false; + continue; + } + + auto dot = dyn_cast(user); + if (!dot) + return false; + if (dot.getA() != value && dot.getB() != value) + return false; + } + return true; +} + +static bool isFullViewLoadUsedOnlyByDotOperands(triton::LoadOp load) { + if (!matchFullViewMemDesc(load)) + return false; + llvm::SmallPtrSet seen; + return hasOnlyDotOperandUses(load.getResult(), seen); +} + +static bool isCudaTargetAtLeast(ModuleOp module, int minCapability) { + auto target = module->getAttrOfType("ttg.target"); + if (!target) + return false; + + StringRef value = target.getValue(); + if (!value.consume_front("cuda:")) + return false; + + int capability = 0; + if (value.getAsInteger(10, capability)) + return false; + return capability >= minCapability; +} + +class InsertLocalPointerBarriersPass + : public impl::TritonMUSAGPUTLEInsertLocalPointerBarriersBase< + InsertLocalPointerBarriersPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + pointerGroups.clear(); + allowDotOperandBarrierElision = isCudaTargetAtLeast(module, 90); + collectTrackedPointers(module); + + if (pointerGroups.empty()) + return; + + for (Operation &op : module.getBody()->getOperations()) + processOperation(op); + } + + void collectTrackedPointers(ModuleOp module) { + llvm::SmallVector worklist; + module.walk([&](triton::musa_tle::LocalPointersOp op) { + auto groupAttr = op->getAttrOfType(kBarrierGroupAttr); + if (!groupAttr) + return; + Value ptr = op.getResult(); + int64_t group = groupAttr.getInt(); + if (pointerGroups.try_emplace(ptr, group).second) + worklist.push_back(ptr); + }); + + auto tryTrackDerived = [&](Operation *op, Value src, Value derived) { + auto it = pointerGroups.find(src); + if (it == pointerGroups.end()) + return; + if (pointerGroups.try_emplace(derived, it->second).second) + worklist.push_back(derived); + }; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto convert = dyn_cast(owner)) { + tryTrackDerived(owner, convert.getSrc(), convert.getResult()); + } else if (auto splat = dyn_cast(owner)) { + tryTrackDerived(owner, splat.getSrc(), splat.getResult()); + } else if (auto bcast = dyn_cast(owner)) { + tryTrackDerived(owner, bcast.getSrc(), bcast.getResult()); + } else if (auto expand = dyn_cast(owner)) { + tryTrackDerived(owner, expand.getSrc(), expand.getResult()); + } else if (auto reshape = dyn_cast(owner)) { + tryTrackDerived(owner, reshape.getSrc(), reshape.getResult()); + } else if (auto addptr = dyn_cast(owner)) { + // Only propagate along the pointer operand. + if (use.getOperandNumber() == 0) + tryTrackDerived(owner, addptr.getPtr(), addptr.getResult()); + } else if (auto call = dyn_cast(owner)) { + auto it = pointerGroups.find(current); + if (it == pointerGroups.end()) + continue; + unsigned operandIdx = use.getOperandNumber(); + auto callee = module.lookupSymbol(call.getCallee()); + if (!callee || operandIdx >= callee.getNumArguments()) + continue; + Value calleeArg = callee.getArgument(operandIdx); + if (pointerGroups.try_emplace(calleeArg, it->second).second) + worklist.push_back(calleeArg); + } + } + } + } + + void processOperation(Operation &op) { + for (Region ®ion : op.getRegions()) + processRegion(region); + } + + void processRegion(Region ®ion) { + for (Block &block : region) + processBlock(block); + } + + void processBlock(Block &block) { + llvm::DenseMap dirtyGroups; + for (Operation &op : block) { + if (!dirtyGroups.empty() && op.getNumRegions() > 0) { + bool handledByIfSpecialization = false; + if (auto ifOp = dyn_cast(&op)) + handledByIfSpecialization = tryHandleUniformIf(ifOp, dirtyGroups); + + if (!handledByIfSpecialization && + opHasLoadNeedingBarrier(op, dirtyGroups)) { + OpBuilder builder(&op); + createLocalBarrier(builder, op.getLoc()); + dirtyGroups.clear(); + } + } + + if (auto store = dyn_cast(&op)) { + if (auto group = lookupPointerGroup(store.getPtr())) + dirtyGroups[*group] = true; + } else if (auto load = dyn_cast(&op)) { + auto group = lookupPointerGroup(load.getPtr()); + if (!group || !dirtyGroups.lookup(*group)) + continue; + if (allowDotOperandBarrierElision && + isFullViewLoadUsedOnlyByDotOperands(load)) + continue; + OpBuilder builder(load); + createLocalBarrier(builder, load.getLoc()); + // A CTA barrier synchronizes all shared-memory groups, not only the + // group used by this load. Clearing all dirty groups avoids emitting + // redundant back-to-back barriers for consecutive loads from different + // tracked groups. + dirtyGroups.clear(); + } else if (isa(&op)) { + dirtyGroups.clear(); + } + + for (Region &nested : op.getRegions()) + processRegion(nested); + + // Propagate write hazards from nested regions to the parent block. + // Without this, a store inside scf.if/scf.for may not mark parent state + // dirty, so a subsequent outer load can miss the required barrier. + markGroupsWrittenByNestedRegions(op, dirtyGroups); + } + } + + bool tryHandleUniformIf(scf::IfOp ifOp, + const llvm::DenseMap &dirtyGroups) { + if (!isUniformCondition(ifOp.getCondition())) + return false; + + for (Region ®ion : ifOp->getRegions()) { + if (!regionHasLoadNeedingBarrier(region, dirtyGroups)) + continue; + if (region.empty() || region.front().empty()) + continue; + + Block &entry = region.front(); + if (isa(entry.front())) + continue; + + OpBuilder builder(&entry, entry.begin()); + createLocalBarrier(builder, ifOp.getLoc()); + } + return true; + } + + bool isUniformCondition(Value cond) const { + if (isa_and_nonnull(cond.getDefiningOp())) + return true; + + auto reduce = cond.getDefiningOp(); + if (!reduce || !cond.getType().isInteger(1)) + return false; + + Operation *combiner = reduce.getSingleCombiner(); + return combiner && isa(combiner); + } + + bool regionHasLoadNeedingBarrier( + Region ®ion, const llvm::DenseMap &dirtyGroups) const { + for (Block &block : region) { + for (Operation &nestedOp : block) { + if (auto load = dyn_cast(&nestedOp)) { + if (auto group = lookupPointerGroup(load.getPtr()); + group && dirtyGroups.lookup(*group) && + !(allowDotOperandBarrierElision && + isFullViewLoadUsedOnlyByDotOperands(load))) + return true; + } + if (nestedOp.getNumRegions() > 0 && + opHasLoadNeedingBarrier(nestedOp, dirtyGroups)) + return true; + } + } + return false; + } + + bool opHasLoadNeedingBarrier( + Operation &op, const llvm::DenseMap &dirtyGroups) const { + for (Region ®ion : op.getRegions()) { + if (regionHasLoadNeedingBarrier(region, dirtyGroups)) + return true; + } + return false; + } + + void markGroupsWrittenByNestedRegions( + Operation &op, llvm::DenseMap &dirtyGroups) const { + if (op.getNumRegions() == 0) + return; + llvm::DenseSet writtenGroups; + for (Region ®ion : op.getRegions()) + collectWrittenGroups(region, writtenGroups); + for (int64_t group : writtenGroups) + dirtyGroups[group] = true; + } + + void collectWrittenGroups(Region ®ion, + llvm::DenseSet &writtenGroups) const { + for (Block &block : region) { + for (Operation &nestedOp : block) { + if (auto store = dyn_cast(&nestedOp)) { + if (auto group = lookupPointerGroup(store.getPtr())) + writtenGroups.insert(*group); + } + for (Region &deeperRegion : nestedOp.getRegions()) + collectWrittenGroups(deeperRegion, writtenGroups); + } + } + } + + std::optional lookupPointerGroup(Value ptr) const { + auto it = pointerGroups.find(ptr); + if (it == pointerGroups.end()) + return std::nullopt; + return it->second; + } + + llvm::DenseMap pointerGroups; + bool allowDotOperandBarrierElision = false; +}; + +} // namespace +} // namespace mlir + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp new file mode 100644 index 0000000000..821ac7c21a --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp @@ -0,0 +1,205 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLEOPTIMIZELOCALPOINTERLOADS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto cvt = current.getDefiningOp()) { + current = cvt.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + if (!range) + return false; + return range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +static std::optional matchFullViewMemDesc(triton::LoadOp load) { + if (load.getMask() || load.getOther()) + return std::nullopt; + if (load.getIsVolatile()) + return std::nullopt; + if (load.getCache() != triton::CacheModifier::NONE || + load.getEvict() != triton::EvictionPolicy::NORMAL) + return std::nullopt; + + auto loadTy = dyn_cast(load.getType()); + if (!loadTy) + return std::nullopt; + + Value ptr = stripConvertLayouts(load.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return std::nullopt; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + if (!ptrTy) + return std::nullopt; + + auto memDescTy = dyn_cast(localPointers.getSrc().getType()); + if (!memDescTy) + return std::nullopt; + + auto memDescShape = memDescTy.getShape(); + if (loadTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return std::nullopt; + if (loadTy.getElementType() != memDescTy.getElementType()) + return std::nullopt; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return localPointers.getSrc(); + if (indices.size() != memDescShape.size()) + return std::nullopt; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return std::nullopt; + + return localPointers.getSrc(); +} + +class OptimizeLocalPointerLoadsPass + : public impl::TritonMUSAGPUTLEOptimizeLocalPointerLoadsBase< + OptimizeLocalPointerLoadsPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + + struct RewriteItem { + triton::LoadOp load; + Value memDesc; + }; + SmallVector rewrites; + + module.walk([&](triton::LoadOp load) { + if (auto memDesc = matchFullViewMemDesc(load)) + rewrites.push_back({load, *memDesc}); + }); + + for (RewriteItem &item : rewrites) { + if (!item.load || !item.memDesc) + continue; + OpBuilder builder(item.load); + auto localLoad = ttg::LocalLoadOp::create( + builder, item.load.getLoc(), item.load.getType(), item.memDesc); + item.load.replaceAllUsesWith(localLoad.getResult()); + item.load.erase(); + } + } +}; + +} // namespace +} // namespace mlir + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp new file mode 100644 index 0000000000..a598054388 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp @@ -0,0 +1,218 @@ +// MIT License +// +// Copyright (c) 2025 The FlagOS Contributors +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLEOPTIMIZELOCALPOINTERSTORES +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto cvt = current.getDefiningOp()) { + current = cvt.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + return range && range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +static std::optional matchFullViewMemDesc(triton::StoreOp store) { + if (!store.getBoundaryCheck().empty()) + return std::nullopt; + + auto valueTy = dyn_cast(store.getValue().getType()); + if (!valueTy) + return std::nullopt; + + Value ptr = stripConvertLayouts(store.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return std::nullopt; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + auto memDescTy = dyn_cast(localPointers.getSrc().getType()); + if (!ptrTy || !memDescTy) + return std::nullopt; + + auto memDescShape = memDescTy.getShape(); + if (valueTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return std::nullopt; + if (valueTy.getElementType() != memDescTy.getElementType()) + return std::nullopt; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return localPointers.getSrc(); + if (indices.size() != memDescShape.size()) + return std::nullopt; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return std::nullopt; + + return localPointers.getSrc(); +} + +class OptimizeLocalPointerStoresPass + : public impl::TritonMUSAGPUTLEOptimizeLocalPointerStoresBase< + OptimizeLocalPointerStoresPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + + struct RewriteItem { + triton::StoreOp store; + Value memDesc; + }; + SmallVector rewrites; + + module.walk([&](triton::StoreOp store) { + if (auto memDesc = matchFullViewMemDesc(store)) + rewrites.push_back({store, *memDesc}); + }); + + for (RewriteItem &item : rewrites) { + if (!item.store || !item.memDesc) + continue; + + OpBuilder builder(item.store); + Value valueToStore = item.store.getValue(); + auto valueTy = cast(valueToStore.getType()); + + if (Value mask = item.store.getMask()) { + auto maskTy = dyn_cast(mask.getType()); + if (!maskTy || maskTy.getShape() != valueTy.getShape()) + continue; + if (maskTy.getEncoding() != valueTy.getEncoding()) { + auto targetMaskTy = + RankedTensorType::get(maskTy.getShape(), maskTy.getElementType(), + valueTy.getEncoding()); + mask = ttg::ConvertLayoutOp::create(builder, item.store.getLoc(), + targetMaskTy, mask) + .getResult(); + } + Value oldValue = ttg::LocalLoadOp::create(builder, item.store.getLoc(), + valueTy, item.memDesc) + .getResult(); + valueToStore = arith::SelectOp::create(builder, item.store.getLoc(), + mask, valueToStore, oldValue) + .getResult(); + } + + ttg::LocalStoreOp::create(builder, item.store.getLoc(), valueToStore, + item.memDesc); + item.store.erase(); + } + } +}; + +} // namespace +} // namespace mlir + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp new file mode 100644 index 0000000000..8827fb1a16 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp @@ -0,0 +1,922 @@ +// MIT License + +// Copyright (c) 2025 The FlagOS Contributors + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// flagtree tle + +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLESELECTENCODINGS +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +// Triton shared-memory pointers use LLVM address space 3 (NVVM shared). +constexpr int kSharedMemoryAddressSpace = 3; +constexpr StringLiteral kBarrierGroupAttr = "musa_tle.barrier_group"; +constexpr StringLiteral kTTContiguityAttr = "tt.contiguity"; +constexpr StringLiteral kTTDivisibilityAttr = "tt.divisibility"; +constexpr StringLiteral kTTConstancyAttr = "tt.constancy"; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto convert = current.getDefiningOp()) + current = convert.getSrc(); + return current; +} + +static Attribute getStrippedTensorEncoding(Value value) { + Value stripped = stripConvertLayouts(value); + auto strippedTy = dyn_cast(stripped.getType()); + if (!strippedTy) + return Attribute(); + return strippedTy.getEncoding(); +} + +static bool isConstantLikeTensorValue(Value value) { + Value cur = stripConvertLayouts(value); + if (!isa(cur.getType())) + return false; + if (isa_and_nonnull(cur.getDefiningOp())) + return true; + if (auto splat = cur.getDefiningOp()) { + Value src = splat.getSrc(); + if (isa_and_nonnull(src.getDefiningOp())) + return true; + } + return false; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto convert = current.getDefiningOp()) { + current = convert.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static bool matchZeroStartMakeRange(Value value, int64_t extent) { + Value current = stripIndexValueWrappers(value); + auto range = current.getDefiningOp(); + if (!range) + return false; + return range.getStart() == 0 && range.getEnd() == extent; +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchZeroStartMakeRange(current, shape.front()); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchZeroStartMakeRange(current, shape[axis]); +} + +// Loads of full-view local_pointers are later rewritten to ttg.local_load. +// They should not bias local_pointers encoding inference toward load layouts. +static bool isRewritableFullViewLocalPointerLoad(triton::LoadOp load) { + if (load.getMask() || load.getOther()) + return false; + if (load.getIsVolatile()) + return false; + if (load.getCache() != triton::CacheModifier::NONE || + load.getEvict() != triton::EvictionPolicy::NORMAL) + return false; + + auto loadTy = dyn_cast(load.getType()); + if (!loadTy) + return false; + + Value ptr = stripConvertLayouts(load.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return false; + + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + if (!ptrTy) + return false; + + auto memDescTy = + dyn_cast(localPointers.getSrc().getType()); + if (!memDescTy) + return false; + + auto memDescShape = memDescTy.getShape(); + if (loadTy.getShape() != memDescShape || ptrTy.getShape() != memDescShape) + return false; + if (loadTy.getElementType() != memDescTy.getElementType()) + return false; + + auto indices = localPointers.getIndices(); + if (indices.empty()) + return true; + if (indices.size() != memDescShape.size()) + return false; + + for (auto [axis, index] : llvm::enumerate(indices)) + if (!matchFullIndexTensorForAxis(index, axis, memDescShape)) + return false; + return true; +} + +static int64_t getScfLoopDepth(Operation *op) { + int64_t depth = 0; + for (Operation *cur = op; cur; cur = cur->getParentOp()) + if (isa(cur)) + ++depth; + return depth; +} + +static bool valueFeedsDot(Value root) { + llvm::SmallVector worklist; + llvm::DenseSet visited; + auto enqueue = [&](Value v) { + if (v && visited.insert(v).second) + worklist.push_back(v); + }; + enqueue(root); + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (isa(owner)) + return true; + if (auto convert = dyn_cast(owner)) { + enqueue(convert.getResult()); + continue; + } + if (auto trans = dyn_cast(owner)) { + enqueue(trans.getResult()); + continue; + } + if (auto bcast = dyn_cast(owner)) { + enqueue(bcast.getResult()); + continue; + } + if (auto expand = dyn_cast(owner)) { + enqueue(expand.getResult()); + continue; + } + if (auto reshape = dyn_cast(owner)) { + enqueue(reshape.getResult()); + continue; + } + } + } + return false; +} + +struct EncodingVote { + Attribute encoding; + int64_t score; +}; + +using CachedConversionKey = std::pair; +using CachedConversionMap = + llvm::DenseMap>; + +static Value getOrCreateCachedConvertLayout(OpBuilder &builder, + Operation *insertBefore, Value v, + Attribute encoding, + CachedConversionMap &cache) { + Value stripped = stripConvertLayouts(v); + auto strippedTy = dyn_cast(stripped.getType()); + if (strippedTy && strippedTy.getEncoding() == encoding) + return stripped; + + auto vTy = dyn_cast(v.getType()); + if (!vTy) + return v; + if (vTy.getEncoding() == encoding) + return v; + + CachedConversionKey key{v, encoding}; + auto it = cache.find(key); + if (it != cache.end()) { + for (Value candidate : it->second) { + Operation *def = candidate.getDefiningOp(); + if (!def) + continue; + if (def->getBlock() != insertBefore->getBlock()) + continue; + if (def->isBeforeInBlock(insertBefore)) + return candidate; + } + } + + auto convertedTy = + RankedTensorType::get(vTy.getShape(), vTy.getElementType(), encoding); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(insertBefore); + auto converted = triton::gpu::ConvertLayoutOp::create( + builder, insertBefore->getLoc(), convertedTy, v); + Value convertedValue = converted.getResult(); + cache[key].push_back(convertedValue); + return convertedValue; +} + +static Operation *peelAxisInfoCarrier(Value value) { + llvm::DenseSet visited; + Value current = value; + while (current && visited.insert(current).second) { + Operation *def = current.getDefiningOp(); + if (!def) + break; + if (auto convert = dyn_cast(def)) { + current = convert.getSrc(); + continue; + } + if (auto bcast = dyn_cast(def)) { + current = bcast.getSrc(); + continue; + } + if (auto expand = dyn_cast(def)) { + current = expand.getSrc(); + continue; + } + if (auto reshape = dyn_cast(def)) { + current = reshape.getSrc(); + continue; + } + return def; + } + return current ? current.getDefiningOp() : nullptr; +} + +static void copyAxisInfoAttrs(Operation *src, Operation *dst) { + if (!src || !dst) + return; + auto tryCopy = [&](StringRef name) { + if (dst->getDiscardableAttr(name)) + return; + if (auto attr = src->getDiscardableAttr(name)) + dst->setDiscardableAttr(name, attr); + }; + tryCopy(kTTContiguityAttr); + tryCopy(kTTDivisibilityAttr); + tryCopy(kTTConstancyAttr); +} + +static void +collectConsumerEncodingVotes(Value root, + llvm::SmallVectorImpl &votes) { + auto rootLocal = stripConvertLayouts(root) + .getDefiningOp(); + bool preferMaskForScalarLocalPointers = false; + if (rootLocal) { + if (auto memDescTy = + dyn_cast(rootLocal.getSrc().getType())) { + int64_t elemCount = 1; + for (int64_t dim : memDescTy.getShape()) { + if (dim <= 0) { + elemCount = 0; + break; + } + elemCount *= dim; + } + preferMaskForScalarLocalPointers = (elemCount == 1); + } + } + + llvm::SmallVector worklist; + llvm::DenseSet visited; + auto enqueue = [&](Value v) { + if (!v) + return; + if (!visited.insert(v).second) + return; + worklist.push_back(v); + }; + + enqueue(root); + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + for (OpOperand &use : current.getUses()) { + Operation *owner = use.getOwner(); + if (auto load = dyn_cast(owner)) { + if (isRewritableFullViewLocalPointerLoad(load)) + continue; + if (Attribute loadEncoding = + getStrippedTensorEncoding(load.getResult())) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + int64_t score = 8 * depthFactor; + if (valueFeedsDot(load.getResult())) + score += 128 * depthFactor; + votes.push_back({loadEncoding, score}); + } + continue; + } + if (auto store = dyn_cast(owner)) { + if (Attribute valueEncoding = + getStrippedTensorEncoding(store.getValue())) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + int64_t score = 2 * depthFactor; + if (Operation *def = store.getValue().getDefiningOp(); + def && isa(def)) + score += 8 * depthFactor; + votes.push_back({valueEncoding, score}); + } + if (Value mask = store.getMask()) + if (Attribute maskEncoding = getStrippedTensorEncoding(mask)) + votes.push_back({maskEncoding, 2 * (1 + getScfLoopDepth(owner))}); + continue; + } + if (auto atomic = dyn_cast(owner)) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + const int64_t valScore = + (preferMaskForScalarLocalPointers ? 8 : 24) * depthFactor; + const int64_t maskScoreBase = + (preferMaskForScalarLocalPointers ? 48 : 12) * depthFactor; + const int64_t resultScore = + (preferMaskForScalarLocalPointers ? 0 : 12) * depthFactor; + if (Attribute valEncoding = getStrippedTensorEncoding(atomic.getVal())) + votes.push_back({valEncoding, valScore}); + if (Value mask = atomic.getMask()) { + if (Attribute maskEncoding = getStrippedTensorEncoding(mask)) { + int64_t maskScore = maskScoreBase; + if (preferMaskForScalarLocalPointers && + isConstantLikeTensorValue(mask)) + maskScore = depthFactor; + votes.push_back({maskEncoding, maskScore}); + } + } + if (resultScore > 0) + if (Attribute resultEncoding = + getStrippedTensorEncoding(atomic.getResult())) + votes.push_back({resultEncoding, resultScore}); + continue; + } + if (auto cas = dyn_cast(owner)) { + const int64_t depthFactor = 1 + getScfLoopDepth(owner); + const int64_t valScore = + (preferMaskForScalarLocalPointers ? 8 : 24) * depthFactor; + const int64_t cmpScore = + (preferMaskForScalarLocalPointers ? 48 : 12) * depthFactor; + const int64_t resultScore = + (preferMaskForScalarLocalPointers ? 0 : 12) * depthFactor; + if (Attribute cmpEncoding = getStrippedTensorEncoding(cas.getCmp())) + votes.push_back({cmpEncoding, cmpScore}); + if (Attribute valEncoding = getStrippedTensorEncoding(cas.getVal())) + votes.push_back({valEncoding, valScore}); + if (resultScore > 0) + if (Attribute resultEncoding = + getStrippedTensorEncoding(cas.getResult())) + votes.push_back({resultEncoding, resultScore}); + continue; + } + if (auto convert = dyn_cast(owner)) { + enqueue(convert.getResult()); + continue; + } + if (auto bcast = dyn_cast(owner)) { + enqueue(bcast.getResult()); + continue; + } + if (auto expand = dyn_cast(owner)) { + enqueue(expand.getResult()); + continue; + } + if (auto reshape = dyn_cast(owner)) { + enqueue(reshape.getResult()); + continue; + } + } + } +} + +static Attribute pickDominantEncoding(ArrayRef votes, + Attribute fallback) { + if (votes.empty()) + return fallback; + + llvm::DenseMap scoreByEncoding; + llvm::SmallVector order; + for (const EncodingVote &vote : votes) { + if (!vote.encoding) + continue; + auto [it, inserted] = scoreByEncoding.try_emplace(vote.encoding, 0); + if (inserted) + order.push_back(vote.encoding); + it->second += vote.score; + } + if (order.empty()) + return fallback; + + Attribute best = order.front(); + int64_t bestScore = scoreByEncoding.lookup(best); + for (Attribute encoding : order) { + int64_t score = scoreByEncoding.lookup(encoding); + if (score > bestScore) { + best = encoding; + bestScore = score; + continue; + } + if (score == bestScore && encoding == fallback) + best = encoding; + } + return best; +} + +static bool isPointerTensorType(Type type) { + auto tensorTy = dyn_cast(type); + if (!tensorTy) + return false; + return isa(tensorTy.getElementType()); +} + +static void bridgeResultTypeToOldEncoding(Value result, Type oldType, + OpBuilder &builder) { + if (result.getType() == oldType) + return; + auto oldTensorTy = dyn_cast(oldType); + if (!oldTensorTy) + return; + Operation *def = result.getDefiningOp(); + if (!def) + return; + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(def); + auto bridge = triton::gpu::ConvertLayoutOp::create(builder, def->getLoc(), + oldTensorTy, result); + result.replaceAllUsesExcept(bridge.getResult(), bridge.getOperation()); +} + +static bool tryFoldPointerConvertLayout(triton::gpu::ConvertLayoutOp convert, + OpBuilder &builder, + CachedConversionMap &cache) { + auto srcTy = dyn_cast(convert.getSrc().getType()); + auto dstTy = dyn_cast(convert.getType()); + if (!srcTy || !dstTy) + return false; + if (!isa(srcTy.getElementType()) || + !isa(dstTy.getElementType())) + return false; + + Value srcPtr = convert.getSrc(); + Value convertedPtr = convert.getResult(); + Attribute srcEncoding = srcTy.getEncoding(); + auto srcElemTy = + cast(srcTy.getElementType()).getPointeeType(); + auto srcLoadTy = + RankedTensorType::get(srcTy.getShape(), srcElemTy, srcEncoding); + + SmallVector uses; + uses.reserve(convertedPtr.getNumUses()); + for (OpOperand &use : convertedPtr.getUses()) { + Operation *owner = use.getOwner(); + if (!isa(owner)) + return false; + uses.push_back(&use); + } + + auto convertOperandEncoding = [&](Operation *insertBefore, Value v, + Attribute encoding) -> Value { + return getOrCreateCachedConvertLayout(builder, insertBefore, v, encoding, + cache); + }; + + for (OpOperand *use : uses) { + Operation *owner = use->getOwner(); + use->set(srcPtr); + + if (auto load = dyn_cast(owner)) { + if (Value mask = load.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + load.getMaskMutable().assign(convertedMask); + } + if (Value other = load.getOther()) { + Value convertedOther = + convertOperandEncoding(owner, other, srcEncoding); + if (convertedOther != other) + load.getOtherMutable().assign(convertedOther); + } + Type oldType = load.getResult().getType(); + if (oldType != srcLoadTy) { + load.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(load.getResult(), oldType, builder); + } + continue; + } + + if (auto store = dyn_cast(owner)) { + Value value = store.getValue(); + Value convertedValue = convertOperandEncoding(owner, value, srcEncoding); + if (convertedValue != value) + store.getValueMutable().assign(convertedValue); + if (Value mask = store.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + store.getMaskMutable().assign(convertedMask); + } + continue; + } + + if (auto atomic = dyn_cast(owner)) { + Value val = atomic.getVal(); + Value convertedVal = convertOperandEncoding(owner, val, srcEncoding); + if (convertedVal != val) + atomic.getValMutable().assign(convertedVal); + if (Value mask = atomic.getMask()) { + Value convertedMask = convertOperandEncoding(owner, mask, srcEncoding); + if (convertedMask != mask) + atomic.getMaskMutable().assign(convertedMask); + } + Type oldType = atomic.getResult().getType(); + if (oldType != srcLoadTy) { + atomic.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(atomic.getResult(), oldType, builder); + } + continue; + } + + auto cas = cast(owner); + Value cmp = cas.getCmp(); + Value convertedCmp = convertOperandEncoding(owner, cmp, srcEncoding); + if (convertedCmp != cmp) + cas.getCmpMutable().assign(convertedCmp); + Value val = cas.getVal(); + Value convertedVal = convertOperandEncoding(owner, val, srcEncoding); + if (convertedVal != val) + cas.getValMutable().assign(convertedVal); + Type oldType = cas.getResult().getType(); + if (oldType != srcLoadTy) { + cas.getResult().setType(srcLoadTy); + bridgeResultTypeToOldEncoding(cas.getResult(), oldType, builder); + } + } + + if (convertedPtr.use_empty()) + convert.erase(); + return true; +} + +class SelectEncodingsPass + : public impl::TritonMUSAGPUTLESelectEncodingsBase { + void runOnOperation() override { + ModuleOp module = getOperation(); + OpBuilder builder(module.getContext()); + CachedConversionMap userOperandConversionCache; + CachedConversionMap indexOperandConversionCache; + module.walk([&](triton::musa_tle::LocalPointersOp op) { + // Always tag local pointer ops so barrier insertion can track hazards + // across different pointer views of the same alloc. + tagDependencyGroup(op, builder); + + auto tensorTy = dyn_cast(op.getResult().getType()); + auto scalarPtrTy = + dyn_cast(op.getResult().getType()); + if (!tensorTy && !scalarPtrTy) + return; + auto ptrTy = + tensorTy ? dyn_cast(tensorTy.getElementType()) + : scalarPtrTy; + if (!ptrTy) + return; + bool updated = false; + Type updatedResultTy = op.getResult().getType(); + const auto desiredAddrSpace = kSharedMemoryAddressSpace; + if (ptrTy.getAddressSpace() != desiredAddrSpace) { + ptrTy = + triton::PointerType::get(ptrTy.getPointeeType(), desiredAddrSpace); + updated = true; + } + + if (!tensorTy) { + if (updated) + op.getResult().setType(ptrTy); + return; + } + + auto encoding = tensorTy.getEncoding(); + SmallVector votes; + collectConsumerEncodingVotes(op.getResult(), votes); + for (Value index : op.getIndices()) { + Attribute indexEncoding = getStrippedTensorEncoding(index); + if (!indexEncoding) + continue; + const bool constantLike = isConstantLikeTensorValue(index); + int64_t elemCount = 1; + if (auto indexTy = dyn_cast(index.getType())) { + for (int64_t dim : indexTy.getShape()) { + if (dim <= 0) { + elemCount = 0; + break; + } + elemCount *= dim; + } + } + const int64_t depthFactor = 1 + getScfLoopDepth(op.getOperation()); + int64_t baseScore = constantLike ? 1 : 12; + if (!constantLike) { + if (elemCount >= 1024) + baseScore = 192; + else if (elemCount >= 256) + baseScore = 64; + } + const int64_t score = baseScore * depthFactor; + votes.push_back({indexEncoding, score}); + } + Attribute userEncoding = pickDominantEncoding(votes, encoding); + if (userEncoding && userEncoding != encoding) { + encoding = userEncoding; + updated = true; + } + if (!encoding) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + int numWarps = triton::gpu::maybeLookupNumWarps(op).value_or(1); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(builder); + int numCTAs = triton::gpu::lookupNumCTAs(builder); + encoding = triton::gpu::getDefaultBlockedEncoding( + module.getContext(), tensorTy.getShape(), numWarps, threadsPerWarp, + numCTAs); + updated = true; + } + + if (updated) + updatedResultTy = + RankedTensorType::get(tensorTy.getShape(), ptrTy, encoding); + + if (updated) + op.getResult().setType(updatedResultTy); + + if (updated) { + llvm::DenseSet visited; + auto updateUserResultTypes = [&](auto &&self, Value ptrVal) -> void { + if (!ptrVal || !visited.insert(ptrVal).second) + return; + auto ptrTensorTy = cast(ptrVal.getType()); + auto ptrEncoding = ptrTensorTy.getEncoding(); + auto ptrElemTy = + cast(ptrTensorTy.getElementType()) + .getPointeeType(); + auto loadTy = RankedTensorType::get(ptrTensorTy.getShape(), ptrElemTy, + ptrTensorTy.getEncoding()); + auto convertOperandEncoding = [&](Operation *insertBefore, Value v, + Attribute encoding) -> Value { + return getOrCreateCachedConvertLayout( + builder, insertBefore, v, encoding, userOperandConversionCache); + }; + for (OpOperand &use : ptrVal.getUses()) { + Operation *owner = use.getOwner(); + if (auto load = dyn_cast(owner)) { + if (isRewritableFullViewLocalPointerLoad(load)) + continue; + if (Value mask = load.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + load.getMaskMutable().assign(convertedMask); + } + if (Value other = load.getOther()) { + Value convertedOther = + convertOperandEncoding(owner, other, ptrEncoding); + if (convertedOther != other) + load.getOtherMutable().assign(convertedOther); + } + auto oldLoadTy = + dyn_cast(load.getResult().getType()); + if (oldLoadTy != loadTy) { + load.getResult().setType(loadTy); + if (oldLoadTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(load); + auto bridge = triton::gpu::ConvertLayoutOp::create( + builder, load.getLoc(), oldLoadTy, load.getResult()); + load.getResult().replaceAllUsesExcept(bridge.getResult(), + bridge.getOperation()); + } + } + continue; + } + if (auto store = dyn_cast(owner)) { + auto valueTy = + dyn_cast(store.getValue().getType()); + if (valueTy) { + Value convertedValue = convertOperandEncoding( + owner, store.getValue(), ptrEncoding); + if (convertedValue != store.getValue()) + store.getValueMutable().assign(convertedValue); + } + if (Value mask = store.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + store.getMaskMutable().assign(convertedMask); + } + continue; + } + if (auto atomic = dyn_cast(owner)) { + Value val = atomic.getVal(); + Value convertedVal = + convertOperandEncoding(owner, val, ptrEncoding); + if (convertedVal != val) + atomic.getValMutable().assign(convertedVal); + if (Value mask = atomic.getMask()) { + Value convertedMask = + convertOperandEncoding(owner, mask, ptrEncoding); + if (convertedMask != mask) + atomic.getMaskMutable().assign(convertedMask); + } + auto oldAtomicTy = + dyn_cast(atomic.getResult().getType()); + if (oldAtomicTy != loadTy) { + atomic.getResult().setType(loadTy); + if (oldAtomicTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(atomic); + auto bridge = triton::gpu::ConvertLayoutOp::create( + builder, atomic.getLoc(), oldAtomicTy, + atomic.getResult()); + atomic.getResult().replaceAllUsesExcept( + bridge.getResult(), bridge.getOperation()); + } + } + continue; + } + if (auto cas = dyn_cast(owner)) { + Value cmp = cas.getCmp(); + Value convertedCmp = + convertOperandEncoding(owner, cmp, ptrEncoding); + if (convertedCmp != cmp) + cas.getCmpMutable().assign(convertedCmp); + Value val = cas.getVal(); + Value convertedVal = + convertOperandEncoding(owner, val, ptrEncoding); + if (convertedVal != val) + cas.getValMutable().assign(convertedVal); + auto oldCasTy = + dyn_cast(cas.getResult().getType()); + if (oldCasTy != loadTy) { + cas.getResult().setType(loadTy); + if (oldCasTy) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(cas); + auto bridge = triton::gpu::ConvertLayoutOp::create( + builder, cas.getLoc(), oldCasTy, cas.getResult()); + cas.getResult().replaceAllUsesExcept(bridge.getResult(), + bridge.getOperation()); + } + } + continue; + } + } + }; + updateUserResultTypes(updateUserResultTypes, op.getResult()); + } + + auto desiredEncoding = + cast(updatedResultTy).getEncoding(); + if (desiredEncoding) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + newOperands.push_back(op.getSrc()); + bool updatedOperands = false; + for (Value operand : op.getIndices()) { + auto operandTy = dyn_cast(operand.getType()); + if (!operandTy) { + newOperands.push_back(operand); + continue; + } + if (operandTy.getEncoding() == desiredEncoding) { + newOperands.push_back(operand); + continue; + } + auto converted = getOrCreateCachedConvertLayout( + builder, op.getOperation(), operand, desiredEncoding, + indexOperandConversionCache); + newOperands.push_back(converted); + updatedOperands = (converted != operand) || updatedOperands; + } + if (updatedOperands) + op->setOperands(newOperands); + } + }); + + // Fold pointer convert_layout around local pointer users after + // encoding updates to avoid leaving convert chains on ptr tensors. + bool changed = true; + while (changed) { + changed = false; + SmallVector ptrConverts; + module.walk([&](triton::gpu::ConvertLayoutOp convert) { + if (isPointerTensorType(convert.getType()) && + isPointerTensorType(convert.getSrc().getType())) + ptrConverts.push_back(convert); + }); + for (triton::gpu::ConvertLayoutOp convert : ptrConverts) { + if (convert->getBlock() == nullptr) + continue; + changed |= tryFoldPointerConvertLayout(convert, builder, + userOperandConversionCache); + } + } + } + + void tagDependencyGroup(triton::musa_tle::LocalPointersOp op, + OpBuilder &builder) { + auto alloc = op.getSrc().getDefiningOp(); + if (!alloc) + return; + auto groupAttr = alloc->getAttrOfType(kBarrierGroupAttr); + if (!groupAttr) { + groupAttr = builder.getI64IntegerAttr(nextBarrierGroupId++); + alloc->setAttr(kBarrierGroupAttr, groupAttr); + } + op->setAttr(kBarrierGroupAttr, groupAttr); + } + + int64_t nextBarrierGroupId = 0; +}; + +} // namespace +} // namespace mlir + +#endif // __TLE__ diff --git a/third_party/mthreads/triton_mthreads.cc b/third_party/mthreads/triton_mthreads.cc index c54341f342..09a178804c 100644 --- a/third_party/mthreads/triton_mthreads.cc +++ b/third_party/mthreads/triton_mthreads.cc @@ -3,9 +3,6 @@ #include "MTGPUToLLVM/Passes.h" #include "TritonMUSAGPUToLLVM/Passes.h" #include "TritonMUSAGPUTransforms/Passes.h" -#ifdef __TLE__ -#include "ir.h" -#endif #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" @@ -14,9 +11,6 @@ #include "mlir/Target/LLVMIR/Dialect/MTVM/MTVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "passes.h" -#ifdef __TLE__ -#include "triton/Dialect/TritonGPU/IR/Dialect.h" -#endif #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -25,13 +19,15 @@ #include #include #include -#ifdef __TLE__ -#include -#include -#endif namespace py = pybind11; +#ifdef __TLE__ +void init_triton_musa_tle_ir(py::module m); +void init_triton_musa_tle_passes_ttgpuir(py::module m); +void register_triton_musa_tle_dialects(mlir::DialectRegistry ®istry); +#endif + namespace { llvm::Function *findPrimaryKernel(llvm::Module &module, @@ -110,116 +106,9 @@ bool moduleUsesMulhiHelper(const llvm::Module &module) { return false; } -#ifdef __TLE__ -namespace ttg = mlir::triton::gpu; - -void checkCtaRank(llvm::ArrayRef order, - llvm::ArrayRef ctasPerCGA, - llvm::ArrayRef ctaSplitNum, - llvm::ArrayRef ctaOrder) { - if (order.size() != ctasPerCGA.size() || order.size() != ctaSplitNum.size() || - order.size() != ctaOrder.size()) - throw py::value_error("shared layout rank mismatch in CTA parameters"); -} - -ttg::CGAEncodingAttr makeCgaLayout(mlir::MLIRContext *context, - llvm::ArrayRef ctasPerCGA, - llvm::ArrayRef ctaSplitNum, - llvm::ArrayRef ctaOrder) { - return ttg::CGAEncodingAttr::fromSplitParams(context, ctasPerCGA, ctaSplitNum, - ctaOrder); -} - -mlir::Attribute getSharedMemorySpace(mlir::MLIRContext *context, - const std::string &storage) { - if (storage == "smem" || storage == "share_memory" || - storage == "shared_memory") - return ttg::SharedMemorySpaceAttr::get(context); - if (storage == "tmem" || storage == "tensor_memory") - throw py::value_error("mthreads TLE alloc does not support tmem storage"); - throw py::value_error("mthreads TLE alloc only supports smem storage"); -} -#endif // __TLE__ - } // namespace -#ifdef __TLE__ -void init_triton_mthreads_ir(py::module &&m) { - (void)m; - - auto *builderClsPtr = ir::getBuilderClass(); - if (!builderClsPtr) - throw std::runtime_error("triton IR builder class is not initialized"); - - auto &builderCls = *builderClsPtr; - builderCls - .def("make_swizzled_shared_encoding_attr", - [](TritonOpBuilder &self, unsigned vectorSize, unsigned perPhase, - unsigned maxPhase, std::vector order, - std::vector CTAsPerCGA, - std::vector CTASplitNum, - std::vector CTAOrder) -> mlir::Attribute { - checkCtaRank(order, CTAsPerCGA, CTASplitNum, CTAOrder); - auto *context = self.getBuilder().getContext(); - auto cgaLayout = - makeCgaLayout(context, CTAsPerCGA, CTASplitNum, CTAOrder); - return ttg::SwizzledSharedEncodingAttr::get( - context, vectorSize, perPhase, maxPhase, order, cgaLayout); - }) - .def("make_nv_mma_shared_encoding_attr", - [](TritonOpBuilder &, std::vector, std::vector, - mlir::Type &, std::vector, std::vector, - std::vector, bool, bool) -> mlir::Attribute { - throw py::value_error("mthreads TLE alloc does not support " - "nv_mma_shared_layout=True"); - }) - .def("make_tensor_memory_encoding_attr", - [](TritonOpBuilder &, unsigned, unsigned, unsigned, unsigned, - unsigned, bool) -> mlir::Attribute { - throw py::value_error( - "mthreads TLE alloc does not support tmem storage"); - }) - .def("create_local_alloc", - [](TritonOpBuilder &self, std::vector shape, - mlir::Type &elementType, - mlir::Attribute &encoding) -> mlir::Value { - auto *context = self.getBuilder().getContext(); - auto memorySpace = ttg::SharedMemorySpaceAttr::get(context); - auto memDesc = ttg::MemDescType::get(shape, elementType, encoding, - memorySpace, - /*mutableMemory=*/true); - return self.create(memDesc); - }) - .def("create_local_alloc", - [](TritonOpBuilder &self, mlir::Type resultTy, - mlir::Value value) -> mlir::Value { - return self.create(resultTy, value); - }) - .def("get_memdesc_type", - [](TritonOpBuilder &self, std::vector shape, - mlir::Type &elementType, mlir::Attribute &encoding, - std::string storage) -> mlir::Type { - auto *context = self.getBuilder().getContext(); - auto memorySpace = getSharedMemorySpace(context, storage); - return ttg::MemDescType::get(shape, elementType, encoding, - memorySpace, - /*mutableMemory=*/true); - }) - .def("get_memdesc_type", - [](TritonOpBuilder &self, std::vector shape, - mlir::Type &elementType, mlir::Attribute &encoding, - std::string storage, - std::vector allocShape) -> mlir::Type { - auto *context = self.getBuilder().getContext(); - auto memorySpace = getSharedMemorySpace(context, storage); - return ttg::MemDescType::get(shape, elementType, encoding, - memorySpace, - /*mutableMemory=*/true, allocShape); - }); -} -#endif // __TLE__ - -void init_triton_musa_passes_ttgpuir(py::module &&m) { +void init_triton_musa_passes_ttgpuir(py::module m) { using namespace mlir::triton; m.def("add_mtgpu_to_llvm", [](mlir::PassManager &pm, int32_t capability) { pm.addPass(mlir::triton::createConvertMTGPUToLLVMPass(capability)); @@ -265,11 +154,15 @@ void init_triton_musa_passes_ttgpuir(py::module &&m) { void init_triton_mthreads(py::module &&m) { #ifdef __TLE__ - init_triton_mthreads_ir(m.def_submodule("ir")); + init_triton_musa_tle_ir(m.def_submodule("ir")); #endif // __TLE__ auto passes = m.def_submodule("passes"); - init_triton_musa_passes_ttgpuir(passes.def_submodule("ttgpuir")); + auto ttgpuir = passes.def_submodule("ttgpuir"); + init_triton_musa_passes_ttgpuir(ttgpuir); +#ifdef __TLE__ + init_triton_musa_tle_passes_ttgpuir(ttgpuir); +#endif // __TLE__ // load dialects m.def("load_dialects", [](mlir::MLIRContext &context) { @@ -277,6 +170,9 @@ void init_triton_mthreads(py::module &&m) { registry .insert(); +#ifdef __TLE__ + register_triton_musa_tle_dialects(registry); +#endif mlir::registerLLVMDialectTranslation(registry); mlir::registerMTVMDialectTranslation(registry); context.appendDialectRegistry(registry); diff --git a/third_party/mthreads/triton_mthreads_tle.cc b/third_party/mthreads/triton_mthreads_tle.cc new file mode 100644 index 0000000000..79cf0609f9 --- /dev/null +++ b/third_party/mthreads/triton_mthreads_tle.cc @@ -0,0 +1,182 @@ +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "TritonMUSAGPUTransforms/Passes.h" +#include "ir.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +namespace ttg = mlir::triton::gpu; + +namespace { + +void checkCtaRank(llvm::ArrayRef order, + llvm::ArrayRef ctasPerCGA, + llvm::ArrayRef ctaSplitNum, + llvm::ArrayRef ctaOrder) { + if (order.size() != ctasPerCGA.size() || order.size() != ctaSplitNum.size() || + order.size() != ctaOrder.size()) + throw py::value_error("shared layout rank mismatch in CTA parameters"); +} + +void normalizeRank0SharedLayout(std::vector &order, + std::vector &ctasPerCGA, + std::vector &ctaSplitNum, + std::vector &ctaOrder) { + if (!order.empty()) + return; + if (!ctasPerCGA.empty() || !ctaSplitNum.empty() || !ctaOrder.empty()) + throw py::value_error("rank-0 shared layout expects empty CTA parameters"); + // TritonGPU memdesc currently rejects true rank-0 descriptors. Mthreads TLE + // keeps Python-visible rank-0 semantics by backing such buffers with one + // shared element and a rank-1 shared layout. + order = {0}; + ctasPerCGA = {1}; + ctaSplitNum = {1}; + ctaOrder = {0}; +} + +std::vector normalizeRank0MemDescShape(std::vector shape) { + if (shape.empty()) + return {1}; + return shape; +} + +ttg::CGAEncodingAttr makeCgaLayout(mlir::MLIRContext *context, + llvm::ArrayRef ctasPerCGA, + llvm::ArrayRef ctaSplitNum, + llvm::ArrayRef ctaOrder) { + return ttg::CGAEncodingAttr::fromSplitParams(context, ctasPerCGA, ctaSplitNum, + ctaOrder); +} + +mlir::Attribute getSharedMemorySpace(mlir::MLIRContext *context, + const std::string &storage) { + if (storage == "smem" || storage == "share_memory" || + storage == "shared_memory") + return ttg::SharedMemorySpaceAttr::get(context); + if (storage == "tmem" || storage == "tensor_memory") + throw py::value_error("mthreads TLE alloc does not support tmem storage"); + throw py::value_error("mthreads TLE alloc only supports smem storage"); +} + +} // namespace + +void init_triton_musa_tle_ir(py::module m) { + (void)m; + + auto *builderClsPtr = ir::getBuilderClass(); + if (!builderClsPtr) + throw std::runtime_error("triton IR builder class is not initialized"); + + auto &builderCls = *builderClsPtr; + builderCls + .def("make_swizzled_shared_encoding_attr", + [](TritonOpBuilder &self, unsigned vectorSize, unsigned perPhase, + unsigned maxPhase, std::vector order, + std::vector CTAsPerCGA, + std::vector CTASplitNum, + std::vector CTAOrder) -> mlir::Attribute { + normalizeRank0SharedLayout(order, CTAsPerCGA, CTASplitNum, + CTAOrder); + checkCtaRank(order, CTAsPerCGA, CTASplitNum, CTAOrder); + auto *context = self.getBuilder().getContext(); + auto cgaLayout = + makeCgaLayout(context, CTAsPerCGA, CTASplitNum, CTAOrder); + return ttg::SwizzledSharedEncodingAttr::get( + context, vectorSize, perPhase, maxPhase, order, cgaLayout); + }) + .def("make_nv_mma_shared_encoding_attr", + [](TritonOpBuilder &, std::vector, std::vector, + mlir::Type &, std::vector, std::vector, + std::vector, bool, bool) -> mlir::Attribute { + throw py::value_error("mthreads TLE alloc does not support " + "nv_mma_shared_layout=True"); + }) + .def("make_tensor_memory_encoding_attr", + [](TritonOpBuilder &, unsigned, unsigned, unsigned, unsigned, + unsigned, bool) -> mlir::Attribute { + throw py::value_error( + "mthreads TLE alloc does not support tmem storage"); + }) + .def("create_local_alloc", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, + mlir::Attribute &encoding) -> mlir::Value { + auto *context = self.getBuilder().getContext(); + auto memorySpace = ttg::SharedMemorySpaceAttr::get(context); + shape = normalizeRank0MemDescShape(std::move(shape)); + auto memDesc = ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true); + return self.create(memDesc); + }) + .def("create_local_alloc", + [](TritonOpBuilder &self, mlir::Type resultTy, + mlir::Value value) -> mlir::Value { + return self.create(resultTy, value); + }) + .def("get_memdesc_type", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, mlir::Attribute &encoding, + std::string storage) -> mlir::Type { + auto *context = self.getBuilder().getContext(); + auto memorySpace = getSharedMemorySpace(context, storage); + shape = normalizeRank0MemDescShape(std::move(shape)); + return ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true); + }) + .def("get_memdesc_type", + [](TritonOpBuilder &self, std::vector shape, + mlir::Type &elementType, mlir::Attribute &encoding, + std::string storage, + std::vector allocShape) -> mlir::Type { + auto *context = self.getBuilder().getContext(); + auto memorySpace = getSharedMemorySpace(context, storage); + shape = normalizeRank0MemDescShape(std::move(shape)); + allocShape = normalizeRank0MemDescShape(std::move(allocShape)); + return ttg::MemDescType::get(shape, elementType, encoding, + memorySpace, + /*mutableMemory=*/true, allocShape); + }) + .def("create_local_pointers", + [](TritonOpBuilder &self, mlir::Type resultTy, mlir::Value memDesc, + py::args args) -> mlir::OpState { + llvm::SmallVector indices; + indices.reserve(args.size()); + for (const auto &arg : args) + indices.push_back(py::cast(arg)); + return self.create( + resultTy, memDesc, indices); + }); +} + +void init_triton_musa_tle_passes_ttgpuir(py::module m) { + ADD_PASS_WRAPPER_0("add_tle_select_encodings", + mlir::createTritonMUSAGPUTLESelectEncodings); + ADD_PASS_WRAPPER_0("add_tle_insert_local_pointer_barriers", + mlir::createTritonMUSAGPUTLEInsertLocalPointerBarriers); + ADD_PASS_WRAPPER_0("add_tle_optimize_local_pointer_loads", + mlir::createTritonMUSAGPUTLEOptimizeLocalPointerLoads); + ADD_PASS_WRAPPER_0("add_tle_optimize_local_pointer_stores", + mlir::createTritonMUSAGPUTLEOptimizeLocalPointerStores); +} + +void register_triton_musa_tle_dialects(mlir::DialectRegistry ®istry) { + registry.insert(); +} + +#endif // __TLE__ From bfb3f6ea2d065fe865deea90698db25bd0efedc8 Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Sun, 24 May 2026 19:15:21 +0800 Subject: [PATCH 06/10] [TLE][MTHREADS] Clarify mthreads TLE frontend and dialect ownership --- third_party/mthreads/CMakeLists.txt | 14 +++++-- .../include/TritonMUSAGPUTransforms/Passes.td | 34 ---------------- .../TritonMUSAGPUTransforms/CMakeLists.txt | 10 ----- third_party/mthreads/tle/CMakeLists.txt | 1 + .../{ => tle/dialect}/triton_mthreads_tle.cc | 6 ++- .../mthreads/tle/frontend/CMakeLists.txt | 2 + .../tle/frontend/include/CMakeLists.txt | 1 + .../include/MUSATLE/Frontend/CMakeLists.txt | 3 ++ .../include/MUSATLE/Frontend/Passes.h | 20 ++++++++++ .../include/MUSATLE/Frontend/Passes.td | 40 +++++++++++++++++++ .../mthreads/tle/frontend/lib/CMakeLists.txt | 1 + .../frontend/lib/Transforms/CMakeLists.txt | 23 +++++++++++ .../Transforms}/EarlyAssignMemorySpace.cpp | 4 +- .../lib/Transforms}/LowerAsyncLoad.cpp | 4 +- .../tle/frontend/triton_mthreads_frontend.cc | 21 ++++++++++ third_party/mthreads/triton_mthreads.cc | 12 ++---- 16 files changed, 136 insertions(+), 60 deletions(-) rename third_party/mthreads/{ => tle/dialect}/triton_mthreads_tle.cc (96%) create mode 100644 third_party/mthreads/tle/frontend/CMakeLists.txt create mode 100644 third_party/mthreads/tle/frontend/include/CMakeLists.txt create mode 100644 third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/CMakeLists.txt create mode 100644 third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.h create mode 100644 third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.td create mode 100644 third_party/mthreads/tle/frontend/lib/CMakeLists.txt create mode 100644 third_party/mthreads/tle/frontend/lib/Transforms/CMakeLists.txt rename third_party/mthreads/{musa/lib/TritonMUSAGPUTransforms/TLE => tle/frontend/lib/Transforms}/EarlyAssignMemorySpace.cpp (99%) rename third_party/mthreads/{musa/lib/TritonMUSAGPUTransforms/TLE => tle/frontend/lib/Transforms}/LowerAsyncLoad.cpp (98%) create mode 100644 third_party/mthreads/tle/frontend/triton_mthreads_frontend.cc diff --git a/third_party/mthreads/CMakeLists.txt b/third_party/mthreads/CMakeLists.txt index fca39c3527..b32683f35e 100644 --- a/third_party/mthreads/CMakeLists.txt +++ b/third_party/mthreads/CMakeLists.txt @@ -6,6 +6,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/musa/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/musa/include) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tle/dialect/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/tle/dialect/include) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tle/frontend/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/tle/frontend/include) if(FLAGTREE_MTHREADS_TLE) add_subdirectory(tle) endif() @@ -15,9 +17,15 @@ add_subdirectory(musa) if(TRITON_BUILD_PYTHON_MODULE) if(FLAGTREE_MTHREADS_TLE) set(_MTHREADS_TLE_PLUGIN_SOURCES - ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads_tle.cc) - set(_MTHREADS_TLE_PLUGIN_LIBS MUSATLEIR MUSATLETransforms) - set(_MTHREADS_TLE_PLUGIN_DEPS MUSATLETableGen MUSATLETransforms) + ${CMAKE_CURRENT_SOURCE_DIR}/tle/dialect/triton_mthreads_tle.cc + ${CMAKE_CURRENT_SOURCE_DIR}/tle/frontend/triton_mthreads_frontend.cc) + set(_MTHREADS_TLE_PLUGIN_LIBS + MUSATLEIR MUSATLETransforms MUSATLEFrontendTransforms) + set(_MTHREADS_TLE_PLUGIN_DEPS + MUSATLETableGen + MUSATLETransforms + MUSATLEFrontendTransformsIncGen + MUSATLEFrontendTransforms) else() set(_MTHREADS_TLE_PLUGIN_SOURCES "") set(_MTHREADS_TLE_PLUGIN_LIBS "") diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td index ab394f8f58..baf0ed1f23 100644 --- a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/Passes.td @@ -191,40 +191,6 @@ def TritonMUSAGPUMarkInplaceLoads } #ifdef __TLE__ -def TritonMUSAGPUTLEEarlyAssignMemorySpace - : Pass<"tritonmusa-tle-early-assign-memory-space", "mlir::ModuleOp"> { - let summary = "Materialize TLE memory-space annotations for MUSA"; - let description = [{ - Rewrite tensors marked with `tt.memory_space = "shared_memory"` into - explicit shared-memory memdesc traffic before MUSA-specific layout and MMA - transforms run. Legal loads are lowered through MUSA async copy plus - commit/wait; all other producers use initialized local_alloc/local_load - materialization to preserve tensor semantics. - }]; - let dependentDialects = [ - "mlir::arith::ArithDialect", - "mlir::triton::TritonDialect", - "mlir::triton::gpu::TritonGPUDialect" - ]; -} - -def TritonMUSAGPUTLELowerAsyncLoad - : Pass<"tritonmusa-tle-lower-async-load", "mlir::ModuleOp"> { - let summary = "Lower TLE async load hints to MUSA async copies"; - let description = [{ - Rewrite `tt.load` operations marked with `tt.load.async = true` into - `ttg.async_copy_global_to_local` plus commit/wait/local_load when the - load can be represented by the MUSA async global-to-shared copy path. - Unsupported hints are dropped so `tle.load(..., is_async=True)` remains - a scheduling hint with ordinary `tt.load` correctness semantics. - }]; - let dependentDialects = [ - "mlir::arith::ArithDialect", - "mlir::triton::TritonDialect", - "mlir::triton::gpu::TritonGPUDialect" - ]; -} - include "MUSATLE/Transforms/Passes.td" #endif // __TLE__ diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt index 94f436452a..07636b167b 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/CMakeLists.txt @@ -1,12 +1,3 @@ -if(FLAGTREE_MTHREADS_TLE) - set(_TLE_SOURCES - TLE/EarlyAssignMemorySpace.cpp - TLE/LowerAsyncLoad.cpp - ) -else() - set(_TLE_SOURCES "") -endif() - add_triton_library(TritonMUSAGPUTransforms AccelerateMUSAMatmul.cpp CanonicalizeSqmmaResultConversions.cpp @@ -22,7 +13,6 @@ add_triton_library(TritonMUSAGPUTransforms SqmmaPipelineUtils.cpp TMEPipelineUtils.cpp TMELowering.cpp - ${_TLE_SOURCES} DEPENDS TritonMUSAGPUTransformsIncGen diff --git a/third_party/mthreads/tle/CMakeLists.txt b/third_party/mthreads/tle/CMakeLists.txt index 562832e921..54df1319c6 100644 --- a/third_party/mthreads/tle/CMakeLists.txt +++ b/third_party/mthreads/tle/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(dialect) +add_subdirectory(frontend) diff --git a/third_party/mthreads/triton_mthreads_tle.cc b/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc similarity index 96% rename from third_party/mthreads/triton_mthreads_tle.cc rename to third_party/mthreads/tle/dialect/triton_mthreads_tle.cc index 79cf0609f9..bf47655d74 100644 --- a/third_party/mthreads/triton_mthreads_tle.cc +++ b/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc @@ -20,6 +20,10 @@ namespace py = pybind11; namespace ttg = mlir::triton::gpu; +// Backend-local `musa_tle` dialect adapters. Frontend marker pass wrappers +// live in tle/frontend/triton_mthreads_frontend.cc; keep them separate from +// `musa_tle.local_pointers` builder and transform bindings. + namespace { void checkCtaRank(llvm::ArrayRef order, @@ -164,7 +168,7 @@ void init_triton_musa_tle_ir(py::module m) { }); } -void init_triton_musa_tle_passes_ttgpuir(py::module m) { +void init_triton_musa_tle_dialect_passes_ttgpuir(py::module m) { ADD_PASS_WRAPPER_0("add_tle_select_encodings", mlir::createTritonMUSAGPUTLESelectEncodings); ADD_PASS_WRAPPER_0("add_tle_insert_local_pointer_barriers", diff --git a/third_party/mthreads/tle/frontend/CMakeLists.txt b/third_party/mthreads/tle/frontend/CMakeLists.txt new file mode 100644 index 0000000000..8a43d93a8b --- /dev/null +++ b/third_party/mthreads/tle/frontend/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(include) +add_subdirectory(lib) diff --git a/third_party/mthreads/tle/frontend/include/CMakeLists.txt b/third_party/mthreads/tle/frontend/include/CMakeLists.txt new file mode 100644 index 0000000000..16113e19b1 --- /dev/null +++ b/third_party/mthreads/tle/frontend/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MUSATLE/Frontend) diff --git a/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/CMakeLists.txt b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/CMakeLists.txt new file mode 100644 index 0000000000..7a3124a426 --- /dev/null +++ b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name MUSATLEFrontend) +add_public_tablegen_target(MUSATLEFrontendTransformsIncGen) diff --git a/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.h b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.h new file mode 100644 index 0000000000..1e32d281c0 --- /dev/null +++ b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.h @@ -0,0 +1,20 @@ +#ifndef MTHREADS_MUSATLE_FRONTEND_PASSES_H +#define MTHREADS_MUSATLE_FRONTEND_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +// Generate the pass class declarations. +#define GEN_PASS_DECL +#include "MUSATLE/Frontend/Passes.h.inc" + +} // namespace mlir + +namespace mlir { +// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "MUSATLE/Frontend/Passes.h.inc" +} // namespace mlir + +#endif // MTHREADS_MUSATLE_FRONTEND_PASSES_H diff --git a/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.td b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.td new file mode 100644 index 0000000000..bd62a3b3dc --- /dev/null +++ b/third_party/mthreads/tle/frontend/include/MUSATLE/Frontend/Passes.td @@ -0,0 +1,40 @@ +#ifndef MTHREADS_MUSATLE_FRONTEND_PASSES +#define MTHREADS_MUSATLE_FRONTEND_PASSES + +include "mlir/Pass/PassBase.td" + +def TritonMUSAGPUTLEEarlyAssignMemorySpace + : Pass<"tritonmusa-tle-early-assign-memory-space", "mlir::ModuleOp"> { + let summary = "Materialize TLE memory-space annotations for MUSA"; + let description = [{ + Rewrite tensors marked with `tt.memory_space = "shared_memory"` into + explicit shared-memory memdesc traffic before MUSA-specific layout and MMA + transforms run. Legal loads are lowered through MUSA async copy plus + commit/wait; all other producers use initialized local_alloc/local_load + materialization to preserve tensor semantics. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +def TritonMUSAGPUTLELowerAsyncLoad + : Pass<"tritonmusa-tle-lower-async-load", "mlir::ModuleOp"> { + let summary = "Lower TLE async load hints to MUSA async copies"; + let description = [{ + Rewrite `tt.load` operations marked with `tt.load.async = true` into + `ttg.async_copy_global_to_local` plus commit/wait/local_load when the + load can be represented by the MUSA async global-to-shared copy path. + Unsupported hints are dropped so `tle.load(..., is_async=True)` remains + a scheduling hint with ordinary `tt.load` correctness semantics. + }]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::gpu::TritonGPUDialect" + ]; +} + +#endif // MTHREADS_MUSATLE_FRONTEND_PASSES diff --git a/third_party/mthreads/tle/frontend/lib/CMakeLists.txt b/third_party/mthreads/tle/frontend/lib/CMakeLists.txt new file mode 100644 index 0000000000..e31af32661 --- /dev/null +++ b/third_party/mthreads/tle/frontend/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Transforms) diff --git a/third_party/mthreads/tle/frontend/lib/Transforms/CMakeLists.txt b/third_party/mthreads/tle/frontend/lib/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..13596f66b1 --- /dev/null +++ b/third_party/mthreads/tle/frontend/lib/Transforms/CMakeLists.txt @@ -0,0 +1,23 @@ +add_triton_library(MUSATLEFrontendTransforms + EarlyAssignMemorySpace.cpp + LowerAsyncLoad.cpp + + DEPENDS + MUSATLEFrontendTransformsIncGen + TritonIR + TritonGPUIR + TritonGPUTransforms + MUSAIR + + LINK_LIBS PUBLIC + MUSAIR + TritonIR + TritonGPUIR + TritonGPUTransforms +) + +target_include_directories(MUSATLEFrontendTransforms + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_BINARY_DIR}/../../include +) diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp b/third_party/mthreads/tle/frontend/lib/Transforms/EarlyAssignMemorySpace.cpp similarity index 99% rename from third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp rename to third_party/mthreads/tle/frontend/lib/Transforms/EarlyAssignMemorySpace.cpp index 1c4510b2f2..e07a3a757f 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/EarlyAssignMemorySpace.cpp +++ b/third_party/mthreads/tle/frontend/lib/Transforms/EarlyAssignMemorySpace.cpp @@ -1,5 +1,5 @@ +#include "MUSATLE/Frontend/Passes.h" #include "TritonMUSACommon/SqmmaAttrUtils.h" -#include "TritonMUSAGPUTransforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" @@ -251,7 +251,7 @@ lowerLoadViaAsyncSharedCopy(tt::LoadOp op, RewriterBase &rewriter, namespace mlir { #define GEN_PASS_DEF_TRITONMUSAGPUTLEEARLYASSIGNMEMORYSPACE -#include "TritonMUSAGPUTransforms/Passes.h.inc" +#include "MUSATLE/Frontend/Passes.h.inc" struct TritonMUSAGPUTLEEarlyAssignMemorySpacePass : impl::TritonMUSAGPUTLEEarlyAssignMemorySpaceBase< diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/LowerAsyncLoad.cpp b/third_party/mthreads/tle/frontend/lib/Transforms/LowerAsyncLoad.cpp similarity index 98% rename from third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/LowerAsyncLoad.cpp rename to third_party/mthreads/tle/frontend/lib/Transforms/LowerAsyncLoad.cpp index b7758513f3..cfffada878 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TLE/LowerAsyncLoad.cpp +++ b/third_party/mthreads/tle/frontend/lib/Transforms/LowerAsyncLoad.cpp @@ -1,5 +1,5 @@ +#include "MUSATLE/Frontend/Passes.h" #include "TritonMUSACommon/SqmmaAttrUtils.h" -#include "TritonMUSAGPUTransforms/Passes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Matchers.h" @@ -225,7 +225,7 @@ static void lowerAsyncLoad(tt::LoadOp op, RewriterBase &rewriter, namespace mlir { #define GEN_PASS_DEF_TRITONMUSAGPUTLELOWERASYNCLOAD -#include "TritonMUSAGPUTransforms/Passes.h.inc" +#include "MUSATLE/Frontend/Passes.h.inc" struct TritonMUSAGPUTLELowerAsyncLoadPass : impl::TritonMUSAGPUTLELowerAsyncLoadBase< diff --git a/third_party/mthreads/tle/frontend/triton_mthreads_frontend.cc b/third_party/mthreads/tle/frontend/triton_mthreads_frontend.cc new file mode 100644 index 0000000000..484f8822f8 --- /dev/null +++ b/third_party/mthreads/tle/frontend/triton_mthreads_frontend.cc @@ -0,0 +1,21 @@ +#ifdef __TLE__ + +#include "MUSATLE/Frontend/Passes.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include + +namespace py = pybind11; + +// Frontend marker adapters consume shared TLE markers emitted by Python +// frontend code, such as `tt.memory_space` and `tt.load.async`, before the +// mthreads/MUSA TTGIR pipeline reaches backend-local `musa_tle` dialect +// optimization. They are not `musa_tle` dialect passes. +void init_triton_musa_tle_frontend_passes_ttgpuir(py::module m) { + ADD_PASS_WRAPPER_0("add_tle_early_assign_memory_space", + mlir::createTritonMUSAGPUTLEEarlyAssignMemorySpace); + ADD_PASS_WRAPPER_0("add_tle_lower_async_load", + mlir::createTritonMUSAGPUTLELowerAsyncLoad); +} + +#endif // __TLE__ diff --git a/third_party/mthreads/triton_mthreads.cc b/third_party/mthreads/triton_mthreads.cc index 09a178804c..bd071517d5 100644 --- a/third_party/mthreads/triton_mthreads.cc +++ b/third_party/mthreads/triton_mthreads.cc @@ -24,7 +24,8 @@ namespace py = pybind11; #ifdef __TLE__ void init_triton_musa_tle_ir(py::module m); -void init_triton_musa_tle_passes_ttgpuir(py::module m); +void init_triton_musa_tle_frontend_passes_ttgpuir(py::module m); +void init_triton_musa_tle_dialect_passes_ttgpuir(py::module m); void register_triton_musa_tle_dialects(mlir::DialectRegistry ®istry); #endif @@ -144,12 +145,6 @@ void init_triton_musa_passes_ttgpuir(py::module m) { mlir::createTritonMUSAGPUOptimizeDescriptorEncoding); ADD_PASS_WRAPPER_0("add_optimize_sqmma_accumulator_layout", mlir::createTritonMUSAGPUOptimizeSqmmaAccumulatorLayout); -#ifdef __TLE__ - ADD_PASS_WRAPPER_0("add_tle_early_assign_memory_space", - mlir::createTritonMUSAGPUTLEEarlyAssignMemorySpace); - ADD_PASS_WRAPPER_0("add_tle_lower_async_load", - mlir::createTritonMUSAGPUTLELowerAsyncLoad); -#endif // __TLE__ } void init_triton_mthreads(py::module &&m) { @@ -161,7 +156,8 @@ void init_triton_mthreads(py::module &&m) { auto ttgpuir = passes.def_submodule("ttgpuir"); init_triton_musa_passes_ttgpuir(ttgpuir); #ifdef __TLE__ - init_triton_musa_tle_passes_ttgpuir(ttgpuir); + init_triton_musa_tle_frontend_passes_ttgpuir(ttgpuir); + init_triton_musa_tle_dialect_passes_ttgpuir(ttgpuir); #endif // __TLE__ // load dialects From 38a529e1fa9cc937d43f46e68aa373e0b9501275 Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Mon, 25 May 2026 15:16:31 +0800 Subject: [PATCH 07/10] [TLE][MTHREADS] Support TLE copy on mthreads backend --- .../experimental/tle/language/gpu/core.py | 10 +- .../tle/language/gpu/mthreads/__init__.py | 3 + .../tle/language/gpu/mthreads/copy.py | 103 +++++++ third_party/mthreads/backend/compiler.py | 2 + .../Dialect/TritonGPU/IR/CMakeLists.txt | 9 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 27 ++ .../mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp | 34 ++ .../TritonMUSAGPUTransforms/CMakeLists.txt | 4 + .../TritonMUSAGPUTransforms/TMELowering.cpp | 92 ++++++ .../mthreads/python/test/unit/tle/test_tle.py | 30 ++ .../python/test/unit/tle/test_tle_copy.py | 184 +++++++++++ .../include/MUSATLE/Transforms/Passes.td | 24 ++ .../MUSATLE/Transforms/TransformAttrs.h | 15 + .../tle/dialect/lib/Transforms/CMakeLists.txt | 1 + .../OptimizeLocalPointerAsyncStores.cpp | 290 ++++++++++++++++++ .../tle/dialect/triton_mthreads_tle.cc | 8 + 16 files changed, 833 insertions(+), 3 deletions(-) create mode 100644 python/triton/experimental/tle/language/gpu/mthreads/__init__.py create mode 100644 python/triton/experimental/tle/language/gpu/mthreads/copy.py create mode 100644 third_party/mthreads/python/test/unit/tle/test_tle_copy.py create mode 100644 third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/TransformAttrs.h create mode 100644 third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp diff --git a/python/triton/experimental/tle/language/gpu/core.py b/python/triton/experimental/tle/language/gpu/core.py index f756b88b69..d913c33755 100644 --- a/python/triton/experimental/tle/language/gpu/core.py +++ b/python/triton/experimental/tle/language/gpu/core.py @@ -4,6 +4,7 @@ from typing import Optional, Sequence from enum import Enum from . import types as tle +from .mthreads import copy as mthreads_copy from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values from triton.language.core import ( @@ -360,6 +361,7 @@ def copy( TMA copy with offsets: tle.copy(tma_desc, local_buf, [64, 64], [x_offset, y_offset]) """ + mthreads_enabled = mthreads_copy.enabled() def normcopy( src: tl.tensor, @@ -368,6 +370,8 @@ def normcopy( direction, _semantic=None, ) -> None: + if mthreads_enabled: + mthreads_copy.validate_normal_copy(src, dst, shape, direction) # Semantic analysis try: @@ -389,8 +393,10 @@ def normcopy( try: if direction == CopyDirection.GM_TO_LOCAL: + # None fills the FlagTree hints slot; TLE copy has no hints to pass. + load_extra_args = () if mthreads_enabled else (None, ) tt_load = _semantic.load(src, mask, other, boundary_check, padding_option, cache_modifier, - eviction_policy, volatile, None) + eviction_policy, volatile, *load_extra_args) local_ptrs = local_ptr(dst, _make_full_indices(dst, _semantic), _semantic=_semantic) _semantic.store(local_ptrs, tt_load, mask, boundary_check, cache_modifier, eviction_policy) else: @@ -492,6 +498,8 @@ def tmacopy( raise ValueError(f"Shape parameter must be tuple or list, but got {type(shape)}") if is_normcopy: return normcopy(src, dst, shape, direction, _semantic) + if mthreads_enabled: + return mthreads_copy.tmacopy(src, dst, direction, shape, offsets, _semantic) else: return tmacopy(src, dst, direction, shape, offsets, _semantic) diff --git a/python/triton/experimental/tle/language/gpu/mthreads/__init__.py b/python/triton/experimental/tle/language/gpu/mthreads/__init__.py new file mode 100644 index 0000000000..4502842ecf --- /dev/null +++ b/python/triton/experimental/tle/language/gpu/mthreads/__init__.py @@ -0,0 +1,3 @@ +from . import copy + +__all__ = ["copy"] diff --git a/python/triton/experimental/tle/language/gpu/mthreads/copy.py b/python/triton/experimental/tle/language/gpu/mthreads/copy.py new file mode 100644 index 0000000000..808ab7503b --- /dev/null +++ b/python/triton/experimental/tle/language/gpu/mthreads/copy.py @@ -0,0 +1,103 @@ +import os + +import triton.language.core as tl + +from .. import types as tle + +try: + from triton._flagtree_backend import FLAGTREE_BACKEND +except ModuleNotFoundError: + FLAGTREE_BACKEND = os.environ.get("FLAGTREE_BACKEND", "") + + +def _has_mthreads_libtriton() -> bool: + try: + from triton._C import libtriton + except ImportError: + return False + return hasattr(libtriton, "mthreads") + + +def enabled() -> bool: + return FLAGTREE_BACKEND == "mthreads" or _has_mthreads_libtriton() + + +def normalize_copy_shape(shape) -> tuple[int, ...]: + return tuple(int(tl._unwrap_if_constexpr(dim)) for dim in shape) + + +def validate_copy_buffer(buffer: tle.buffered_tensor, shape: tuple[int, ...]) -> None: + if not isinstance(buffer, tle.buffered_tensor): + raise ValueError(f"buffer must be a tle.gpu.buffered_tensor, but got {type(buffer)}") + if buffer.type.storage != tle.smem: + raise ValueError("MUSA TLE copy only supports tle.gpu.smem buffers") + buffer_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape) + if buffer_shape != shape: + raise ValueError(f"copy shape {shape} must match buffer shape {buffer_shape}") + + +def tensor_shape(value: tl.tensor) -> tuple[int, ...]: + if not value.type.is_block(): + return tuple() + return tuple(int(tl._unwrap_if_constexpr(dim)) for dim in value.shape) + + +def tensor_pointer_element_ty(value: tl.tensor): + scalar_ty = value.dtype + if not scalar_ty.is_ptr(): + raise ValueError("tle.gpu.copy tensor operands must be pointer tensors") + return scalar_ty.element_ty + + +def validate_normal_copy(src, dst, shape, direction) -> None: + shape = normalize_copy_shape(shape) + if direction.name == "GM_TO_LOCAL": + global_tensor = src + local_buffer = dst + else: + global_tensor = dst + local_buffer = src + + validate_copy_buffer(local_buffer, shape) + ptr_shape = tensor_shape(global_tensor) + if ptr_shape != shape: + raise ValueError(f"copy shape {shape} must match tensor pointer shape {ptr_shape}") + elem_ty = tensor_pointer_element_ty(global_tensor) + if elem_ty != local_buffer.dtype: + raise ValueError(f"copy dtype mismatch: tensor points to {elem_ty}, buffer stores {local_buffer.dtype}") + + +def normalize_offsets(offsets, rank: int): + offsets = tl._unwrap_if_constexpr(offsets) + if offsets is None: + raise ValueError("descriptor-based tle.gpu.copy requires offsets") + if isinstance(offsets, tl.tuple): + offsets_tuple = tuple(offsets.values) + elif isinstance(offsets, (tuple, list)): + offsets_tuple = tuple(offsets) + elif hasattr(offsets, "__iter__"): + offsets_tuple = tuple(offsets) + else: + raise ValueError(f"offsets must be a tuple or list, but got {type(offsets)}") + if len(offsets_tuple) != rank: + raise ValueError(f"offsets must provide {rank} values, got {len(offsets_tuple)}") + return offsets_tuple + + +def tmacopy(src, dst, direction, shape, offsets, _semantic) -> None: + shape = normalize_copy_shape(shape) + desc = src if direction.name == "GM_TO_LOCAL" else dst + buffer = dst if direction.name == "GM_TO_LOCAL" else src + + validate_copy_buffer(buffer, shape) + desc_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in desc.block_shape) + if desc_shape != shape: + raise ValueError(f"copy shape {shape} must match tensor descriptor block shape {desc_shape}") + if desc.dtype != buffer.dtype: + raise ValueError(f"copy dtype mismatch: descriptor stores {desc.dtype}, buffer stores {buffer.dtype}") + + offset_values = normalize_offsets(offsets, len(desc_shape)) + offset_values = _semantic._convert_to_ir_values(offset_values, require_i64=False) + if not hasattr(_semantic.builder, "create_tma_copy"): + raise RuntimeError("TLE TMA copy builder binding is not available") + _semantic.builder.create_tma_copy(src.handle, dst.handle, offset_values) diff --git a/third_party/mthreads/backend/compiler.py b/third_party/mthreads/backend/compiler.py index 8350ce996a..19ddd19581 100644 --- a/third_party/mthreads/backend/compiler.py +++ b/third_party/mthreads/backend/compiler.py @@ -740,6 +740,8 @@ def make_ttgir(mod, metadata, opt, arch, capability): passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_optimize_thread_locality(pm) + if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_async_stores"): + mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_async_stores(pm) if hasattr(mthreads.passes.ttgpuir, "add_tle_early_assign_memory_space"): mthreads.passes.ttgpuir.add_tle_early_assign_memory_space(pm) if hasattr(mthreads.passes.ttgpuir, "add_tle_select_encodings"): diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index 8b44463001..5ce12376ba 100644 --- a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,10 +1,15 @@ set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) +if(FLAGTREE_MTHREADS_TLE) + set(_TLE_TABLEGEN_DEFS -D__TLE__) +else() + set(_TLE_TABLEGEN_DEFS "") +endif() mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Ops.h.inc -gen-op-decls ${_TLE_TABLEGEN_DEFS}) +mlir_tablegen(Ops.cpp.inc -gen-op-defs ${_TLE_TABLEGEN_DEFS}) mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg) mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg) add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) diff --git a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 5f666b43a2..438dbf02cf 100644 --- a/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/third_party/mthreads/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -23,6 +23,10 @@ include "mlir/Interfaces/ViewLikeInterface.td" def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">; +#ifdef __TLE__ +def TTG_TMACopyOperand : AnyTypeOf<[TT_TensorDescType, TTG_MemDescType]>; +#endif // __TLE__ + class TTG_Op traits = []> : Op { @@ -721,6 +725,29 @@ def TTG_BarrierOp : TTG_Op<"barrier"> { }]; } +#ifdef __TLE__ +def TTG_TMACopyOp : TTG_Op<"tma_copy", [MemoryEffects<[MemRead, MemWrite]>]> { + let summary = "Pseudo op for descriptor-based copy between global tensor descriptor and shared memdesc."; + + let description = [{ + `ttg.tma_copy` represents an explicit copy between a global tensor + descriptor and a shared-memory memdesc. Backend-specific TME/TMA lowering + replaces it with the target hardware copy and synchronization operations. + }]; + + let arguments = (ins + TTG_TMACopyOperand:$src, + TTG_TMACopyOperand:$dst, + Variadic:$indices + ); + + let assemblyFormat = + "$src `,` $dst `,` `[` $indices `]` attr-dict `:` type($src) `,` type($dst)"; + + let hasVerifier = 1; +} +#endif // __TLE__ + def TTG_WarpIdOp : TTG_Op<"warp_id", [Pure]> { let summary = "Return the GPU warp ID"; diff --git a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp index b359868f5a..aa3dfca094 100644 --- a/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/third_party/mthreads/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -888,6 +888,40 @@ LogicalResult AsyncCopyGlobalToLocalOp::verify() { return success(); } +#ifdef __TLE__ +LogicalResult TMACopyOp::verify() { + auto srcDescTy = dyn_cast(getSrc().getType()); + auto dstDescTy = dyn_cast(getDst().getType()); + auto srcMemDescTy = dyn_cast(getSrc().getType()); + auto dstMemDescTy = dyn_cast(getDst().getType()); + + const bool globalToLocal = srcDescTy && dstMemDescTy; + const bool localToGlobal = srcMemDescTy && dstDescTy; + if (!globalToLocal && !localToGlobal) { + return emitOpError("expects one tensor descriptor operand and one memdesc " + "operand"); + } + + auto descTy = globalToLocal ? srcDescTy : dstDescTy; + auto memDescTy = globalToLocal ? dstMemDescTy : srcMemDescTy; + auto blockTy = descTy.getBlockType(); + if (getIndices().size() != static_cast(blockTy.getRank())) { + return emitOpError("expects ") + << blockTy.getRank() << " indices, got " << getIndices().size(); + } + if (memDescTy.getShape() != blockTy.getShape()) { + return emitOpError("memdesc shape must match descriptor block shape"); + } + if (memDescTy.getElementType() != blockTy.getElementType()) { + return emitOpError("memdesc element type must match descriptor element " + "type"); + } + if (globalToLocal && !memDescTy.getMutableMemory()) + return emitOpError("cannot copy into immutable memdesc"); + return success(); +} +#endif // __TLE__ + LogicalResult MemDescIndexOp::verify() { auto srcTy = getSrc().getType(); auto dstTy = getType(); diff --git a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt index e1f1e09202..346709f200 100644 --- a/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt +++ b/third_party/mthreads/musa/include/TritonMUSAGPUTransforms/CMakeLists.txt @@ -1,10 +1,14 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) if(FLAGTREE_MTHREADS_TLE) set(_TLE_TABLEGEN_DEFS -D__TLE__) + set(_TLE_TABLEGEN_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/../../../tle/dialect/include/MUSATLE/Transforms/Passes.td) else() set(_TLE_TABLEGEN_DEFS "") + set(_TLE_TABLEGEN_DEPENDS "") endif() +set(LLVM_TARGET_DEPENDS ${_TLE_TABLEGEN_DEPENDS}) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonMUSAGPU ${_TLE_TABLEGEN_DEFS}) add_public_tablegen_target(TritonMUSAGPUTransformsIncGen) diff --git a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp index 153b4c43f9..c616dc20d7 100644 --- a/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp +++ b/third_party/mthreads/musa/lib/TritonMUSAGPUTransforms/TMELowering.cpp @@ -160,6 +160,85 @@ static LogicalResult lowerDescriptorStore(tt::DescriptorStoreOp op, return success(); } +#ifdef __TLE__ +static LogicalResult lowerTMACopy(ttg::TMACopyOp op, RewriterBase &rewriter) { + auto loc = op.getLoc(); + auto srcDescTy = dyn_cast(op.getSrc().getType()); + auto dstDescTy = dyn_cast(op.getDst().getType()); + auto srcMemDescTy = dyn_cast(op.getSrc().getType()); + auto dstMemDescTy = dyn_cast(op.getDst().getType()); + + const bool globalToLocal = srcDescTy && dstMemDescTy; + const bool localToGlobal = srcMemDescTy && dstDescTy; + if (!globalToLocal && !localToGlobal) { + return op.emitOpError("expects one tensor descriptor operand and one " + "shared memdesc operand"); + } + + auto descTy = globalToLocal ? srcDescTy : dstDescTy; + auto memDescTy = globalToLocal ? dstMemDescTy : srcMemDescTy; + auto descBlockTy = descTy.getSignlessBlockType(); + if (memDescTy.getShape() != descBlockTy.getShape()) + return op.emitOpError("memdesc shape must match descriptor block shape"); + if (memDescTy.getElementType() != descBlockTy.getElementType()) + return op.emitOpError("memdesc element type must match descriptor element " + "type"); + + rewriter.setInsertionPoint(op); + auto coord = + triton::musa::materializeTMECoordValues(loc, op.getIndices(), rewriter); + if (failed(coord)) + return op.emitOpError("unsupported descriptor block rank for TME copy"); + + Value pred = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + if (globalToLocal) { + auto config = triton::musa::resolveFinalTMECopyConfig( + memDescTy, descBlockTy.getShape(), + triton::musa::TMECopyKind::GlobalToLocal); + if (failed(config)) + return op.emitOpError("unable to resolve final TME load config"); + + auto barId = triton::musa::reserveFreshBarrierId(op); + if (failed(barId)) + return op.emitOpError("exhausted MUSA async barrier ids"); + Value barIdValue = arith::ConstantIntOp::create(rewriter, loc, *barId, 32); + Value phaseInit = arith::ConstantIntOp::create(rewriter, loc, 0, 32); + Value arriveCnt = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + Value alwaysIssue = arith::ConstantIntOp::create(rewriter, loc, 1, 1); + Value totalBytes = materializeStaticTMETransactionBytes( + loc, descBlockTy.getShape(), memDescTy.getElementType(), rewriter); + if (!totalBytes) + return op.emitOpError("unable to materialize TME copy transaction " + "bytes"); + + triton::musa::InitArrivalOp::create(rewriter, loc, barIdValue, arriveCnt, + phaseInit); + triton::musa::BarrierAddTransOp::create(rewriter, loc, barIdValue, + totalBytes, alwaysIssue); + triton::musa::createAsyncTMECopyGlobalToLocal(rewriter, loc, op.getSrc(), + *coord, barIdValue, + op.getDst(), pred, *config); + triton::musa::ArriveBarrierNoRetOp::create(rewriter, loc, barIdValue, + alwaysIssue); + triton::musa::WaitBarrierOp::create(rewriter, loc, barIdValue, phaseInit); + } else { + auto config = triton::musa::resolveFinalTMECopyConfig( + memDescTy, descBlockTy.getShape(), + triton::musa::TMECopyKind::LocalToGlobal); + if (failed(config)) + return op.emitOpError("unable to resolve final TME store config"); + + triton::musa::createAsyncTMECopyLocalToGlobal( + rewriter, loc, op.getDst(), *coord, op.getSrc(), pred, *config); + triton::musa::TMEStoreCommitOp::create(rewriter, loc); + triton::musa::TMEStoreReadWaitOp::create(rewriter, loc); + } + + rewriter.eraseOp(op); + return success(); +} +#endif // __TLE__ + } // namespace namespace mlir { @@ -175,6 +254,19 @@ struct TritonMUSAGPUTMELoweringPass IRRewriter rewriter(&getContext()); for (tt::FuncOp func : mod.getOps()) { +#ifdef __TLE__ + SmallVector tmaCopyOps; + func.walk([&](ttg::TMACopyOp op) { tmaCopyOps.push_back(op); }); + for (ttg::TMACopyOp op : tmaCopyOps) { + if (!op->getBlock()) + continue; + if (failed(lowerTMACopy(op, rewriter))) { + signalPassFailure(); + return; + } + } +#endif // __TLE__ + SmallVector loadOps; func.walk([&](tt::DescriptorLoadOp op) { loadOps.push_back(op); }); for (tt::DescriptorLoadOp op : loadOps) { diff --git a/third_party/mthreads/python/test/unit/tle/test_tle.py b/third_party/mthreads/python/test/unit/tle/test_tle.py index 1a3d3c1b4e..25fd2520d9 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle.py @@ -37,6 +37,36 @@ def test_tle_language_import_exports_load_signature(): "nv_mma_shared_layout", "_semantic", ] + assert list(inspect.signature(tle.gpu.copy).parameters) == [ + "src", + "dst", + "shape", + "offsets", + "_semantic", + ] + assert hasattr(tle.gpu, "copy") + + +def test_tle_copy_mthreads_bindings_are_backend_local(): + from triton._C import libtriton + from triton._C.libtriton import ir + + from test_tle_utils import mthreads_backend + + _, backend = mthreads_backend() + context = ir.context() + ir.load_dialects(context) + backend.load_dialects(context) + builder = ir.builder(context) + + assert hasattr(builder, "create_local_pointers") + assert hasattr(builder, "create_tma_copy") + assert hasattr(libtriton, "mthreads") + assert not hasattr(libtriton, "tle") + assert hasattr( + libtriton.mthreads.passes.ttgpuir, + "add_tle_optimize_local_pointer_async_stores", + ) def test_tle_load_sets_async_bool_attr(): diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_copy.py b/third_party/mthreads/python/test/unit/tle/test_tle_copy.py new file mode 100644 index 0000000000..6799843a0a --- /dev/null +++ b/third_party/mthreads/python/test/unit/tle/test_tle_copy.py @@ -0,0 +1,184 @@ +import pytest +import torch +import triton +import triton.language as tl +import triton.experimental.tle.language as tle +from triton.tools.tensor_descriptor import TensorDescriptor + +from test_tle_utils import compile_musa, require_mthreads_libtriton + +require_mthreads_libtriton() + + +@triton.jit +def _normal_copy_roundtrip_kernel(src, dst, BLOCK: tl.constexpr): + offsets = tl.arange(0, BLOCK) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float32, nv_mma_shared_layout=False) + tle.gpu.copy(src + offsets, smem, (BLOCK, )) + tle.gpu.copy(smem, dst + offsets, (BLOCK, )) + + +@triton.jit +def _normal_copy_shape_mismatch_kernel(src, dst, BLOCK: tl.constexpr): + offsets = tl.arange(0, BLOCK) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float32, nv_mma_shared_layout=False) + tle.gpu.copy(src + offsets, smem, (BLOCK // 2, )) + tle.gpu.copy(smem, dst + offsets, (BLOCK, )) + + +@triton.jit +def _tma_copy_desc_to_smem_kernel(desc, dst, BLOCK: tl.constexpr): + offsets = tl.arange(0, BLOCK) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float16, nv_mma_shared_layout=False) + tle.gpu.copy(desc, smem, (BLOCK, ), (0, )) + values = tl.load(tle.gpu.local_ptr(smem)) + tl.store(dst + offsets, values) + + +@triton.jit +def _tma_copy_smem_to_desc_kernel(desc, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + rows = tl.arange(0, BLOCK_M)[:, None] + cols = tl.arange(0, BLOCK_N)[None, :] + values = (rows * 10 + cols).to(tl.float16) + smem = tle.gpu.alloc((BLOCK_M, BLOCK_N), dtype=tl.float16, nv_mma_shared_layout=False) + tl.store(tle.gpu.local_ptr(smem), values) + tle.gpu.copy(smem, desc, (BLOCK_M, BLOCK_N), (0, 0)) + + +@triton.jit +def _tma_copy_missing_offsets_kernel(desc, BLOCK: tl.constexpr): + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float16, nv_mma_shared_layout=False) + tle.gpu.copy(desc, smem, (BLOCK, )) + + +@triton.jit +def _tma_copy_wrong_offset_rank_kernel(desc, BLOCK: tl.constexpr): + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float16, nv_mma_shared_layout=False) + tle.gpu.copy(desc, smem, (BLOCK, ), (0, 0)) + + +def test_tle_copy_normal_gmem_to_smem_lowers_to_async_copy(): + compiled = compile_musa( + _normal_copy_roundtrip_kernel, + signature={"src": "*fp32", "dst": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + ttgir = compiled.asm["ttgir"] + assert "ttg.async_copy_global_to_local" in ttgir, ttgir + assert "musa_tle.local_ptr_async_store" in ttgir, ttgir + assert "ttg.local_load" in ttgir, ttgir + assert "musa_tle.local_pointers" not in compiled.asm["llir"], compiled.asm["llir"] + + +def test_tle_copy_normal_rejects_shape_mismatch(): + from triton.compiler.errors import CompilationError + + with pytest.raises(CompilationError, match="copy shape .* must match"): + compile_musa( + _normal_copy_shape_mismatch_kernel, + signature={"src": "*fp32", "dst": "*fp32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_normal_roundtrip_runtime(): + block = 64 + src = torch.arange(0, block, device="musa", dtype=torch.float32) + dst = torch.empty((block, ), device="musa", dtype=torch.float32) + + _normal_copy_roundtrip_kernel[(1, )](src, dst, BLOCK=block, num_warps=1) + + torch.testing.assert_close(dst.cpu(), src.cpu(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_desc_to_smem_lowers_to_musa_tme(): + block = 128 + src = torch.arange(0, block, device="musa", dtype=torch.float16) + dst = torch.empty((block, ), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(src, [block]) + + compiled = _tma_copy_desc_to_smem_kernel.warmup(desc, dst, BLOCK=block, grid=(1, ), num_warps=4) + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + + assert "ttg.tma_copy" not in ttgir, ttgir + assert "ttmg.async_tme_copy_global_to_local" in ttgir, ttgir + assert "ttmg.wait_barrier" in ttgir, ttgir + assert "llvm.musa.tme.ld.tile.1d" in llir, llir + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_rejects_missing_offsets(): + from triton.compiler.errors import CompilationError + + block = 128 + src = torch.empty((block, ), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(src, [block]) + + with pytest.raises(CompilationError, match="descriptor-based tle.gpu.copy requires offsets"): + _tma_copy_missing_offsets_kernel.warmup(desc, BLOCK=block, grid=(1, ), num_warps=4) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_rejects_wrong_offset_rank(): + from triton.compiler.errors import CompilationError + + block = 128 + src = torch.empty((block, ), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(src, [block]) + + with pytest.raises(CompilationError, match="offsets must provide 1 values, got 2"): + _tma_copy_wrong_offset_rank_kernel.warmup(desc, BLOCK=block, grid=(1, ), num_warps=4) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_desc_to_smem_runtime(): + block = 128 + src = torch.arange(0, block, device="musa", dtype=torch.float16) + dst = torch.empty((block, ), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(src, [block]) + + _tma_copy_desc_to_smem_kernel[(1, )](desc, dst, BLOCK=block, num_warps=4) + + torch.testing.assert_close(dst.cpu(), src.cpu(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_smem_to_desc_lowers_to_musa_tme(): + block_m = 16 + block_n = 32 + dst = torch.empty((block_m, block_n), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(dst, [block_m, block_n]) + + compiled = _tma_copy_smem_to_desc_kernel.warmup( + desc, + BLOCK_M=block_m, + BLOCK_N=block_n, + grid=(1, ), + num_warps=4, + ) + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + + assert "ttg.tma_copy" not in ttgir, ttgir + assert "ttmg.async_tme_copy_local_to_global" in ttgir, ttgir + assert "ttmg.tme_store_commit" in ttgir, ttgir + assert "llvm.musa.tme.st.2d" in llir, llir + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_copy_tma_smem_to_desc_runtime(): + block_m = 16 + block_n = 32 + dst = torch.empty((block_m, block_n), device="musa", dtype=torch.float16) + desc = TensorDescriptor.from_tensor(dst, [block_m, block_n]) + + _tma_copy_smem_to_desc_kernel[(1, )](desc, BLOCK_M=block_m, BLOCK_N=block_n, num_warps=4) + + rows = torch.arange(0, block_m, dtype=torch.float16)[:, None] + cols = torch.arange(0, block_n, dtype=torch.float16)[None, :] + ref = rows * 10 + cols + torch.testing.assert_close(dst.cpu(), ref, rtol=0, atol=0) diff --git a/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td index c55ed3d065..f7808ecb4e 100644 --- a/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td +++ b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/Passes.td @@ -47,4 +47,28 @@ def TritonMUSAGPUTLEOptimizeLocalPointerStores ]; } +def TritonMUSAGPUTLEOptimizeLocalPointerAsyncStores + : Pass<"tritonmusa-tle-optimize-local-pointer-async-stores", "mlir::ModuleOp"> { + let summary = + "Rewrite gmem load plus musa_tle.local_pointers subview store into async_copy"; + + let description = [{ + Match conservative staging patterns of the form: + + %v = tt.load %gptr ... + tt.store %local_ptr_subview, %v ... + + where `%local_ptr_subview` is produced by `musa_tle.local_pointers` and + denotes a static full view or static subview of a backing shared memdesc. + Rewrite the pair into `ttg.async_copy_global_to_local` plus commit/wait. + }]; + + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::musa_tle::MUSATLEDialect" + ]; +} + #endif // MUSA_TLE_TRANSFORMS_PASSES diff --git a/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/TransformAttrs.h b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/TransformAttrs.h new file mode 100644 index 0000000000..dbd52caf5d --- /dev/null +++ b/third_party/mthreads/tle/dialect/include/MUSATLE/Transforms/TransformAttrs.h @@ -0,0 +1,15 @@ +#ifndef MTHREADS_MUSATLE_TRANSFORMS_TRANSFORM_ATTRS_H +#define MTHREADS_MUSATLE_TRANSFORMS_TRANSFORM_ATTRS_H + +#include "llvm/ADT/StringRef.h" + +#ifdef __TLE__ +namespace mlir::triton::musa_tle { + +inline constexpr llvm::StringLiteral + kMUSATLELocalPointerAsyncStoreAttr("musa_tle.local_ptr_async_store"); + +} // namespace mlir::triton::musa_tle +#endif // __TLE__ + +#endif // MTHREADS_MUSATLE_TRANSFORMS_TRANSFORM_ATTRS_H diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt b/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt index d98285836e..ca44639783 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt +++ b/third_party/mthreads/tle/dialect/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(MUSATLETransforms InsertLocalPointerBarriers.cpp + OptimizeLocalPointerAsyncStores.cpp OptimizeLocalPointerLoads.cpp OptimizeLocalPointerStores.cpp SelectEncodings.cpp diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp new file mode 100644 index 0000000000..d9a359cf23 --- /dev/null +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp @@ -0,0 +1,290 @@ +#ifdef __TLE__ + +#include "Dialect/MUSATLE/IR/Dialect.h" +#include "MUSATLE/Transforms/TransformAttrs.h" +#include "TritonMUSAGPUTransforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace mlir { + +#define GEN_PASS_DEF_TRITONMUSAGPUTLEOPTIMIZELOCALPOINTERASYNCSTORES +#include "TritonMUSAGPUTransforms/Passes.h.inc" + +namespace { + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Value stripConvertLayouts(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static Value stripStoreValueWrappers(Value value) { + Value current = value; + while (auto cvt = current.getDefiningOp()) + current = cvt.getSrc(); + return current; +} + +static bool isGlobalPointerTensor(Value value) { + auto tensorTy = dyn_cast(value.getType()); + if (!tensorTy) + return false; + auto ptrTy = dyn_cast(tensorTy.getElementType()); + if (!ptrTy) + return false; + return ptrTy.getAddressSpace() == 1; +} + +static Value stripIndexValueWrappers(Value value) { + Value current = value; + while (true) { + if (auto cvt = current.getDefiningOp()) { + current = cvt.getSrc(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto ext = current.getDefiningOp()) { + current = ext.getIn(); + continue; + } + if (auto trunc = current.getDefiningOp()) { + current = trunc.getIn(); + continue; + } + if (auto cast = current.getDefiningOp()) { + current = cast.getIn(); + continue; + } + break; + } + return current; +} + +static std::optional getConstantIntLike(Value value) { + value = stripIndexValueWrappers(value); + if (auto splat = value.getDefiningOp()) + return getConstantIntLike(splat.getSrc()); + if (auto cst = value.getDefiningOp()) { + if (auto dense = dyn_cast(cst.getValue())) { + if (dense.isSplat()) + return dense.getSplatValue().getSExtValue(); + } + } + if (auto cst = value.getDefiningOp()) + return cst.value(); + if (auto cst = value.getDefiningOp()) + return cst.value(); + return std::nullopt; +} + +static bool matchRangeWithStaticOffset(Value value, int64_t extent, + int64_t &offset) { + Value current = stripIndexValueWrappers(value); + if (auto range = current.getDefiningOp()) { + offset = range.getStart(); + return range.getEnd() - range.getStart() == extent; + } + + auto add = current.getDefiningOp(); + if (!add) + return false; + + auto tryMatch = [&](Value lhs, Value rhs) -> bool { + Value lhsStripped = stripIndexValueWrappers(lhs); + auto range = lhsStripped.getDefiningOp(); + if (!range) + return false; + std::optional cst = getConstantIntLike(rhs); + if (!cst) + return false; + offset = range.getStart() + *cst; + return range.getEnd() - range.getStart() == extent; + }; + + return tryMatch(add.getLhs(), add.getRhs()) || + tryMatch(add.getRhs(), add.getLhs()); +} + +static bool matchFullIndexTensorForAxis(Value index, size_t axis, + ArrayRef shape, + int64_t &offset) { + auto indexTy = dyn_cast(index.getType()); + if (!indexTy || !indexTy.getElementType().isInteger()) + return false; + if (indexTy.getShape() != shape) + return false; + + Value current = stripIndexValueWrappers(index); + if (shape.size() == 1) + return matchRangeWithStaticOffset(current, shape.front(), offset); + + auto bcast = current.getDefiningOp(); + if (!bcast) + return false; + + auto bcastSrcTy = dyn_cast(bcast.getSrc().getType()); + if (!bcastSrcTy || bcastSrcTy.getRank() != static_cast(shape.size())) + return false; + for (auto [dim, dimSize] : llvm::enumerate(shape)) { + const int64_t expected = dim == axis ? dimSize : 1; + if (bcastSrcTy.getShape()[dim] != expected) + return false; + } + + current = stripIndexValueWrappers(bcast.getSrc()); + while (auto expand = current.getDefiningOp()) + current = stripIndexValueWrappers(expand.getSrc()); + + auto rangeTy = dyn_cast(current.getType()); + if (!rangeTy || rangeTy.getRank() != 1) + return false; + if (rangeTy.getShape()[0] != shape[axis]) + return false; + + return matchRangeWithStaticOffset(current, shape[axis], offset); +} + +struct StaticSubviewMatch { + Value baseMemDesc; + SmallVector offsets; + RankedTensorType valueType; +}; + +static std::optional +matchStaticSubviewMemDesc(tt::StoreOp store) { + Value ptr = stripConvertLayouts(store.getPtr()); + auto localPointers = ptr.getDefiningOp(); + if (!localPointers) + return std::nullopt; + + auto valueTy = dyn_cast(store.getValue().getType()); + auto ptrTy = dyn_cast(localPointers.getResult().getType()); + auto memDescTy = dyn_cast(localPointers.getSrc().getType()); + if (!valueTy || !ptrTy || !memDescTy) + return std::nullopt; + if (valueTy.getShape() != ptrTy.getShape()) + return std::nullopt; + if (valueTy.getElementType() != memDescTy.getElementType()) + return std::nullopt; + + auto indices = localPointers.getIndices(); + SmallVector offsets(memDescTy.getRank(), 0); + if (indices.empty()) { + if (llvm::equal(valueTy.getShape(), memDescTy.getShape())) + return StaticSubviewMatch{localPointers.getSrc(), std::move(offsets), + valueTy}; + return std::nullopt; + } + if (indices.size() != static_cast(memDescTy.getRank())) + return std::nullopt; + + for (auto [axis, index] : llvm::enumerate(indices)) { + int64_t offset = 0; + if (!matchFullIndexTensorForAxis(index, axis, valueTy.getShape(), offset)) + return std::nullopt; + offsets[axis] = static_cast(offset); + } + + return StaticSubviewMatch{localPointers.getSrc(), std::move(offsets), + valueTy}; +} + +static Value createSubviewForStore(OpBuilder &builder, Location loc, + StaticSubviewMatch match) { + auto memDescTy = cast(match.baseMemDesc.getType()); + bool isFullView = + llvm::equal(match.valueType.getShape(), memDescTy.getShape()) && + llvm::all_of(match.offsets, [](int32_t offset) { return offset == 0; }); + if (isFullView) + return match.baseMemDesc; + + auto subTy = ttg::MemDescType::get( + match.valueType.getShape(), match.valueType.getElementType(), + memDescTy.getEncoding(), memDescTy.getMemorySpace(), + memDescTy.getMutableMemory(), memDescTy.getAllocShape()); + return ttg::MemDescSubsliceOp::create(builder, loc, subTy, match.baseMemDesc, + match.offsets); +} + +class OptimizeLocalPointerAsyncStoresPass + : public impl::TritonMUSAGPUTLEOptimizeLocalPointerAsyncStoresBase< + OptimizeLocalPointerAsyncStoresPass> { + void runOnOperation() override { + ModuleOp module = getOperation(); + + SmallVector stores; + module.walk([&](tt::StoreOp store) { stores.push_back(store); }); + + for (tt::StoreOp store : stores) { + if (!store) + continue; + Value strippedStoreValue = stripStoreValueWrappers(store.getValue()); + auto load = strippedStoreValue.getDefiningOp(); + if (!load || !load->hasOneUse()) + continue; + if (load.getIsVolatile()) + continue; + if (!isa(load.getType())) + continue; + if (!isGlobalPointerTensor(load.getPtr())) + continue; + auto match = matchStaticSubviewMemDesc(store); + if (!match) + continue; + if (cast(load.getType()).getShape() != + match->valueType.getShape()) + continue; + if (cast(load.getType()).getElementType() != + match->valueType.getElementType()) + continue; + + OpBuilder builder(store); + Value dst = createSubviewForStore(builder, store.getLoc(), *match); + auto asyncCopy = ttg::AsyncCopyGlobalToLocalOp::create( + builder, store.getLoc(), load.getPtr(), dst, load.getMask(), + load.getOther(), load.getCache(), load.getEvict(), + load.getIsVolatile()); + asyncCopy->setAttr(triton::musa_tle::kMUSATLELocalPointerAsyncStoreAttr, + builder.getUnitAttr()); + auto commit = ttg::AsyncCommitGroupOp::create(builder, store.getLoc(), + asyncCopy.getToken()); + ttg::AsyncWaitOp::create(builder, store.getLoc(), commit.getResult(), 0); + + Value originalStoreValue = store.getValue(); + store.erase(); + for (Value current = originalStoreValue; current != load.getResult();) { + Operation *def = current.getDefiningOp(); + auto cvt = dyn_cast_or_null(def); + if (!cvt || !cvt->use_empty()) + break; + current = cvt.getSrc(); + cvt.erase(); + } + load.erase(); + } + } +}; + +} // namespace +} // namespace mlir + +#endif // __TLE__ diff --git a/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc b/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc index bf47655d74..5aad452c3e 100644 --- a/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc +++ b/third_party/mthreads/tle/dialect/triton_mthreads_tle.cc @@ -156,6 +156,11 @@ void init_triton_musa_tle_ir(py::module m) { memorySpace, /*mutableMemory=*/true, allocShape); }) + .def("create_tma_copy", + [](TritonOpBuilder &self, mlir::Value src, mlir::Value dst, + std::vector indices) -> void { + self.create(src, dst, indices); + }) .def("create_local_pointers", [](TritonOpBuilder &self, mlir::Type resultTy, mlir::Value memDesc, py::args args) -> mlir::OpState { @@ -177,6 +182,9 @@ void init_triton_musa_tle_dialect_passes_ttgpuir(py::module m) { mlir::createTritonMUSAGPUTLEOptimizeLocalPointerLoads); ADD_PASS_WRAPPER_0("add_tle_optimize_local_pointer_stores", mlir::createTritonMUSAGPUTLEOptimizeLocalPointerStores); + ADD_PASS_WRAPPER_0( + "add_tle_optimize_local_pointer_async_stores", + mlir::createTritonMUSAGPUTLEOptimizeLocalPointerAsyncStores); } void register_triton_musa_tle_dialects(mlir::DialectRegistry ®istry) { From 4b4f2c2ed9ca99bdabdc71ebbacc68d33aa960ac Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Mon, 25 May 2026 16:14:29 +0800 Subject: [PATCH 08/10] [TLE][MTHREADS] Support atomic operands --- .../triton/Dialect/Triton/IR/Dialect.h | 5 + .../triton/Dialect/Triton/IR/TritonOps.td | 25 +++++ .../test/unit/tle/test_tle_local_ptr.py | 96 ++++++++++++++++++- 3 files changed, 125 insertions(+), 1 deletion(-) diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h index db54dcd708..5ab85ec07f 100644 --- a/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/Dialect.h @@ -27,6 +27,11 @@ namespace triton { struct GlobalMemory : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; +#ifdef __TLE__ +struct SharedMemory : public SideEffects::Resource::Base { + StringRef getName() final { return ""; } +}; +#endif class DialectInferLayoutInterface : public DialectInterface::Base { diff --git a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td index 91f8aff7a6..94848b2303 100644 --- a/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/third_party/mthreads/include/triton/Dialect/Triton/IR/TritonOps.td @@ -20,6 +20,9 @@ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td" // Interfaces // def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; +#ifdef __TLE__ +def SharedMemory : Resource<"::mlir::triton::SharedMemory">; +#endif // __TLE__ // // Op Base @@ -350,8 +353,13 @@ def TT_StoreOp : TT_Op<"store", [ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ SameOperandsAndResultShape, SameOperandsAndResultEncoding, +#ifdef __TLE__ + TypesMatchWith<"value type matches ptr type", "ptr", "val", + "getPointeeType($_self)">, +#else TypesMatchWith<"ptr type matches value type", "val", "ptr", "getPointerTypeSameShape($_self)">, +#endif // __TLE__ TypesMatchWith<"mask type matches value type", "val", "mask", "getI1SameShape($_self)", "($_op.getOperands().size() <= 2) || std::equal_to<>()"> @@ -366,7 +374,12 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ let arguments = (ins TT_AtomicRMWAttr:$atomic_rmw_op, +#ifdef __TLE__ + Arg, MemWrite, + MemRead, MemWrite]>:$ptr, +#else Arg, MemWrite]>:$ptr, +#endif // __TLE__ TT_Type:$val, Optional:$mask, TT_MemSemanticAttr:$sem, @@ -386,10 +399,17 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [ def TT_AtomicCASOp : TT_Op<"atomic_cas", [ SameOperandsAndResultShape, SameOperandsAndResultEncoding, +#ifdef __TLE__ + TypesMatchWith<"cmp type matches ptr type", "ptr", "cmp", + "getPointeeType($_self)">, + TypesMatchWith<"value type matches ptr type", "ptr", "val", + "getPointeeType($_self)"> +#else TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr", "getPointerTypeSameShape($_self)">, TypesMatchWith<"ptr type matches value type", "val", "ptr", "getPointerTypeSameShape($_self)"> +#endif // __TLE__ ]> { let summary = "atomic cas"; @@ -404,7 +424,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [ }]; let arguments = (ins +#ifdef __TLE__ + Arg, MemWrite, + MemRead, MemWrite]>:$ptr, +#else Arg, MemWrite]>:$ptr, +#endif // __TLE__ TT_Type:$cmp, TT_Type:$val, TT_MemSemanticAttr:$sem, diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py index 49f8063d35..d5398af49b 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py @@ -5,7 +5,7 @@ import triton.experimental.tle.language as tle from triton.compiler.errors import CompilationError -from test_tle_utils import compile_musa, require_mthreads_libtriton +from test_tle_utils import compile_musa, compile_to_ttir, require_mthreads_libtriton require_mthreads_libtriton() @@ -40,6 +40,40 @@ def _local_ptr_full_view_kernel(out_ptr): tl.store(out_ptr + tl.arange(0, 16), loaded) +@triton.jit +def _local_ptr_atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): + offsets = tl.arange(0, BLOCK) + init = tl.full((BLOCK, ), 0, tl.int32) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (offsets, )) + increments = offsets.to(tl.int32) + 1 + old = tl.atomic_add(ptrs, increments, sem="relaxed", scope="cta") + after = tl.load(ptrs) + tl.store(out_ptr + offsets, old) + tl.store(out_ptr + BLOCK + offsets, after) + + +@triton.jit +def _local_ptr_atomic_cas_kernel(out_ptr): + init = tl.full((1, ), 3, tl.int32) + smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) + ptr = tle.gpu.local_ptr(smem, (0, )) + old = tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta") + after = tl.load(ptr) + tl.store(out_ptr, old) + tl.store(out_ptr + 1, after) + + +@triton.jit +def _local_ptr_atomic_cas_update_kernel(out_ptr): + init = tl.full((1, ), 3, tl.int32) + smem = tle.gpu.alloc((1, ), dtype=tl.int32, init_value=init, nv_mma_shared_layout=False) + ptr = tle.gpu.local_ptr(smem, (0, )) + tl.atomic_cas(ptr, 3, 9, sem="relaxed", scope="cta") + after = tl.load(ptr) + tl.store(out_ptr, after) + + @triton.jit def _local_ptr_non_integer_index_kernel(out_ptr): smem = tle.gpu.alloc((16, ), dtype=tl.float32, nv_mma_shared_layout=False) @@ -110,6 +144,44 @@ def test_tle_local_ptr_full_view_store_load_rewrites_to_memdesc_ops(): assert "musa_tle.local_pointers" not in llir, llir +def test_tle_local_ptr_atomic_ops_accept_addrspace3_ttir(): + add_ttir = compile_to_ttir( + _local_ptr_atomic_add_kernel, + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 16}, + ) + cas_ttir = compile_to_ttir(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"}) + + assert "tt.atomic_rmw add, relaxed, cta" in add_ttir, add_ttir + assert ("(tensor<16x!tt.ptr>, tensor<16xi32>, tensor<16xi1>) -> tensor<16xi32>" in add_ttir), add_ttir + assert "tt.atomic_cas relaxed, cta" in cas_ttir, cas_ttir + assert "(!tt.ptr, i32, i32) -> i32" in cas_ttir, cas_ttir + + +def test_tle_local_ptr_atomic_add_lowers_through_mthreads_llvm(): + compiled = compile_musa( + _local_ptr_atomic_add_kernel, + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 16}, + ) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "tt.atomic_rmw" in ttgir, ttgir + assert "tensor<16x!tt.ptr" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + +def test_tle_local_ptr_atomic_cas_lowers_through_mthreads_llvm(): + compiled = compile_musa(_local_ptr_atomic_cas_kernel, signature={"out_ptr": "*i32"}) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "tt.atomic_cas" in ttgir, ttgir + assert "-> !tt.ptr" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + def test_tle_local_ptr_rejects_non_integer_indices(): with pytest.raises(CompilationError, match="local_ptr indices must use integer dtypes"): compile_musa(_local_ptr_non_integer_index_kernel, signature={"out_ptr": "*fp32"}) @@ -158,3 +230,25 @@ def test_tle_local_ptr_full_view_runtime_round_trip(): ref = torch.arange(0, 16, dtype=torch.float32) + 7.0 torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_atomic_add_runtime_round_trip(): + block = 16 + out = torch.empty((block * 2, ), device="musa", dtype=torch.int32) + + _local_ptr_atomic_add_kernel[(1, )](out, BLOCK=block, num_warps=1) + + ref_old = torch.zeros((block, ), dtype=torch.int32) + ref_after = torch.arange(1, block + 1, dtype=torch.int32) + torch.testing.assert_close(out[:block].cpu(), ref_old, rtol=0, atol=0) + torch.testing.assert_close(out[block:].cpu(), ref_after, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_atomic_cas_runtime_round_trip(): + out = torch.empty((1, ), device="musa", dtype=torch.int32) + + _local_ptr_atomic_cas_update_kernel[(1, )](out, num_warps=1) + + torch.testing.assert_close(out.cpu(), torch.tensor([9], dtype=torch.int32), rtol=0, atol=0) From 64ccc8a9aba9a94a85e09f799ef3d9e4f6fb1b9f Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Mon, 25 May 2026 17:54:23 +0800 Subject: [PATCH 09/10] [TLE][MTHREADS] Avoid illegal fp16 local pointer async copies --- .../test/unit/tle/test_tle_local_ptr.py | 575 ++++++++++++++++++ .../OptimizeLocalPointerAsyncStores.cpp | 37 +- 2 files changed, 611 insertions(+), 1 deletion(-) diff --git a/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py index d5398af49b..8ac21e576a 100644 --- a/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py +++ b/third_party/mthreads/python/test/unit/tle/test_tle_local_ptr.py @@ -1,3 +1,5 @@ +import re + import pytest import torch import triton @@ -40,6 +42,239 @@ def _local_ptr_full_view_kernel(out_ptr): tl.store(out_ptr + tl.arange(0, 16), loaded) +@triton.jit +def _local_ptr_axpy_kernel(x_ptr, y_ptr, out_ptr, numel, alpha, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < numel + + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (tl.arange(0, BLOCK), )) + + x_vals = tl.load(x_ptr + offsets, mask=mask, other=0.0) + tl.store(ptrs, x_vals, mask=mask) + + shared_values = tl.load(ptrs, mask=mask, other=0.0) + y_values = tl.load(y_ptr + offsets, mask=mask, other=0.0) + updated = shared_values * alpha + y_values + + tl.store(ptrs, updated, mask=mask) + out_vals = tl.load(ptrs, mask=mask, other=0.0) + tl.store(out_ptr + offsets, out_vals, mask=mask) + + +@triton.jit +def _local_ptr_constant_store_kernel(out_ptr, numel, value, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < numel + + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (tl.arange(0, BLOCK), )) + + init = tl.full((BLOCK, ), value, tl.float32) + tl.store(ptrs, init, mask=mask) + out_vals = tl.load(ptrs, mask=mask, other=0.0) + tl.store(out_ptr + offsets, out_vals, mask=mask) + + +@triton.jit +def _local_ptr_full_view_tail_mask_kernel(out_ptr, numel, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < numel + + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem) + + vals = tl.arange(0, BLOCK) + tl.store(ptrs, vals) + + out_vals = tl.load(ptrs, mask=mask, other=-1) + tl.store(out_ptr + offsets, out_vals, mask=mask) + + +@triton.jit +def _local_ptr_full_view_2d_copy_kernel(x_ptr, out_ptr, stride_xm, stride_xn, stride_om, stride_on, ROWS: tl.constexpr, + COLS: tl.constexpr): + smem = tle.gpu.alloc((ROWS, COLS), dtype=tl.float32, nv_mma_shared_layout=False) + rows = tl.arange(0, ROWS)[:, None] + cols = tl.arange(0, COLS)[None, :] + x_tile = x_ptr + rows * stride_xm + cols * stride_xn + tle.gpu.copy(x_tile, smem, (ROWS, COLS)) + + full_ptrs = tle.gpu.local_ptr(smem) + vals = tl.load(full_ptrs) + + out_tile = out_ptr + rows * stride_om + cols * stride_on + tl.store(out_tile, vals) + + +@triton.jit +def _local_ptr_local_load_none_kernel(out_ptr, BLOCK: tl.constexpr): + idx = tl.arange(0, BLOCK) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem) + tl.store(ptrs, idx + 3) + vals = tl.load(ptrs) + tl.store(out_ptr + idx, vals) + + +@triton.jit +def _local_ptr_local_load_full_indices_kernel(out_ptr, BLOCK: tl.constexpr): + idx = tl.arange(0, BLOCK) + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (idx, )) + tl.store(ptrs, idx + 5) + vals = tl.load(ptrs) + tl.store(out_ptr + idx, vals) + + +@triton.jit +def _local_ptr_conditional_mask_store_kernel(out_ptr, numel, BLOCK: tl.constexpr): + pid = tl.program_id(0) + idx = tl.arange(0, BLOCK) + mask = idx < numel + + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (idx, )) + + if pid == 0: + tl.store(ptrs, idx, mask=mask) + + vals = tl.load(ptrs, mask=mask, other=-1) + tl.store(out_ptr + idx, vals, mask=mask) + + +@triton.jit +def _local_ptr_looped_elementwise_kernel(x_ptr, y_ptr, out_ptr, numel, alpha, BLOCK: tl.constexpr, CHUNKS: tl.constexpr, + SLICES: tl.constexpr, SLICE_SIZE: tl.constexpr): + pid = tl.program_id(0) + base = pid * BLOCK * CHUNKS + + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.float32, nv_mma_shared_layout=False) + ptrs = tle.gpu.local_ptr(smem, (tl.arange(0, BLOCK), )) + assert BLOCK % SLICE_SIZE == 0, "BLOCK must be divisible by SLICE_SIZE" + slice_indices = tl.arange(0, SLICE_SIZE) + + for chunk in range(CHUNKS): + offsets = base + chunk * BLOCK + tl.arange(0, BLOCK) + mask = offsets < numel + x_vals = tl.load(x_ptr + offsets, mask=mask, other=0.0) + tl.store(ptrs, x_vals, mask=mask) + + for slice_idx in range(SLICES): + block_offset = slice_idx * SLICE_SIZE + slice_ptr = tle.gpu.local_ptr(smem, (block_offset + slice_indices, )) + slice_offsets = base + chunk * BLOCK + block_offset + slice_indices + slice_mask = slice_offsets < numel + shared_vals = tl.load(slice_ptr, mask=slice_mask, other=0.0) + y_vals = tl.load(y_ptr + slice_offsets, mask=slice_mask, other=0.0) + updated = shared_vals * alpha + y_vals + tl.store(slice_ptr, updated, mask=slice_mask) + + out_vals = tl.load(ptrs, mask=mask, other=0.0) + tl.store(out_ptr + offsets, out_vals, mask=mask) + + +@triton.jit +def _local_ptr_axis_gather_kernel(x_ptr, out_ptr, stride_xm, stride_xn, stride_om, stride_on, ROWS: tl.constexpr, + COLS: tl.constexpr, SLICE: tl.constexpr): + smem = tle.gpu.alloc((ROWS, COLS), dtype=tl.float32, nv_mma_shared_layout=False) + offs_m = tl.arange(0, ROWS)[:, None] + offs_n = tl.arange(0, COLS)[None, :] + x_tile = x_ptr + offs_m * stride_xm + offs_n * stride_xn + tle.gpu.copy(x_tile, smem, (ROWS, COLS)) + + row_ids = tl.broadcast_to(offs_m, (ROWS, SLICE)) + col_ids = tl.broadcast_to(1 + tl.arange(0, SLICE)[None, :], (ROWS, SLICE)) + slice_ptrs = tle.gpu.local_ptr(smem, (row_ids, col_ids)) + vals = tl.load(slice_ptrs) + + out_tile = out_ptr + offs_m * stride_om + tl.arange(0, SLICE)[None, :] * stride_on + tl.store(out_tile, vals) + + +@triton.jit +def _local_ptr_dynamic_scalar_load_after_vector_store_kernel(out_ptr, BLOCK: tl.constexpr): + smem = tle.gpu.alloc((BLOCK, ), dtype=tl.int32, nv_mma_shared_layout=False) + vec_idx = tl.arange(0, BLOCK) + vec_ptr = tle.gpu.local_ptr(smem, (vec_idx, )) + tl.store(vec_ptr, vec_idx + 1) + zero = tl.program_id(0) * 0 + + for i in range(BLOCK): + scalar_idx = zero + i + scalar_ptr = tle.gpu.local_ptr(smem, (scalar_idx, )) + scalar_val = tl.load(scalar_ptr) + tl.store(out_ptr + i, scalar_val) + + +@triton.jit +def _local_ptr_tiled_matmul_kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, + stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + NUM_K_TILES: tl.constexpr, SLICE_PARTS: tl.constexpr, SLICE_WIDTH: tl.constexpr): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + smem_a = tle.gpu.alloc((BLOCK_M, BLOCK_K), dtype=tl.float16, nv_mma_shared_layout=False) + smem_b = tle.gpu.alloc((BLOCK_K, BLOCK_N), dtype=tl.float16, nv_mma_shared_layout=False) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + slice_parts = int(SLICE_PARTS) + slice_width = int(SLICE_WIDTH) + assert BLOCK_K % slice_parts == 0, "BLOCK_K must divide slice_parts" + + for k_tile in range(NUM_K_TILES): + k_offsets = k_tile * BLOCK_K + tl.arange(0, BLOCK_K) + a_tile = a_ptr + offs_m[:, None] * stride_am + k_offsets[None, :] * stride_ak + b_tile = b_ptr + k_offsets[:, None] * stride_bk + offs_n[None, :] * stride_bn + tle.gpu.copy(a_tile, smem_a, (BLOCK_M, BLOCK_K)) + tle.gpu.copy(b_tile, smem_b, (BLOCK_K, BLOCK_N)) + + for slice_idx in range(slice_parts): + k_start = slice_idx * slice_width + a_rows = tl.arange(0, BLOCK_M)[:, None] + a_cols = tl.arange(0, SLICE_WIDTH)[None, :] + k_start + a_rows = tl.broadcast_to(a_rows, (BLOCK_M, SLICE_WIDTH)) + a_cols = tl.broadcast_to(a_cols, (BLOCK_M, SLICE_WIDTH)) + a_slice = tle.gpu.local_ptr(smem_a, (a_rows, a_cols)) + + b_rows = tl.arange(0, SLICE_WIDTH)[:, None] + k_start + b_cols = tl.arange(0, BLOCK_N)[None, :] + b_rows = tl.broadcast_to(b_rows, (SLICE_WIDTH, BLOCK_N)) + b_cols = tl.broadcast_to(b_cols, (SLICE_WIDTH, BLOCK_N)) + b_slice = tle.gpu.local_ptr(smem_b, (b_rows, b_cols)) + a_vals = tl.load(a_slice) + b_vals = tl.load(b_slice) + acc += tl.dot(a_vals, b_vals, out_dtype=tl.float32) + + c_tile = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + tl.store(c_tile, acc) + + +@triton.jit +def _local_ptr_full_view_dot_kernel(a_ptr, out_ptr, stride_ai, stride_aj, stride_oi, stride_oj, BLOCK: tl.constexpr): + offs_i = tl.arange(0, BLOCK)[:, None] + offs_j = tl.arange(0, BLOCK)[None, :] + + a_tile_ptr = a_ptr + offs_i * stride_ai + offs_j * stride_aj + a_tile = tl.load(a_tile_ptr) + + smem = tle.gpu.alloc((BLOCK, BLOCK), dtype=tl.float16, nv_mma_shared_layout=False) + smem_ptr = tle.gpu.local_ptr(smem) + tl.store(smem_ptr, a_tile) + + staged = tl.load(smem_ptr) + acc = tl.dot(staged, tl.trans(staged), out_dtype=tl.float32) + + out_ptrs = out_ptr + offs_i * stride_oi + offs_j * stride_oj + tl.store(out_ptrs, acc.to(tl.float16)) + + @triton.jit def _local_ptr_atomic_add_kernel(out_ptr, BLOCK: tl.constexpr): offsets = tl.arange(0, BLOCK) @@ -144,6 +379,179 @@ def test_tle_local_ptr_full_view_store_load_rewrites_to_memdesc_ops(): assert "musa_tle.local_pointers" not in llir, llir +def test_tle_local_ptr_full_view_tail_mask_lowers_through_mthreads_llvm(): + compiled = compile_musa( + _local_ptr_full_view_tail_mask_kernel, + signature={"out_ptr": "*i32", "numel": "i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 128}, + ) + + ttgir = compiled.asm["ttgir"] + llir = compiled.asm["llir"] + assert "ttg.local_store" in ttgir, ttgir + assert "musa_tle.local_pointers" not in llir, llir + + +def test_tle_local_ptr_full_view_and_indices_load_rewrite_to_local_load(): + none_compiled = compile_musa( + _local_ptr_local_load_none_kernel, + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + indices_compiled = compile_musa( + _local_ptr_local_load_full_indices_kernel, + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + none_ttgir = none_compiled.asm["ttgir"] + indices_ttgir = indices_compiled.asm["ttgir"] + assert "ttg.local_load" in none_ttgir, none_ttgir + assert "ttg.local_load" in indices_ttgir, indices_ttgir + assert "musa_tle.local_pointers" not in none_compiled.asm["llir"], none_compiled.asm["llir"] + assert "musa_tle.local_pointers" not in indices_compiled.asm["llir"], indices_compiled.asm["llir"] + + +def test_tle_local_ptr_2d_copy_and_axis_gather_lower_through_mthreads_llvm(): + copy_compiled = compile_musa( + _local_ptr_full_view_2d_copy_kernel, + signature={ + "x_ptr": "*fp32", + "out_ptr": "*fp32", + "stride_xm": "i32", + "stride_xn": "i32", + "stride_om": "i32", + "stride_on": "i32", + "ROWS": "constexpr", + "COLS": "constexpr", + }, + constexprs={"ROWS": 16, "COLS": 32}, + ) + gather_compiled = compile_musa( + _local_ptr_axis_gather_kernel, + signature={ + "x_ptr": "*fp32", + "out_ptr": "*fp32", + "stride_xm": "i32", + "stride_xn": "i32", + "stride_om": "i32", + "stride_on": "i32", + "ROWS": "constexpr", + "COLS": "constexpr", + "SLICE": "constexpr", + }, + constexprs={"ROWS": 8, "COLS": 8, "SLICE": 4}, + ) + + copy_ttgir = copy_compiled.asm["ttgir"] + gather_ttgir = gather_compiled.asm["ttgir"] + assert "ttg.async_copy_global_to_local" in copy_ttgir, copy_ttgir + assert "ttg.local_load" in copy_ttgir, copy_ttgir + assert "ttg.async_copy_global_to_local" in gather_ttgir, gather_ttgir + assert "musa_tle.local_pointers" not in copy_compiled.asm["llir"], copy_compiled.asm["llir"] + assert "musa_tle.local_pointers" not in gather_compiled.asm["llir"], gather_compiled.asm["llir"] + + +def test_tle_local_ptr_conditional_mask_store_compiles(): + compiled = compile_musa( + _local_ptr_conditional_mask_store_kernel, + signature={"out_ptr": "*i32", "numel": "i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 512}, + ) + + assert compiled.asm["llir"], compiled.asm + assert "musa_tle.local_pointers" not in compiled.asm["llir"], compiled.asm["llir"] + + +def test_tle_local_ptr_dynamic_scalar_index_inserts_barrier(): + compiled = compile_musa( + _local_ptr_dynamic_scalar_load_after_vector_store_kernel, + signature={"out_ptr": "*i32", "BLOCK": "constexpr"}, + constexprs={"BLOCK": 64}, + ) + + ttgir = compiled.asm["ttgir"] + assert "ttg.barrier local" in ttgir, ttgir + assert '"tt.reduce"' not in ttgir, ttgir + assert "musa_tle.local_pointers" not in compiled.asm["llir"], compiled.asm["llir"] + + +def test_tle_local_ptr_tiled_matmul_matches_torch(): + block_m = 32 + block_n = 32 + block_k = 32 + num_k_tiles = 2 + m = block_m + n = block_n + k = block_k * num_k_tiles + + a = torch.randn((m, k), device="musa", dtype=torch.float16) + b = torch.randn((k, n), device="musa", dtype=torch.float16) + c = torch.empty((m, n), device="musa", dtype=torch.float32) + + slice_parts = 2 + slice_width = block_k // slice_parts + grid = (m // block_m, n // block_n) + _local_ptr_tiled_matmul_kernel[grid]( + a, + b, + c, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + NUM_K_TILES=num_k_tiles, + SLICE_PARTS=slice_parts, + SLICE_WIDTH=slice_width, + num_warps=4, + ) + + expected = a.float() @ b.float() + torch.testing.assert_close(c.cpu(), expected.cpu(), rtol=5e-3, atol=5e-3) + + +def test_tle_local_ptr_full_view_dot_avoids_pointer_convert_layout(): + block = 32 + a = torch.randn((block, block), device="musa", dtype=torch.float16) + out = torch.empty_like(a) + + compiled = compile_musa( + _local_ptr_full_view_dot_kernel, + signature={ + "a_ptr": "*fp16", + "out_ptr": "*fp16", + "stride_ai": "i32", + "stride_aj": "i32", + "stride_oi": "i32", + "stride_oj": "i32", + "BLOCK": "constexpr", + }, + constexprs={"BLOCK": block}, + ) + ttgir = compiled.asm["ttgir"] + assert "ttg.local_load" in ttgir, ttgir + assert re.search(r"ttg\.convert_layout .*-> tensor<.*!tt\.ptr", ttgir) is None + + _local_ptr_full_view_dot_kernel[(1, )]( + a, + out, + a.stride(0), + a.stride(1), + out.stride(0), + out.stride(1), + BLOCK=block, + num_warps=4, + num_stages=2, + ) + expected = (a.float() @ a.float().T).to(torch.float16) + torch.testing.assert_close(out.cpu(), expected.cpu(), rtol=2e-1, atol=2e-1) + + def test_tle_local_ptr_atomic_ops_accept_addrspace3_ttir(): add_ttir = compile_to_ttir( _local_ptr_atomic_add_kernel, @@ -232,6 +640,173 @@ def test_tle_local_ptr_full_view_runtime_round_trip(): torch.testing.assert_close(out.cpu(), ref, rtol=0, atol=0) +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_axpy_runtime_matches_torch(): + torch.manual_seed(0) + block = 64 + numel = block * 4 + alpha = 1.5 + x = torch.randn((numel, ), device="musa", dtype=torch.float32) + y = torch.randn_like(x) + out = torch.empty_like(x) + + grid = (triton.cdiv(numel, block), ) + _local_ptr_axpy_kernel[grid](x, y, out, numel, alpha, BLOCK=block, num_warps=1) + + expected = (alpha * x + y).cpu() + torch.testing.assert_close(out.cpu(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_constant_store_runtime_populates_value(): + block = 64 + numel = block * 4 + value = 2.25 + out = torch.empty((numel, ), device="musa", dtype=torch.float32) + + grid = (triton.cdiv(numel, block), ) + _local_ptr_constant_store_kernel[grid](out, numel, value, BLOCK=block, num_warps=1) + + expected = torch.full((numel, ), value, dtype=torch.float32) + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=1e-7) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_full_view_tail_mask_runtime_preserves_tail(): + block = 128 + numel = block - 9 + out = torch.full((block, ), -1, device="musa", dtype=torch.int32) + + _local_ptr_full_view_tail_mask_kernel[(1, )](out, numel, BLOCK=block, num_warps=4) + + expected = torch.arange(0, block, dtype=torch.int32) + expected[numel:] = -1 + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_full_view_2d_copy_runtime_matches_input(): + rows = 16 + cols = 32 + x = torch.randn((rows, cols), device="musa", dtype=torch.float32) + out = torch.empty_like(x) + + _local_ptr_full_view_2d_copy_kernel[(1, )]( + x, + out, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ROWS=rows, + COLS=cols, + num_warps=4, + ) + + torch.testing.assert_close(out.cpu(), x.cpu(), rtol=1e-6, atol=1e-6) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_full_view_load_runtime_uses_local_load(): + block = 64 + out = torch.empty((block, ), device="musa", dtype=torch.int32) + + _local_ptr_local_load_none_kernel[(1, )](out, BLOCK=block, num_warps=4) + + expected = torch.arange(0, block, dtype=torch.int32) + 3 + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_full_indices_load_runtime_uses_local_load(): + block = 64 + out = torch.empty((block, ), device="musa", dtype=torch.int32) + + _local_ptr_local_load_full_indices_kernel[(1, )](out, BLOCK=block, num_warps=4) + + expected = torch.arange(0, block, dtype=torch.int32) + 5 + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_conditional_mask_store_runtime_preserves_tail(): + block = 512 + numel = block - 7 + out = torch.full((block, ), -1, device="musa", dtype=torch.int32) + + _local_ptr_conditional_mask_store_kernel[(1, )](out, numel, BLOCK=block, num_warps=8) + + expected = torch.arange(0, block, dtype=torch.int32) + expected[numel:] = -1 + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_looped_elementwise_runtime_matches_torch(): + torch.manual_seed(0) + block = 64 + chunks = 4 + numel = block * chunks * 3 + alpha = 0.75 + slices = 4 + slice_size = block // slices + x = torch.randn((numel, ), device="musa", dtype=torch.float32) + y = torch.randn_like(x) + out = torch.empty_like(x) + + grid = (triton.cdiv(numel, block * chunks), ) + _local_ptr_looped_elementwise_kernel[grid]( + x, + y, + out, + numel, + alpha, + BLOCK=block, + CHUNKS=chunks, + SLICES=slices, + SLICE_SIZE=slice_size, + num_warps=1, + ) + + expected = (alpha * x + y).cpu() + torch.testing.assert_close(out.cpu(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_axis_gather_runtime_matches_slice(): + rows = 8 + cols = 8 + slice_width = 4 + x = torch.arange(0, rows * cols, device="musa", dtype=torch.float32).reshape(rows, cols) + out = torch.empty((rows, slice_width), device="musa", dtype=torch.float32) + + _local_ptr_axis_gather_kernel[(1, )]( + x, + out, + x.stride(0), + x.stride(1), + out.stride(0), + out.stride(1), + ROWS=rows, + COLS=cols, + SLICE=slice_width, + num_warps=4, + ) + + torch.testing.assert_close(out.cpu(), x[:, 1:1 + slice_width].cpu(), rtol=1e-6, atol=1e-6) + + +@pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") +def test_tle_local_ptr_dynamic_scalar_index_runtime_matches_vector_store(): + block = 64 + out = torch.empty((block, ), device="musa", dtype=torch.int32) + + _local_ptr_dynamic_scalar_load_after_vector_store_kernel[(1, )](out, BLOCK=block, num_warps=4) + + expected = torch.arange(1, block + 1, dtype=torch.int32) + torch.testing.assert_close(out.cpu(), expected, rtol=0, atol=0) + + @pytest.mark.skipif(not torch.musa.is_available(), reason="MUSA device is not available") def test_tle_local_ptr_atomic_add_runtime_round_trip(): block = 16 diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp index d9a359cf23..7f0358eba9 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerAsyncStores.cpp @@ -7,14 +7,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include #include namespace mlir { @@ -225,11 +228,41 @@ static Value createSubviewForStore(OpBuilder &builder, Location loc, match.offsets); } +static unsigned +getAsyncCopyContiguity(tt::LoadOp load, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + unsigned contiguity = axisInfoAnalysis.getContiguity(load.getPtr()); + if (Value mask = load.getMask()) + contiguity = + std::min(contiguity, axisInfoAnalysis.getMaskAlignment(mask)); + return std::max(1u, contiguity); +} + +static bool +canUseAsyncCopyForStore(tt::LoadOp load, const StaticSubviewMatch &match, + tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) { + if (!tt::canBeConvertedToAsyncLoad(load, axisInfoAnalysis)) + return false; + + auto loadTy = dyn_cast(load.getType()); + auto memDescTy = dyn_cast(match.baseMemDesc.getType()); + if (!loadTy || !memDescTy) + return false; + + auto sharedEncoding = + dyn_cast(memDescTy.getEncoding()); + if (!sharedEncoding) + return false; + + return tt::getCopyVecBytes(loadTy, sharedEncoding) >= 4; +} + class OptimizeLocalPointerAsyncStoresPass : public impl::TritonMUSAGPUTLEOptimizeLocalPointerAsyncStoresBase< OptimizeLocalPointerAsyncStoresPass> { void runOnOperation() override { ModuleOp module = getOperation(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(module); SmallVector stores; module.walk([&](tt::StoreOp store) { stores.push_back(store); }); @@ -256,13 +289,15 @@ class OptimizeLocalPointerAsyncStoresPass if (cast(load.getType()).getElementType() != match->valueType.getElementType()) continue; + if (!canUseAsyncCopyForStore(load, *match, axisInfoAnalysis)) + continue; OpBuilder builder(store); Value dst = createSubviewForStore(builder, store.getLoc(), *match); auto asyncCopy = ttg::AsyncCopyGlobalToLocalOp::create( builder, store.getLoc(), load.getPtr(), dst, load.getMask(), load.getOther(), load.getCache(), load.getEvict(), - load.getIsVolatile()); + load.getIsVolatile(), getAsyncCopyContiguity(load, axisInfoAnalysis)); asyncCopy->setAttr(triton::musa_tle::kMUSATLELocalPointerAsyncStoreAttr, builder.getUnitAttr()); auto commit = ttg::AsyncCommitGroupOp::create(builder, store.getLoc(), From d1080e2bf9da56b9f428ed95be1df4db0cf97a4e Mon Sep 17 00:00:00 2001 From: QiLin Gai Date: Tue, 26 May 2026 11:02:13 +0800 Subject: [PATCH 10/10] [TLE][MTHREADS] Remove comments --- .../Transforms/InsertLocalPointerBarriers.cpp | 24 ------------------- .../Transforms/OptimizeLocalPointerLoads.cpp | 22 ----------------- .../Transforms/OptimizeLocalPointerStores.cpp | 22 ----------------- .../lib/Transforms/SelectEncodings.cpp | 24 ------------------- 4 files changed, 92 deletions(-) diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp index 70ffd7403a..f6d92e8058 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp +++ b/third_party/mthreads/tle/dialect/lib/Transforms/InsertLocalPointerBarriers.cpp @@ -1,27 +1,3 @@ -// MIT License -// -// Copyright (c) 2025 The FlagOS Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// flagtree tle - #ifdef __TLE__ #include "TritonMUSAGPUTransforms/Passes.h" diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp index 821ac7c21a..4872fa83ad 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerLoads.cpp @@ -1,25 +1,3 @@ -// MIT License -// -// Copyright (c) 2025 The FlagOS Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - #ifdef __TLE__ #include "Dialect/MUSATLE/IR/Dialect.h" diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp index a598054388..d30142e7f8 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp +++ b/third_party/mthreads/tle/dialect/lib/Transforms/OptimizeLocalPointerStores.cpp @@ -1,25 +1,3 @@ -// MIT License -// -// Copyright (c) 2025 The FlagOS Contributors -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - #ifdef __TLE__ #include "Dialect/MUSATLE/IR/Dialect.h" diff --git a/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp b/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp index 8827fb1a16..30dba07b56 100644 --- a/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp +++ b/third_party/mthreads/tle/dialect/lib/Transforms/SelectEncodings.cpp @@ -1,27 +1,3 @@ -// MIT License - -// Copyright (c) 2025 The FlagOS Contributors - -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: - -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. - -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// flagtree tle - #ifdef __TLE__ #include "Dialect/MUSATLE/IR/Dialect.h"