Skip to content

Commit ae0618c

Browse files
committed
[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope
Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives): - cooperative_tensor_fill: fill a cooperative_tensor with a value - cooperative_tensor_load: load from device/threadgroup memory - cooperative_tensor_store: store to device/threadgroup memory - cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor) for buffers backed by MPP cooperative_tensor registers, analogous to the existing metal.simdgroup scope but targeting the Metal 4 tensor operations API. These primitives enable code generation for MetalPerformancePrimitives matmul2d, which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix instructions on M1-M4.
1 parent ca68bef commit ae0618c

5 files changed

Lines changed: 176 additions & 0 deletions

File tree

include/tvm/tirx/builtin.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,51 @@ TVM_DLL const Op& simdgroup_store();
788788
*/
789789
TVM_DLL const Op& simdgroup_multiply_accumulate();
790790

791+
// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)
792+
793+
/*!
794+
* \brief Fill a cooperative_tensor with a given value.
795+
*
796+
* void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
797+
* int rows, int cols);
798+
*/
799+
TVM_DLL const Op& cooperative_tensor_fill();
800+
801+
/*!
802+
* \brief Load data from device or threadgroup memory into a cooperative_tensor.
803+
*
804+
* void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
805+
* PrimExpr stride, int rows, int cols,
806+
* bool transpose_matrix,
807+
* int mma_M, int mma_N, int mma_K,
808+
* int operand_role);
809+
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
810+
*/
811+
TVM_DLL const Op& cooperative_tensor_load();
812+
813+
/*!
814+
* \brief Store data from a cooperative_tensor to device or threadgroup memory.
815+
*
816+
* void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
817+
* PrimExpr stride, int rows, int cols,
818+
* bool transpose_matrix,
819+
* int mma_M, int mma_N, int mma_K,
820+
* int operand_role);
821+
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
822+
*/
823+
TVM_DLL const Op& cooperative_tensor_store();
824+
825+
/*!
826+
* \brief Multiply and accumulate two matrices using cooperative_tensor
827+
* (MetalPerformancePrimitives matmul2d).
828+
*
829+
* void cooperative_tensor_multiply_accumulate(
830+
* Var d, PrimExpr index_d, Var a, PrimExpr index_a,
831+
* Var b, PrimExpr index_b, Var c, PrimExpr index_c,
832+
* int M, int N, int K, bool transpose_a, bool transpose_b);
833+
*/
834+
TVM_DLL const Op& cooperative_tensor_multiply_accumulate();
835+
791836
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
792837
/*!
793838
* \brief Get the high level half of the vector

python/tvm/script/ir_builder/tirx/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,10 @@ def wrapped(*args, **kwargs) -> T:
19621962
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
19631963
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
19641964
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
1965+
cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
1966+
cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
1967+
cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
1968+
cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
19651969
create_barriers = _op_wrapper(_tir_op.create_barriers)
19661970
assume = _op_wrapper(_tir_op.assume)
19671971
undef = _op_wrapper(_tir_op.undef)
@@ -2252,6 +2256,10 @@ def wrapped(*args, **kwargs):
22522256
"simdgroup_load",
22532257
"simdgroup_store",
22542258
"simdgroup_multiply_accumulate",
2259+
"cooperative_tensor_fill",
2260+
"cooperative_tensor_load",
2261+
"cooperative_tensor_store",
2262+
"cooperative_tensor_multiply_accumulate",
22552263
"create_barriers",
22562264
"mma_store",
22572265
"mma_fill",

python/tvm/tirx/op.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,6 +1792,110 @@ def simdgroup_multiply_accumulate(
17921792
)
17931793

17941794

1795+
def cooperative_tensor_fill(
1796+
d: Var,
1797+
index: PrimExpr,
1798+
value: PrimExpr,
1799+
rows: int,
1800+
cols: int,
1801+
):
1802+
return call_intrin("handle", "tirx.cooperative_tensor_fill", d, index, value, rows, cols)
1803+
1804+
1805+
def cooperative_tensor_load(
1806+
d: Var,
1807+
index: PrimExpr,
1808+
ptr: PrimExpr,
1809+
stride: PrimExpr,
1810+
rows: int,
1811+
cols: int,
1812+
transpose_matrix: bool = False,
1813+
mma_M: int = 0,
1814+
mma_N: int = 0,
1815+
mma_K: int = 0,
1816+
operand_role: int = 0,
1817+
):
1818+
return call_intrin(
1819+
"handle",
1820+
"tirx.cooperative_tensor_load",
1821+
d,
1822+
index,
1823+
ptr,
1824+
stride,
1825+
rows,
1826+
cols,
1827+
transpose_matrix,
1828+
mma_M,
1829+
mma_N,
1830+
mma_K,
1831+
operand_role,
1832+
)
1833+
1834+
1835+
def cooperative_tensor_store(
1836+
d: PrimExpr,
1837+
index: PrimExpr,
1838+
ptr: PrimExpr,
1839+
stride: PrimExpr,
1840+
rows: int,
1841+
cols: int,
1842+
transpose_matrix: bool = False,
1843+
mma_M: int = 0,
1844+
mma_N: int = 0,
1845+
mma_K: int = 0,
1846+
operand_role: int = 0,
1847+
):
1848+
return call_intrin(
1849+
"handle",
1850+
"tirx.cooperative_tensor_store",
1851+
d,
1852+
index,
1853+
ptr,
1854+
stride,
1855+
rows,
1856+
cols,
1857+
transpose_matrix,
1858+
mma_M,
1859+
mma_N,
1860+
mma_K,
1861+
operand_role,
1862+
)
1863+
1864+
1865+
def cooperative_tensor_multiply_accumulate(
1866+
d: Var,
1867+
index_d: PrimExpr,
1868+
a: Var,
1869+
index_a: PrimExpr,
1870+
b: Var,
1871+
index_b: PrimExpr,
1872+
c: Var,
1873+
index_c: PrimExpr,
1874+
M: int,
1875+
N: int,
1876+
K: int,
1877+
transpose_a: bool = False,
1878+
transpose_b: bool = False,
1879+
):
1880+
return call_intrin(
1881+
"handle",
1882+
"tirx.cooperative_tensor_multiply_accumulate",
1883+
d,
1884+
index_d,
1885+
a,
1886+
index_a,
1887+
b,
1888+
index_b,
1889+
c,
1890+
index_c,
1891+
M,
1892+
N,
1893+
K,
1894+
transpose_a,
1895+
transpose_b,
1896+
)
1897+
1898+
17951899
def vectorlow(dtype, vec):
17961900
"""Get the low level half of the vector
17971901

src/runtime/thread_storage_scope.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ enum class StorageRank {
7171
kMMAMatrixC = 11,
7272
/*! \brief Metal SIMD group memory */
7373
kMetalSimdGroup = 12,
74+
/*! \brief Metal cooperative_tensor memory (MetalPerformancePrimitives) */
75+
kMetalCooperativeTensor = 13,
7476
};
7577

