@@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
514514bool SYCLGenBase::emitBuiltinType (const InlineAsmBuiltinType *T) {
515515 switch (T->getKind ()) {
516516 // clang-format off
517+ case InlineAsmBuiltinType::b1: OS () << " uint8_t" ; break ;
517518 case InlineAsmBuiltinType::b8: OS () << " uint8_t" ; break ;
518519 case InlineAsmBuiltinType::b16: OS () << " uint16_t" ; break ;
519520 case InlineAsmBuiltinType::b32: OS () << " uint32_t" ; break ;
520521 case InlineAsmBuiltinType::b64: OS () << " uint64_t" ; break ;
522+ case InlineAsmBuiltinType::u4: OS () << " uint8_t" ; break ;
521523 case InlineAsmBuiltinType::u8 : OS () << " uint8_t" ; break ;
522524 case InlineAsmBuiltinType::u16 : OS () << " uint16_t" ; break ;
523525 case InlineAsmBuiltinType::u32 : OS () << " uint32_t" ; break ;
524526 case InlineAsmBuiltinType::u64 : OS () << " uint64_t" ; break ;
527+ case InlineAsmBuiltinType::s4: OS () << " int8_t" ; break ;
525528 case InlineAsmBuiltinType::s8: OS () << " int8_t" ; break ;
526529 case InlineAsmBuiltinType::s16: OS () << " int16_t" ; break ;
527530 case InlineAsmBuiltinType::s32: OS () << " int32_t" ; break ;
@@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
559562 case InlineAsmVectorType::x1:
560563 OS () << 1 ;
561564 break ;
565+ case InlineAsmVectorType::v1:
566+ OS () << 1 ;
567+ break ;
562568 case InlineAsmVectorType::v2:
563569 case InlineAsmVectorType::x2:
564570 OS () << 2 ;
@@ -1370,6 +1376,167 @@ class SYCLGen : public SYCLGenBase {
13701376 return SYCLGenSuccess ();
13711377 }
13721378
1379+ bool handle_mma (const InlineAsmInstruction *Inst) override {
1380+ if (Inst->getNumInputOperands () != 3 )
1381+ return SYCLGenError ();
1382+
1383+ const InlineAsmVectorExpr *DMatVE =
1384+ dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand ());
1385+ if (!DMatVE)
1386+ return SYCLGenError ();
1387+
1388+ // Only row Layout is supported for of A matrix and
1389+ // only col Layout is supported for of B matrix
1390+ if (Inst->getAttr (3 ) != InstAttr::row || Inst->getAttr (4 ) != InstAttr::col)
1391+ return SYCLGenError ();
1392+
1393+ // Data types of D, A, B & C matrices respectively in the PTX instruction
1394+ const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (0 ));
1395+ const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (1 ));
1396+ const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (2 ));
1397+ const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType (3 ));
1398+
1399+ if (!(AType && BType && CType && DType))
1400+ return SYCLGenError ();
1401+
1402+ // Data types of matrix elements for A&B and C&D matrices should be same
1403+ if ((AType->getKind () != BType->getKind ()) ||
1404+ (CType->getKind () != DType->getKind ()))
1405+ return SYCLGenError ();
1406+
1407+ // Check the validity of AB & CD types
1408+ std::string ABType, CDType;
1409+ if (tryEmitType (ABType, AType))
1410+ return SYCLGenError ();
1411+
1412+ if (tryEmitType (CDType, CType))
1413+ return SYCLGenError ();
1414+
1415+ // Register sizes for vector elements of A, B, C & D matrices
1416+ unsigned NumVecElements[4 ] = {0 };
1417+
1418+ // Sizes of A & B matrices
1419+ std::string M, N, K;
1420+
1421+ // Data types of A, B & C matrices respectively in the PTX arguments
1422+ std::string InMatrixType[3 ];
1423+
1424+ if (Inst->hasAttr (InstAttr::m16n8k16)) {
1425+ M = " 16" ;
1426+ N = " 8" ;
1427+ K = " 16" ;
1428+
1429+ // Only f16/s8 types are supported for A and B matrices of m16n8k16
1430+ if (AType->getKind () == InlineAsmBuiltinType::f16 ) {
1431+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1432+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1433+
1434+ // If A matrix type is f16, then C&D matrix types can only be f32
1435+ if (CType->getKind () == InlineAsmBuiltinType::f32 ) {
1436+ NumVecElements[0 ] = 4 ; // A
1437+ NumVecElements[1 ] = 2 ; // B
1438+ NumVecElements[2 ] = 4 ; // C
1439+ NumVecElements[3 ] = 4 ; // D
1440+ } else
1441+ return SYCLGenError ();
1442+ } else if (AType->getKind () == InlineAsmBuiltinType::s8) {
1443+ InMatrixType[0 ] = " uint32_t" ; // A type is .f16x2
1444+ InMatrixType[1 ] = " uint32_t" ; // B type is .f16x2
1445+
1446+ // If A matrix type is s8, then C&D matrix types can only be s32
1447+ if (CType->getKind () == InlineAsmBuiltinType::s32) {
1448+ NumVecElements[0 ] = 2 ; // A
1449+ NumVecElements[1 ] = 1 ; // B
1450+ NumVecElements[2 ] = 4 ; // C
1451+ NumVecElements[3 ] = 4 ; // D
1452+ } else
1453+ return SYCLGenError ();
1454+ } else
1455+ return SYCLGenError ();
1456+ } else
1457+ return SYCLGenError ();
1458+
1459+ InMatrixType[2 ] = CDType;
1460+
1461+ // Check the register sizes for vector elements of A, B, C & D matrices
1462+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1463+ InputOp++) {
1464+ if (auto VE =
1465+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1466+ if (VE ->getNumElements () != NumVecElements[InputOp])
1467+ return SYCLGenError ();
1468+ } else
1469+ return SYCLGenError ();
1470+ }
1471+ if (DMatVE->getNumElements () != NumVecElements[3 ])
1472+ return SYCLGenError ();
1473+
1474+ // Declare and init an array for storing the addresses of D matrix elements
1475+ OS () << " {\n " ;
1476+ OS () << " volatile " << CDType << " *d_mat_frag_ct1["
1477+ << DMatVE->getNumElements () << " ] = { " ;
1478+ for (unsigned Inst = 0 ; Inst != DMatVE->getNumElements (); ++Inst) {
1479+ if (isa<InlineAsmDiscardExpr>(DMatVE->getElement (Inst)))
1480+ continue ;
1481+ OS () << " &" ;
1482+ if (emitStmt (DMatVE->getElement (Inst)))
1483+ return SYCLGenError ();
1484+ if ((Inst + 1 ) != DMatVE->getNumElements ())
1485+ OS () << " , " ;
1486+ }
1487+ OS () << " }" ;
1488+ endstmt ();
1489+
1490+ // Declare and init vectors for storing the values of A, B & C matrix
1491+ // elements
1492+ std::string InMatrixName[3 ] = {" a" , " b" , " c" };
1493+ for (unsigned InputOp = 0 ; InputOp < Inst->getNumInputOperands ();
1494+ InputOp++) {
1495+ if (auto VE =
1496+ dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand (InputOp))) {
1497+ OS () << " sycl::vec<" << InMatrixType[InputOp] << " , "
1498+ << VE ->getNumElements () << " > " << InMatrixName[InputOp]
1499+ << " _mat_frag_ct1(" ;
1500+ for (unsigned Inst = 0 ; Inst != VE ->getNumElements (); ++Inst) {
1501+ if (isa<InlineAsmDiscardExpr>(VE ->getElement (Inst)))
1502+ continue ;
1503+ if (emitStmt (VE ->getElement (Inst)))
1504+ return SYCLGenError ();
1505+ if ((Inst + 1 ) != VE ->getNumElements ())
1506+ OS () << " , " ;
1507+ }
1508+ OS () << " )" ;
1509+ endstmt ();
1510+ } else {
1511+ return SYCLGenError ();
1512+ }
1513+ }
1514+
1515+ OS () << MapNames::getDpctNamespace () << " experimental::matrix::mma" ;
1516+ OS () << " <" ;
1517+ OS () << M << " , " << N << " , " << K << " , " ;
1518+ OS () << ABType << " , " << CDType;
1519+ OS () << " >(" ;
1520+
1521+ OS () << " reinterpret_cast<volatile void **>(d_mat_frag_ct1)" ;
1522+ for (int i = 0 ; i < 3 ; i++)
1523+ OS () << " , &" << InMatrixName[i] << " _mat_frag_ct1" ;
1524+ OS () << " )" ;
1525+ endstmt ();
1526+ OS () << " }" ;
1527+ endstmt ();
1528+
1529+ const auto *KernelDecl = getImmediateOuterFuncDecl (GAS );
1530+ if (KernelDecl) {
1531+ auto FuncInfo = DeviceFunctionDecl::LinkRedecls (KernelDecl);
1532+ if (FuncInfo)
1533+ FuncInfo->addSubGroupSizeRequest (32 , GAS ->getBeginLoc (),
1534+ DpctGlobalInfo::getSubGroup (GAS ));
1535+ }
1536+
1537+ return SYCLGenSuccess ();
1538+ }
1539+
13731540 bool handle_prefetch (const InlineAsmInstruction *Inst) override {
13741541 if (!DpctGlobalInfo::useExtPrefetch () || Inst->getNumInputOperands () != 1 )
13751542 return SYCLGenError ();
@@ -2595,11 +2762,10 @@ class SYCLGen : public SYCLGenBase {
25952762 Op = std::move (NewOp);
25962763 }
25972764
2598- bool HasHalfOrBfloat16 =
2599- SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2600- DesType->getKind () == InlineAsmBuiltinType::f16 ||
2601- SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2602- DesType->getKind () == InlineAsmBuiltinType::bf16 ;
2765+ bool HasHalfOrBfloat16 = SrcType->getKind () == InlineAsmBuiltinType::f16 ||
2766+ DesType->getKind () == InlineAsmBuiltinType::f16 ||
2767+ SrcType->getKind () == InlineAsmBuiltinType::bf16 ||
2768+ DesType->getKind () == InlineAsmBuiltinType::bf16 ;
26032769 if (DpctGlobalInfo::useIntelDeviceMath () && HasHalfOrBfloat16) {
26042770 insertHeader (HeaderType::HT_SYCL_Math);
26052771 if (SrcNeedBitCast)
0 commit comments