Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions include/dxc/dxcapi.internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ static const BYTE IA_R = 0xf0;
static const BYTE IA_C = 0xf1;
static const BYTE IA_R2 = 0xf2;
static const BYTE IA_C2 = 0xf3;
static const BYTE IA_SPECIAL_SLOTS = 4;
static const BYTE IA_R3 = 0xf4;
static const BYTE IA_C3 = 0xf5;
static const BYTE IA_SPECIAL_SLOTS = 6;

struct HLSL_INTRINSIC_ARGUMENT {
LPCSTR
Expand All @@ -180,10 +182,12 @@ struct HLSL_INTRINSIC_ARGUMENT {
BYTE uLegalComponentTypes; // A LEGAL_INTRINSIC_COMPTYPES value for allowed
// components.

BYTE uRows; // Required number of rows, or one of IA_R/IA_C/IA_R2/IA_C2 for
// matching input constraints.
BYTE uCols; // Required number of cols, or one of IA_R/IA_C/IA_R2/IA_C2 for
// matching input constraints.
BYTE uRows; // Required number of rows, or one of
// IA_R/IA_C/IA_R2/IA_C2/IA_R3/IA_C3 for matching input
// constraints.
BYTE uCols; // Required number of cols, or one of
// IA_R/IA_C/IA_R2/IA_C2/IA_R3/IA_C3 for matching input
// constraints.
};

// HLSL_INTRINSIC flags
Expand Down
23 changes: 15 additions & 8 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,26 +519,29 @@ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
return Result;
}

template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasInterp,
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
vector<OutputElTy, M> >::type
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef) {
vector<InputElTy, K> Vec, VectorRef<BiasInterp, M> BiasRef) {
using BiasVecTy =
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
(M + __detail::ComponentTypeTraits<BiasInterp>::ElementsPerScalar -
1) /
__detail::ComponentTypeTraits<BiasInterp>::ElementsPerScalar>;
BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(Result, MatrixA.__handle,
hlsl::is_signed<OutputElTy>::value,
Vec, MatrixDT, BiasVec, BiasElTy);
Vec, MatrixDT, BiasVec, BiasInterp);
return Result;
}

template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum BiasInterp, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
ComponentEnum MatrixDT>
// clang-format off
typename hlsl::enable_if<
Expand All @@ -547,14 +550,18 @@ typename hlsl::enable_if<
// clang-format on
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
VectorRef<BiasElTy, M> BiasRef) {
VectorRef<BiasInterp, M> BiasRef) {
using BiasVecTy =
vector<typename __detail::ComponentTypeTraits<BiasElTy>::Type, M>;
vector<typename __detail::ComponentTypeTraits<BiasInterp>::Type,
(M + __detail::ComponentTypeTraits<BiasInterp>::ElementsPerScalar -
1) /
__detail::ComponentTypeTraits<BiasInterp>::ElementsPerScalar>;

BiasVecTy BiasVec = BiasRef.Buf.template Load<BiasVecTy>(BiasRef.Offset);
vector<OutputElTy, M> Result;
__builtin_LinAlg_MatrixVectorMultiplyAdd(
Result, MatrixA.__handle, hlsl::is_signed<OutputElTy>::value,
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasElTy);
InterpVec.Data, InterpVec.Interpretation, BiasVec, BiasInterp);
return Result;
}

Expand Down
35 changes: 34 additions & 1 deletion tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using MatrixATy = Matrix<ComponentType::F16, 8, 4, MatrixUse::A, MatrixScope::Th
using MatrixAccum_8_8_Ty = Matrix<ComponentType::F16, 8, 8, MatrixUse::Accumulator, MatrixScope::Thread>;
using MatrixAccum_8_4_Ty = Matrix<ComponentType::F16, 8, 4, MatrixUse::Accumulator, MatrixScope::Thread>;
using Matrix_7_15_ATy = Matrix<ComponentType::F16, 7, 15, MatrixUse::A, MatrixScope::Thread>;
using MatrixPacked_7_15_ATy = Matrix<ComponentType::F8_E4M3FN, 7, 15, MatrixUse::A, MatrixScope::Thread>;

ByteAddressBuffer BAB : register(t0);

