Skip to content

Commit aefe236

Browse files
Fixed format & addressed comments
1 parent 6c61f6d commit aefe236

3 files changed

Lines changed: 14 additions & 10 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,13 +1486,16 @@ class SYCLGen : public SYCLGenBase {
14861486
OS() << " }";
14871487
endstmt();
14881488

1489-
// Declare and init vectors for storing the values of A, B & C matrix elements
1489+
// Declare and init vectors for storing the values of A, B & C matrix
1490+
// elements
14901491
std::string InMatrixName[3] = {"A", "B", "C"};
14911492
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
14921493
InputOp++) {
14931494
if (auto VE =
14941495
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1495-
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " << VE->getNumElements() << "> " << InMatrixName[InputOp] << "Matrix_ct1(";
1496+
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", "
1497+
<< VE->getNumElements() << "> " << InMatrixName[InputOp]
1498+
<< "Matrix_ct1(";
14961499
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
14971500
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
14981501
continue;
@@ -1516,7 +1519,8 @@ class SYCLGen : public SYCLGenBase {
15161519

15171520
OS() << "DMatrix_ct1";
15181521
for (int i = 0; i < 3; i++)
1519-
OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" << InMatrixName[i] << "Matrix_ct1)";
1522+
OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&"
1523+
<< InMatrixName[i] << "Matrix_ct1)";
15201524
OS() << ")";
15211525
endstmt();
15221526
OS() << "}";

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -757,9 +757,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
757757
// Vector size must be 2, 4, or 8.
758758
switch (Vec.size()) {
759759
case 1:
760-
Kind = InlineAsmVectorType::v1;
761-
break;
762-
case 2:
760+
Kind = InlineAsmVectorType::v1;
761+
break;
762+
case 2:
763763
Kind = InlineAsmVectorType::v2;
764764
break;
765765
case 4:

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2227,23 +2227,23 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22272227
/// \tparam [in] MulType The type used to multiply A and B matrix elements as
22282228
/// \tparam [in] ABType The type of the input matrix (A & B) elements
22292229
/// \tparam [in] CDType The type of the output matrix (C & D) elements
2230-
/// \param [in] d The elements of the output D matrix to store the result to
2230+
/// \param [out] d The elements of the output D matrix to store the result to
22312231
/// \param [in] a The elements of the input A matrix to be multiplied with B
22322232
/// matrix elements
22332233
/// \param [in] b The elements of the input B matrix to be multiplied with A
22342234
/// matrix elements
22352235
/// \param [in] c The elements of the input C matrix to be added with the result
22362236
/// of A * B
22372237
template <int M, int N, int K, typename MulType, typename ABType,
2238-
typename CDType, typename Op = sycl::bit_and<>>
2239-
void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) {
2238+
typename CDType>
2239+
void mma(CDType **d, ABType *a, ABType *b, CDType *c) {
22402240
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
22412241
int lane = sg.get_local_linear_id();
22422242

22432243
short ROW_LOAD_OFFSET = 4 * (lane >> 2);
22442244
short COL_LOAD_OFFSET = 8 * (lane % 4);
22452245

2246-
if (M == 16 && N == 8 && K == 16) {
2246+
if constexpr (M == 16 && N == 8 && K == 16) {
22472247
if constexpr (std::is_floating_point_v<CDType>) {
22482248
// f32.f16.f16.f32
22492249
for (int i = 0; i < 4; i++) {

0 commit comments

Comments
 (0)