7678
/*!
@@ -129,6 +131,8 @@ struct StorageScope {
129131
return "m16n8k8.matrixC" + tag;
130132
case StorageRank::kMetalSimdGroup:
131133
return "metal.simdgroup" + tag;
134+
case StorageRank::kMetalCooperativeTensor:
135+
return "metal.cooperative_tensor" + tag;
132136
default:
133137
TVM_FFI_THROW(InternalError) << "unknown storage scope";
134138
return "";
@@ -182,6 +186,9 @@ struct StorageScope {
182186
} else if (s.compare(0, 15, "metal.simdgroup") == 0) {
183187
r.rank = StorageRank::kMetalSimdGroup;
184188
r.tag = s.substr(15, std::string::npos);
189+
} else if (s.compare(0, 24, "metal.cooperative_tensor") == 0) {
190+
r.rank = StorageRank::kMetalCooperativeTensor;
191+
r.tag = s.substr(24, std::string::npos);
185192
} else {
186193
TVM_FFI_THROW(InternalError) << "unknown storage scope " << s;
187194
}

src/tirx/op/builtin.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@ TIR_DEFINE_BUILTIN_FUNC(simdgroup_store)
348348
TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate)
349349
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
350350

351+
TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill)
352+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
353+
354+
TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_load)
355+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
356+
357+
TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_store)
358+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
359+
360+
TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_multiply_accumulate)
361+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
362+
351363
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
352364
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
353365
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",

0 commit comments

Comments
 (0)