Expand Down Expand Up @@ -83,7 +84,7 @@ void main(uint ID : SV_GroupID) {
half16 srcF16 = BAB.Load<half16>(128);
InterpretedVector<uint, 4, ComponentEnum::F8_E4M3FN> convertedPacked = Convert<ComponentEnum::F8_E4M3FN, ComponentEnum::F16>(srcF16);

// CHECK: call <1 x i32> @dx.op.linAlgConvert.v1i32.v3f16(i32 -2147483618, <3 x half> %25, i32 8, i32 21)
// CHECK: call <1 x i32> @dx.op.linAlgConvert.v1i32.v3f16(i32 -2147483618, <3 x half> %{{[0-9]+}}, i32 8, i32 21)
// CHECK-SAME: ; LinAlgConvert(inputVector,inputInterpretation,outputInterpretation)
half3 ThreeF16 = BAB.Load<half3>(256);
InterpretedVector<uint, 1, ComponentEnum::F8_E4M3FN> convertedPacked2 =
Expand Down Expand Up @@ -144,4 +145,36 @@ void main(uint ID : SV_GroupID) {
// CHECK-SAME: %dx.types.LinAlgMatrixC8M7N15U0S0 %[[MAT_7_15]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %[[MEM_BIAS3]], i32 8)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec12 = MultiplyAdd<half>(Mat_7_15, interpVecH15Packed, memBias7);

// Test Convert and MultiplyAdd with odd sizes and packed types

// CHECK: %[[MAT_7_15_PACKED:.*]] = call %dx.types.LinAlgMatrixC21M7N15U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC21M7N15U0S0(i32 -2147483634,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128) ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixPacked_7_15_ATy MatF8_7_15 = MatrixPacked_7_15_ATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v15f16.v7f16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <15 x half> %{{[0-9]+}}, i32 21, <7 x half> %{{[0-9]+}}, i32 21)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec21 = MultiplyAdd<half>(MatF8_7_15, vecH15, vecH7);

// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v4i32.v7f16(i32 -2147483622, %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]],
// CHECK-SAME: i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <7 x half> %{{[0-9]+}}, i32 21)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec22 = MultiplyAdd<half>(MatF8_7_15, interpVecH15Packed, vecH7);

// CHECK: %[[LOAD4:.*]] = call %dx.types.ResRet.v2i32 @dx.op.rawBufferVectorLoad.v2i32(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 4)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS_PACKED1:.*]] = extractvalue %dx.types.ResRet.v2i32 %[[LOAD4]], 0
// CHECK: call <7 x half> @dx.op.linAlgMatVecMulAdd.v7f16.mC21M7N15U0S0.v15f16.v2i32(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <15 x half> %{{[0-9]+}}, i32 21, <2 x i32> %[[MEM_BIAS_PACKED1]], i32 21)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
VectorRef<ComponentType::F8_E4M3FN, 7> memBias7Packed = {BAB, 512};
vector<half, 7> vec23 = MultiplyAdd<half>(MatF8_7_15, vecH15, memBias7Packed);

// CHECK: [[LOAD5:.8]] = call %dx.types.ResRet.v2i32 @dx.op.rawBufferVectorLoad.v2i32(i32 303, %dx.types.Handle %{{[0-9]+}}, i32 512, i32 undef, i32 4)
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)
// CHECK: %[[MEM_BIAS_PACKED1:.*]] = extractvalue %dx.types.ResRet.v2i32 %[[LOAD5]], 0
// CHECK-NEXT: %dx.types.LinAlgMatrixC21M7N15U0S0 %[[MAT_7_15_PACKED]], i1 true, <4 x i32> %[[INTERP_VEC_H15_PACKED]], i32 21, <2 x i32> %[[MEM_BIAS_PACKED1]], i32 21)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
vector<half, 7> vec24 = MultiplyAdd<half>(MatF8_7_15, interpVecH15Packed, memBias7Packed);
}
2 changes: 1 addition & 1 deletion utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matr
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(ref LinAlgMatrix matrixR, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB, in LinAlgMatrix matrixC);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(ref LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<c> ret, in LinAlgMatrix mat, in bool isOutputSigned, in numeric<c2> input, in uint inputInterp, in numeric<c3> bias, in uint biasInterp);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout, in uint align);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, groupshared numeric[] memory, in uint offset, in uint stride, in uint layout);
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<c> vecA, in numeric<c2> vecB);
Expand Down
2 changes: 1 addition & 1 deletion utils/hct/hctdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9467,7 +9467,7 @@ def __init__(self, intrinsic_defs, opcode_data):
"LinAlg": "LICOMPTYPE_LINALG",
}

self.trans_rowcol = {"r": "IA_R", "c": "IA_C", "r2": "IA_R2", "c2": "IA_C2"}
self.trans_rowcol = {"r": "IA_R", "c": "IA_C", "r2": "IA_R2", "c2": "IA_C2", "r3": "IA_R3", "c3": "IA_C3"}
self.param_qual = {
"in": "AR_QUAL_IN",
"inout": "AR_QUAL_IN | AR_QUAL_OUT",
Expand Down
Loading