Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ python/triton/_C/*.pdb
python/triton/_C/*.exe
python/triton/_C/*.ilk
python/triton/FileCheck
python/triton/FLAGTREE_BACKEND

third_party/mthreads/python/triton/_C/*.so
third_party/mthreads/python/triton/FileCheck
third_party/mthreads/python/*.egg-info

# Backends copied from submodules
python/triton/backends/*
Expand Down
10 changes: 9 additions & 1 deletion python/triton/experimental/tle/language/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Sequence
from enum import Enum
from . import types as tle
from .mthreads import copy as mthreads_copy
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values

from triton.language.core import (
Expand Down Expand Up @@ -360,6 +361,7 @@ def copy(
TMA copy with offsets:
tle.copy(tma_desc, local_buf, [64, 64], [x_offset, y_offset])
"""
mthreads_enabled = mthreads_copy.enabled()

def normcopy(
src: tl.tensor,
Expand All @@ -368,6 +370,8 @@ def normcopy(
direction,
_semantic=None,
) -> None:
if mthreads_enabled:
mthreads_copy.validate_normal_copy(src, dst, shape, direction)

# Semantic analysis
try:
Expand All @@ -389,8 +393,10 @@ def normcopy(

try:
if direction == CopyDirection.GM_TO_LOCAL:
# None fills the FlagTree hints slot; TLE copy has no hints to pass.
load_extra_args = () if mthreads_enabled else (None, )
tt_load = _semantic.load(src, mask, other, boundary_check, padding_option, cache_modifier,
eviction_policy, volatile, None)
eviction_policy, volatile, *load_extra_args)
local_ptrs = local_ptr(dst, _make_full_indices(dst, _semantic), _semantic=_semantic)
_semantic.store(local_ptrs, tt_load, mask, boundary_check, cache_modifier, eviction_policy)
else:
Expand Down Expand Up @@ -492,6 +498,8 @@ def tmacopy(
raise ValueError(f"Shape parameter must be tuple or list, but got {type(shape)}")
if is_normcopy:
return normcopy(src, dst, shape, direction, _semantic)
if mthreads_enabled:
return mthreads_copy.tmacopy(src, dst, direction, shape, offsets, _semantic)
else:
return tmacopy(src, dst, direction, shape, offsets, _semantic)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import copy

__all__ = ["copy"]
103 changes: 103 additions & 0 deletions python/triton/experimental/tle/language/gpu/mthreads/copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

import triton.language.core as tl

from .. import types as tle

try:
from triton._flagtree_backend import FLAGTREE_BACKEND
except ModuleNotFoundError:
FLAGTREE_BACKEND = os.environ.get("FLAGTREE_BACKEND", "")


def _has_mthreads_libtriton() -> bool:
try:
from triton._C import libtriton
except ImportError:
return False
return hasattr(libtriton, "mthreads")


def enabled() -> bool:
return FLAGTREE_BACKEND == "mthreads" or _has_mthreads_libtriton()


def normalize_copy_shape(shape) -> tuple[int, ...]:
return tuple(int(tl._unwrap_if_constexpr(dim)) for dim in shape)


def validate_copy_buffer(buffer: tle.buffered_tensor, shape: tuple[int, ...]) -> None:
if not isinstance(buffer, tle.buffered_tensor):
raise ValueError(f"buffer must be a tle.gpu.buffered_tensor, but got {type(buffer)}")
if buffer.type.storage != tle.smem:
raise ValueError("MUSA TLE copy only supports tle.gpu.smem buffers")
buffer_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in buffer.type.shape)
if buffer_shape != shape:
raise ValueError(f"copy shape {shape} must match buffer shape {buffer_shape}")


def tensor_shape(value: tl.tensor) -> tuple[int, ...]:
if not value.type.is_block():
return tuple()
return tuple(int(tl._unwrap_if_constexpr(dim)) for dim in value.shape)


def tensor_pointer_element_ty(value: tl.tensor):
scalar_ty = value.dtype
if not scalar_ty.is_ptr():
raise ValueError("tle.gpu.copy tensor operands must be pointer tensors")
return scalar_ty.element_ty


def validate_normal_copy(src, dst, shape, direction) -> None:
shape = normalize_copy_shape(shape)
if direction.name == "GM_TO_LOCAL":
global_tensor = src
local_buffer = dst
else:
global_tensor = dst
local_buffer = src

validate_copy_buffer(local_buffer, shape)
ptr_shape = tensor_shape(global_tensor)
if ptr_shape != shape:
raise ValueError(f"copy shape {shape} must match tensor pointer shape {ptr_shape}")
elem_ty = tensor_pointer_element_ty(global_tensor)
if elem_ty != local_buffer.dtype:
raise ValueError(f"copy dtype mismatch: tensor points to {elem_ty}, buffer stores {local_buffer.dtype}")


def normalize_offsets(offsets, rank: int):
offsets = tl._unwrap_if_constexpr(offsets)
if offsets is None:
raise ValueError("descriptor-based tle.gpu.copy requires offsets")
if isinstance(offsets, tl.tuple):
offsets_tuple = tuple(offsets.values)
elif isinstance(offsets, (tuple, list)):
offsets_tuple = tuple(offsets)
elif hasattr(offsets, "__iter__"):
offsets_tuple = tuple(offsets)
else:
raise ValueError(f"offsets must be a tuple or list, but got {type(offsets)}")
if len(offsets_tuple) != rank:
raise ValueError(f"offsets must provide {rank} values, got {len(offsets_tuple)}")
return offsets_tuple


def tmacopy(src, dst, direction, shape, offsets, _semantic) -> None:
shape = normalize_copy_shape(shape)
desc = src if direction.name == "GM_TO_LOCAL" else dst
buffer = dst if direction.name == "GM_TO_LOCAL" else src

validate_copy_buffer(buffer, shape)
desc_shape = tuple(int(tl._unwrap_if_constexpr(dim)) for dim in desc.block_shape)
if desc_shape != shape:
raise ValueError(f"copy shape {shape} must match tensor descriptor block shape {desc_shape}")
if desc.dtype != buffer.dtype:
raise ValueError(f"copy dtype mismatch: descriptor stores {desc.dtype}, buffer stores {buffer.dtype}")

offset_values = normalize_offsets(offsets, len(desc_shape))
offset_values = _semantic._convert_to_ir_values(offset_values, require_i64=False)
if not hasattr(_semantic.builder, "create_tma_copy"):
raise RuntimeError("TLE TMA copy builder binding is not available")
_semantic.builder.create_tma_copy(src.handle, dst.handle, offset_values)
26 changes: 26 additions & 0 deletions third_party/mthreads/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,43 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/musa/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/musa/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tle/dialect/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/tle/dialect/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/tle/frontend/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/tle/frontend/include)
if(FLAGTREE_MTHREADS_TLE)
add_subdirectory(tle)
endif()
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(musa)
if(TRITON_BUILD_PYTHON_MODULE)
if(FLAGTREE_MTHREADS_TLE)
set(_MTHREADS_TLE_PLUGIN_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/tle/dialect/triton_mthreads_tle.cc
${CMAKE_CURRENT_SOURCE_DIR}/tle/frontend/triton_mthreads_frontend.cc)
set(_MTHREADS_TLE_PLUGIN_LIBS
MUSATLEIR MUSATLETransforms MUSATLEFrontendTransforms)
set(_MTHREADS_TLE_PLUGIN_DEPS
MUSATLETableGen
MUSATLETransforms
MUSATLEFrontendTransformsIncGen
MUSATLEFrontendTransforms)
else()
set(_MTHREADS_TLE_PLUGIN_SOURCES "")
set(_MTHREADS_TLE_PLUGIN_LIBS "")
set(_MTHREADS_TLE_PLUGIN_DEPS "")
endif()
add_triton_plugin(TritonMthreads ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads.cc
${_MTHREADS_TLE_PLUGIN_SOURCES}
LINK_LIBS TritonMUSAGPUToLLVM MTGPUToLLVM
TritonMUSAGPUTransforms
${_MTHREADS_TLE_PLUGIN_LIBS}
MLIRMTVMToLLVMIRTranslation)
add_dependencies(TritonMthreads
MUSATableGen
MUSAAttrDefsIncGen
${_MTHREADS_TLE_PLUGIN_DEPS}
MTGPUTableGen
MTGPUTypesIncGen
MTGPUConversionPassIncGen
Expand Down
12 changes: 12 additions & 0 deletions third_party/mthreads/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,18 @@ def make_ttgir(mod, metadata, opt, arch, capability):
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)

if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_async_stores"):
mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_async_stores(pm)
if hasattr(mthreads.passes.ttgpuir, "add_tle_early_assign_memory_space"):
mthreads.passes.ttgpuir.add_tle_early_assign_memory_space(pm)
if hasattr(mthreads.passes.ttgpuir, "add_tle_select_encodings"):
mthreads.passes.ttgpuir.add_tle_select_encodings(pm)
if hasattr(mthreads.passes.ttgpuir, "add_tle_insert_local_pointer_barriers"):
mthreads.passes.ttgpuir.add_tle_insert_local_pointer_barriers(pm)
if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_loads"):
mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_loads(pm)
if hasattr(mthreads.passes.ttgpuir, "add_tle_optimize_local_pointer_stores"):
mthreads.passes.ttgpuir.add_tle_optimize_local_pointer_stores(pm)
mthreads.passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
mthreads.passes.ttgpuir.add_optimize_dot_operands(pm)
Expand Down
6 changes: 6 additions & 0 deletions third_party/mthreads/bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "Dialect/MTGPU/IR/Dialect.h"
#include "Dialect/MUSA/IR/Dialect.h"
#ifdef __TLE__
#include "Dialect/MUSATLE/IR/Dialect.h"
#endif
#include "MTGPUToLLVM/Passes.h"
#include "TritonMUSAGPUToLLVM/Passes.h"
#include "TritonMUSAGPUTransforms/Passes.h"
Expand Down Expand Up @@ -122,6 +125,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::gpu::TritonGPUDialect,
mlir::triton::instrument::TritonInstrumentDialect,
mlir::triton::musa::MUSADialect, mlir::triton::mtgpu::MTGPUDialect,
#ifdef __TLE__
mlir::triton::musa_tle::MUSATLEDialect,
#endif
mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect,
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ namespace triton {
struct GlobalMemory : public SideEffects::Resource::Base<GlobalMemory> {
StringRef getName() final { return "<GlobalMemory>"; }
};
#ifdef __TLE__
struct SharedMemory : public SideEffects::Resource::Base<SharedMemory> {
StringRef getName() final { return "<SharedMemory>"; }
};
#endif

class DialectInferLayoutInterface
: public DialectInterface::Base<DialectInferLayoutInterface> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
// Interfaces
//
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
#ifdef __TLE__
def SharedMemory : Resource<"::mlir::triton::SharedMemory">;
#endif // __TLE__

//
// Op Base
Expand Down Expand Up @@ -350,8 +353,13 @@ def TT_StoreOp : TT_Op<"store", [
def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
#ifdef __TLE__
TypesMatchWith<"value type matches ptr type", "ptr", "val",
"getPointeeType($_self)">,
#else
TypesMatchWith<"ptr type matches value type", "val", "ptr",
"getPointerTypeSameShape($_self)">,
#endif // __TLE__
TypesMatchWith<"mask type matches value type",
"val", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
Expand All @@ -366,7 +374,12 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [

let arguments = (ins
TT_AtomicRMWAttr:$atomic_rmw_op,
#ifdef __TLE__
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>,
MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$ptr,
#else
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$ptr,
#endif // __TLE__
TT_Type:$val,
Optional<TT_BoolLike>:$mask,
TT_MemSemanticAttr:$sem,
Expand All @@ -386,10 +399,17 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
def TT_AtomicCASOp : TT_Op<"atomic_cas", [
SameOperandsAndResultShape,
SameOperandsAndResultEncoding,
#ifdef __TLE__
TypesMatchWith<"cmp type matches ptr type", "ptr", "cmp",
"getPointeeType($_self)">,
TypesMatchWith<"value type matches ptr type", "ptr", "val",
"getPointeeType($_self)">
#else
TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr",
"getPointerTypeSameShape($_self)">,
TypesMatchWith<"ptr type matches value type", "val", "ptr",
"getPointerTypeSameShape($_self)">
#endif // __TLE__
]> {
let summary = "atomic cas";

Expand All @@ -404,7 +424,12 @@ def TT_AtomicCASOp : TT_Op<"atomic_cas", [
}];

let arguments = (ins
#ifdef __TLE__
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>,
MemRead<SharedMemory>, MemWrite<SharedMemory>]>:$ptr,
#else
Arg<TT_PtrLike, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$ptr,
#endif // __TLE__
TT_Type:$cmp,
TT_Type:$val,
TT_MemSemanticAttr:$sem,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})

set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td)
if(FLAGTREE_MTHREADS_TLE)
set(_TLE_TABLEGEN_DEFS -D__TLE__)
else()
set(_TLE_TABLEGEN_DEFS "")
endif()
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=ttg)
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=ttg)
mlir_tablegen(Ops.h.inc -gen-op-decls)
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
mlir_tablegen(Ops.h.inc -gen-op-decls ${_TLE_TABLEGEN_DEFS})
mlir_tablegen(Ops.cpp.inc -gen-op-defs ${_TLE_TABLEGEN_DEFS})
mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=ttg)
mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=ttg)
add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ include "mlir/Interfaces/ViewLikeInterface.td"
def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
def SharedMemory : Resource<"::mlir::triton::gpu::SharedMemory">;

#ifdef __TLE__
def TTG_TMACopyOperand : AnyTypeOf<[TT_TensorDescType, TTG_MemDescType]>;
#endif // __TLE__

class TTG_Op<string mnemonic, list<Trait> traits = []> :
Op<TritonGPU_Dialect, mnemonic,
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
Expand Down Expand Up @@ -721,6 +725,29 @@ def TTG_BarrierOp : TTG_Op<"barrier"> {
}];
}

#ifdef __TLE__
def TTG_TMACopyOp : TTG_Op<"tma_copy", [MemoryEffects<[MemRead, MemWrite]>]> {
let summary = "Pseudo op for descriptor-based copy between global tensor descriptor and shared memdesc.";

let description = [{
`ttg.tma_copy` represents an explicit copy between a global tensor
descriptor and a shared-memory memdesc. Backend-specific TME/TMA lowering
replaces it with the target hardware copy and synchronization operations.
}];

let arguments = (ins
TTG_TMACopyOperand:$src,
TTG_TMACopyOperand:$dst,
Variadic<I32>:$indices
);

let assemblyFormat =
"$src `,` $dst `,` `[` $indices `]` attr-dict `:` type($src) `,` type($dst)";

let hasVerifier = 1;
}
#endif // __TLE__

def TTG_WarpIdOp : TTG_Op<"warp_id", [Pure]> {
let summary = "Return the GPU warp ID";

Expand Down
Loading
Loading