Skip to content

Commit afe168b

Browse files
committed
[TIR] Add cooperative_tensor builtins and metal.cooperative_tensor storage scope
Add TIR builtins for Metal cooperative_tensor operations (MetalPerformancePrimitives): - cooperative_tensor_fill: fill a cooperative_tensor with a value - cooperative_tensor_load: load from device/threadgroup memory - cooperative_tensor_store: store to device/threadgroup memory - cooperative_tensor_multiply_accumulate: matrix multiply-accumulate via matmul2d Add metal.cooperative_tensor storage scope (StorageRank::kMetalCooperativeTensor) for buffers backed by MPP cooperative_tensor registers, analogous to the existing metal.simdgroup scope but targeting the Metal 4 tensor operations API. These primitives enable code generation for MetalPerformancePrimitives matmul2d, which routes to NAX tensor cores on Apple M5 and falls back to simdgroup matrix instructions on M1-M4.
1 parent 882a774 commit afe168b

6 files changed

Lines changed: 336 additions & 3 deletions

File tree

include/tvm/tir/builtin.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,51 @@ TVM_DLL const Op& simdgroup_store();
806806
*/
807807
TVM_DLL const Op& simdgroup_multiply_accumulate();
808808

809+
// Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4)
810+
811+
/*!
812+
* \brief Fill a cooperative_tensor with a given value.
813+
*
814+
* void cooperative_tensor_fill(Var d, PrimExpr index, PrimExpr value,
815+
* int rows, int cols);
816+
*/
817+
TVM_DLL const Op& cooperative_tensor_fill();
818+
819+
/*!
820+
* \brief Load data from device or threadgroup memory into a cooperative_tensor.
821+
*
822+
* void cooperative_tensor_load(Var d, PrimExpr index, PrimExpr ptr,
823+
* PrimExpr stride, int rows, int cols,
824+
* bool transpose_matrix,
825+
* int mma_M, int mma_N, int mma_K,
826+
* int operand_role);
827+
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
828+
*/
829+
TVM_DLL const Op& cooperative_tensor_load();
830+
831+
/*!
832+
* \brief Store data from a cooperative_tensor to device or threadgroup memory.
833+
*
834+
* void cooperative_tensor_store(Var d, PrimExpr index, PrimExpr ptr,
835+
* PrimExpr stride, int rows, int cols,
836+
* bool transpose_matrix,
837+
* int mma_M, int mma_N, int mma_K,
838+
* int operand_role);
839+
* operand_role: 0=left(A), 1=right(B), 2=destination(C)
840+
*/
841+
TVM_DLL const Op& cooperative_tensor_store();
842+
843+
/*!
844+
* \brief Multiply and accumulate two matrices using cooperative_tensor
845+
* (MetalPerformancePrimitives matmul2d).
846+
*
847+
* void cooperative_tensor_multiply_accumulate(
848+
* Var d, PrimExpr index_d, Var a, PrimExpr index_a,
849+
* Var b, PrimExpr index_b, Var c, PrimExpr index_c,
850+
* int M, int N, int K, bool transpose_a, bool transpose_b);
851+
*/
852+
TVM_DLL const Op& cooperative_tensor_multiply_accumulate();
853+
809854
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
810855
/*!
811856
* \brief Get the high level half of the vector

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,159 +1430,313 @@ def func(
14301430

14311431
return func
14321432

1433+
14331434
if TYPE_CHECKING:
1435+
14341436
class int8: ...
1437+
14351438
class int16: ...
1439+
14361440
class int32: ...
1441+
14371442
class int64: ...
1443+
14381444
class int8x4: ...
1445+
14391446
class int16x4: ...
1447+
14401448
class int32x4: ...
1449+
14411450
class int64x4: ...
1451+
14421452
class int8x8: ...
1453+
14431454
class int16x8: ...
1455+
14441456
class int32x8: ...
1457+
14451458
class int64x8: ...
1459+
14461460
class int8x16: ...
1461+
14471462
class int16x16: ...
1463+
14481464
class int32x16: ...
1465+
14491466
class int64x16: ...
1467+
14501468
class int8x32: ...
1469+
14511470
class int16x32: ...
1471+
14521472
class int32x32: ...
1473+
14531474
class int64x32: ...
1475+
14541476
class int8x64: ...
1477+
14551478
class int16x64: ...
1479+
14561480
class int32x64: ...
1481+
14571482
class int64x64: ...
1483+
14581484
class uint8: ...
1485+
14591486
class uint16: ...
1487+
14601488
class uint32: ...
1489+
14611490
class uint64: ...
1491+
14621492
class uint8x4: ...
1493+
14631494
class uint16x4: ...
1495+
14641496
class uint32x4: ...
1497+
14651498
class uint64x4: ...
1499+
14661500
class uint8x8: ...
1501+
14671502
class uint16x8: ...
1503+
14681504
class uint32x8: ...
1505+
14691506
class uint64x8: ...
1507+
14701508
class uint8x16: ...
1509+
14711510
class uint16x16: ...
1511+
14721512
class uint32x16: ...
1513+
14731514
class uint64x16: ...
1515+
14741516
class uint8x32: ...
1517+
14751518
class uint16x32: ...
1519+
14761520
class uint32x32: ...
1521+
14771522
class uint64x32: ...
1523+
14781524
class uint8x64: ...
1525+
14791526
class uint16x64: ...
1527+
14801528
class uint32x64: ...
1529+
14811530
class uint64x64: ...
1531+
14821532
class float16: ...
1533+
14831534
class float32: ...
1535+
14841536
class float64: ...
1537+
14851538
class float16x2: ...
1539+
14861540
class float32x2: ...
1541+
14871542
class float64x2: ...
1543+
14881544
class float16x4: ...
1545+
14891546
class float32x4: ...
1547+
14901548
class float64x4: ...
1549+
14911550
class float16x8: ...
1551+
14921552
class float32x8: ...
1553+
14931554
class float64x8: ...
1555+
14941556
class float16x16: ...
1557+
14951558
class float32x16: ...
1559+
14961560
class float64x16: ...
1561+
14971562
class float16x32: ...
1563+
14981564
class float32x32: ...
1565+
14991566
class float64x32: ...
1567+
15001568
class float16x64: ...
1569+
15011570
class float32x64: ...
1571+
15021572
class float64x64: ...
1573+
15031574
class float8_e3m4: ...
1575+
15041576
class float8_e3m4x2: ...
1577+
15051578
class float8_e3m4x4: ...
1579+
15061580
class float8_e3m4x8: ...
1581+
15071582
class float8_e3m4x16: ...
1583+
15081584
class float8_e3m4x32: ...
1585+
15091586
class float8_e3m4x64: ...
1587+
15101588
class float8_e4m3: ...
1589+
15111590
class float8_e4m3x2: ...
1591+
15121592
class float8_e4m3x4: ...
1593+
15131594
class float8_e4m3x8: ...
1595+
15141596
class float8_e4m3x16: ...
1597+
15151598
class float8_e4m3x32: ...
1599+
15161600
class float8_e4m3x64: ...
1601+
15171602
class float8_e4m3b11fnuz: ...
1603+
15181604
class float8_e4m3b11fnuzx2: ...
1605+
15191606
class float8_e4m3b11fnuzx4: ...
1607+
15201608
class float8_e4m3b11fnuzx8: ...
1609+
15211610
class float8_e4m3b11fnuzx16: ...
1611+
15221612
class float8_e4m3b11fnuzx32: ...
1613+
15231614
class float8_e4m3b11fnuzx64: ...
1615+
15241616
class float8_e4m3fn: ...
1617+
15251618
class float8_e4m3fnx2: ...
1619+
15261620
class float8_e4m3fnx4: ...
1621+
15271622
class float8_e4m3fnx8: ...
1623+
15281624
class float8_e4m3fnx16: ...
1625+
15291626
class float8_e4m3fnx32: ...
1627+
15301628
class float8_e4m3fnx64: ...
1629+
15311630
class float8_e4m3fnuz: ...
1631+
15321632
class float8_e4m3fnuzx2: ...
1633+
15331634
class float8_e4m3fnuzx4: ...
1635+
15341636
class float8_e4m3fnuzx8: ...
1637+
15351638
class float8_e4m3fnuzx16: ...
1639+
15361640
class float8_e4m3fnuzx32: ...
1641+
15371642
class float8_e4m3fnuzx64: ...
1643+
15381644
class float8_e5m2: ...
1645+
15391646
class float8_e5m2x2: ...
1647+
15401648
class float8_e5m2x4: ...
1649+
15411650
class float8_e5m2x8: ...
1651+
15421652
class float8_e5m2x16: ...
1653+
15431654
class float8_e5m2x32: ...
1655+
15441656
class float8_e5m2x64: ...
1657+
15451658
class float8_e5m2fnuz: ...
1659+
15461660
class float8_e5m2fnuzx2: ...
1661+
15471662
class float8_e5m2fnuzx4: ...
1663+
15481664
class float8_e5m2fnuzx8: ...
1665+
15491666
class float8_e5m2fnuzx16: ...
1667+
15501668
class float8_e5m2fnuzx32: ...
1669+
15511670
class float8_e5m2fnuzx64: ...
1671+
15521672
class float8_e8m0fnu: ...
1673+
15531674
class float8_e8m0fnux2: ...
1675+
15541676
class float8_e8m0fnux4: ...
1677+
15551678
class float8_e8m0fnux8: ...
1679+
15561680
class float8_e8m0fnux16: ...
1681+
15571682
class float8_e8m0fnux32: ...
1683+
15581684
class float8_e8m0fnux64: ...
1685+
15591686
class float6_e2m3fn: ...
1687+
15601688
class float6_e2m3fnx2: ...
1689+
15611690
class float6_e2m3fnx4: ...
1691+
15621692
class float6_e2m3fnx8: ...
1693+
15631694
class float6_e2m3fnx16: ...
1695+
15641696
class float6_e2m3fnx32: ...
1697+
15651698
class float6_e2m3fnx64: ...
1699+
15661700
class float6_e3m2fn: ...
1701+
15671702
class float6_e3m2fnx2: ...
1703+
15681704
class float6_e3m2fnx4: ...
1705+
15691706
class float6_e3m2fnx8: ...
1707+
15701708
class float6_e3m2fnx16: ...
1709+
15711710
class float6_e3m2fnx32: ...
1711+
15721712
class float6_e3m2fnx64: ...
1713+
15731714
class float4_e2m1fn: ...
1715+
15741716
class float4_e2m1fnx2: ...
1717+
15751718
class float4_e2m1fnx4: ...
1719+
15761720
class float4_e2m1fnx8: ...
1721+
15771722
class float4_e2m1fnx16: ...
1723+
15781724
class float4_e2m1fnx32: ...
1725+
15791726
class float4_e2m1fnx64: ...
1727+
15801728
class bfloat16: ...
1729+
15811730
class bfloat16x2: ...
1731+
15821732
class bfloat16x4: ...
1733+
15831734
class bfloat16x8: ...
1735+
15841736
class bfloat16x16: ...
1737+
15851738
class bfloat16x32: ...
1739+
15861740
class bfloat16x64: ...
15871741
else:
15881742
# pylint: disable=invalid-name
@@ -2202,6 +2356,10 @@ def wrapped(*args, **kwargs):
22022356
simdgroup_load = _op_wrapper(_tir_op.simdgroup_load)
22032357
simdgroup_store = _op_wrapper(_tir_op.simdgroup_store)
22042358
simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate)
2359+
cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill)
2360+
cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load)
2361+
cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store)
2362+
cooperative_tensor_multiply_accumulate = _op_wrapper(_tir_op.cooperative_tensor_multiply_accumulate)
22052363
create_barriers = _op_wrapper(_tir_op.create_barriers)
22062364
assume = _op_wrapper(_tir_op.assume)
22072365
undef = _op_wrapper(_tir_op.undef)
@@ -2500,6 +2658,10 @@ def wrapped(*args, **kwargs):
25002658
"simdgroup_load",
25012659
"simdgroup_store",
25022660
"simdgroup_multiply_accumulate",
2661+
"cooperative_tensor_fill",
2662+
"cooperative_tensor_load",
2663+
"cooperative_tensor_store",
2664+
"cooperative_tensor_multiply_accumulate",
25032665
"create_barriers",
25042666
"mma_store",
25052667
"mma_fill",

0 commit comments

Comments
 (0)