Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,15 @@ class Matrix {
}

static Matrix Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = sizeof(ElementType)) {
MatrixLayoutEnum Layout, uint Align = 128) {
Matrix Result;
__builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset,
Stride, Layout, Align);
return Result;
}

static Matrix Load(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = sizeof(ElementType)) {
MatrixLayoutEnum Layout, uint Align = 128) {
Matrix Result;
__builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset,
Stride, Layout, Align);
Expand Down Expand Up @@ -331,7 +329,7 @@ class Matrix {
}

void Store(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout, uint Align = sizeof(ElementType)) {
MatrixLayoutEnum Layout, uint Align = 128) {
__builtin_LinAlg_MatrixStoreToDescriptor(__handle, Res, StartOffset, Stride,
Layout, Align);
}
Expand All @@ -351,8 +349,7 @@ class Matrix {
typename hlsl::enable_if<Use == MatrixUse::Accumulator && UseLocal == Use,
void>::type
InterlockedAccumulate(RWByteAddressBuffer Res, uint StartOffset, uint Stride,
MatrixLayoutEnum Layout,
uint Align = sizeof(ElementType)) {
MatrixLayoutEnum Layout, uint Align = 128) {
__builtin_LinAlg_MatrixAccumulateToDescriptor(__handle, Res, StartOffset,
Stride, Layout, Align);
}
Expand Down Expand Up @@ -409,8 +406,7 @@ class Matrix<ComponentTy, M, N, Use, MatrixScope::Thread> {
template <MatrixLayoutEnum Layout, MatrixUseEnum UseLocal = Use>
static typename hlsl::enable_if<Use == MatrixUse::A && UseLocal == Use,
Matrix>::type
Load(ByteAddressBuffer Res, uint StartOffset, uint Stride,
uint Align = sizeof(ElementType)) {
Load(ByteAddressBuffer Res, uint StartOffset, uint Stride, uint Align = 128) {
Matrix Result;
__builtin_LinAlg_MatrixLoadFromDescriptor(Result.__handle, Res, StartOffset,
Stride, Layout, Align);
Expand Down
16 changes: 8 additions & 8 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ void main(uint ID : SV_GroupID)
//
// CHECK: %[[MATA2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S1
// CHECK-SAME: @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S1(i32 -2147483634,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixATy MatA2 = MatrixATy::Load(BAB, 0, 16, MatrixLayoutEnum::ColMajor);

// Matrix::Load from RWByteAddressBuffer
//
// CHECK: %[[MATB2:.*]] = call %dx.types.LinAlgMatrixC9M4N4U1S1
// CHECK-SAME: @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U1S1(i32 -2147483634,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 256, i32 16, i32 1, i32 4)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 256, i32 16, i32 1, i32 128)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixBTy MatB2;
MatB2 = MatrixBTy::Load(RWBAB, 256, 16, MatrixLayoutEnum::ColMajor);

Expand All @@ -87,7 +87,7 @@ void main(uint ID : SV_GroupID)
// Matrix::GetCoordinate
//
// CHECK: call <2 x i32> @dx.op.linAlgMatrixGetCoordinate.mC9M4N4U1S1(i32 -2147483631,
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i32 %[[GROUP_ID]])
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i32 %[[GROUP_ID]])
// CHECK-SAME:; LinAlgMatrixGetCoordinate(matrix,threadLocalIndex)
uint2 coord = MatB1.GetCoordinate(ID);

Expand All @@ -110,7 +110,7 @@ void main(uint ID : SV_GroupID)
//
// CHECK: call void @dx.op.linAlgMatrixStoreToDescriptor.mC9M4N4U1S1(i32 -2147483628,
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1_2]], %dx.types.Handle %{{[0-9]+}},
// CHECK-SAME: i32 256, i32 16, i32 1, i32 4) ;
// CHECK-SAME: i32 256, i32 16, i32 1, i32 128) ;
// CHECK-SAME: LinAlgMatrixStoreToDescriptor(matrix,handle,offset,stride,layout,align)
MatB1.Store(RWBAB, 256, 16, MatrixLayoutEnum::ColMajor);

Expand All @@ -129,7 +129,7 @@ void main(uint ID : SV_GroupID)
// Matrix::InterlockedAccumulate to resource descriptor
//
// CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U2S1(i32 -2147483621,
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM0]], %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM0]], %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128)
// CHECK-SAME: ; LinAlgMatrixAccumulateToDescriptor(matrix,handle,offset,stride,layout,align)
AccMat1.InterlockedAccumulate(RWBAB, 0, 16, MatrixLayoutEnum::ColMajor);

Expand Down Expand Up @@ -160,7 +160,7 @@ void main(uint ID : SV_GroupID)
// CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB2]])
// CHECK-SAME: ; LinAlgMatrixAccumulate(matrixLHS,matrixRHS)
AccMat2.Accumulate(MatB2);

// Matrix::MultiplyAccumulate
//
// CHECK: %[[ACCUM4:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S1
Expand All @@ -174,7 +174,7 @@ void main(uint ID : SV_GroupID)
// Matrix::Load for thread-scope matrix
//
// CHECK: %[[TSMATA:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S0(
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 128)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
TSMatrixATy TSMatA = TSMatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 16);

Expand Down
14 changes: 7 additions & 7 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/vectors.hlsl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// REQUIRES: dxil-1-10
// RUN: %dxc -I %hlsl_headers -enable-16bit-types -T cs_6_10 %s | FileCheck %s
// RUN: %dxc -I %hlsl_headers -enable-16bit-types -T cs_6_10 %s | FileCheck %s

#include <dx/linalg.h>
using namespace dx::linalg;
Expand All @@ -12,16 +12,16 @@ ByteAddressBuffer BAB : register(t0);

[numthreads(4, 4, 4)]
void main(uint ID : SV_GroupID) {

// CHECK: %[[MAT1:.*]] = call %dx.types.LinAlgMatrixC8M8N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC8M8N4U0S0(
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 2)
// CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}}, i32 0, i32 8, i32 1, i32 128)
// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
MatrixATy Mat1 = MatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0, 8);

vector<half, 4> vec1 = 10.3f;

// CHECK: %[[VEC2:.*]] = call <8 x half> @dx.op.linAlgMatVecMul.v8f16.mC8M8N4U0S0.v4f16(i32 -2147483623,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> <half 0xH4926, half 0xH4926, half 0xH4926,
// CHECK-SAME: half 0xH4926>, i32 8) ; LinAlgMatVecMul(matrix,isOutputSigned,inputVector,interpretation)
vector<half, 8> vec2 = Multiply<half>(Mat1, vec1);

Expand All @@ -42,9 +42,9 @@ void main(uint ID : SV_GroupID) {

// CHECK: %[[RAWLOAD:.*]] = call %dx.types.ResRet.v8i16 @dx.op.rawBufferVectorLoad.v8i16(i32 303,
// CHECK-SAME: %dx.types.Handle %{{[0-9]+}}, i32 4096, i32 undef, i32 2) ; RawBufferVectorLoad(buf,index,elementOffset,alignment)

// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0

// CHECK: %[[VEC5:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK-SAME:; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
Expand All @@ -56,7 +56,7 @@ void main(uint ID : SV_GroupID) {
// CHECK-SAME: ; RawBufferVectorLoad(buf,index,elementOffset,alignment)

// CHECK: %[[VEC_BIAS:.*]] = extractvalue %dx.types.ResRet.v8i16 %[[RAWLOAD]], 0

// CHECK: %[[VEC6:.*]] = call <8 x half> @dx.op.linAlgMatVecMulAdd.v8f16.mC8M8N4U0S0.v4f16.v8i16(i32 -2147483622,
// CHECK-SAME: %dx.types.LinAlgMatrixC8M8N4U0S0 %[[MAT1]], i1 true, <4 x half> %[[VEC20]], i32 8, <8 x i16> %[[VEC_BIAS]], i32 2)
// CHECK-SAME: ; LinAlgMatVecMulAdd(matrix,isOutputSigned,inputVector,inputInterpretation,biasVector,biasInterpretation)
Expand Down