diff --git a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp index 10dfc63b37..8715ea1ac4 100644 --- a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp +++ b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.cpp @@ -691,7 +691,30 @@ void addUAVBuffer(st::ShaderOp *Op, const char *Name, UINT64 Width, Op->Resources.push_back(Res); } -void addRootUAV(st::ShaderOp *Op, UINT Index, const char *ResName) { +void addSRVBuffer(st::ShaderOp *Op, const char *Name, UINT64 Width, + const char *Init) { + st::ShaderOpResource Res = {}; + Res.Name = Op->Strings.insert(Name); + Res.Init = Op->Strings.insert(Init); + Res.ReadBack = FALSE; + + Res.HeapProperties.Type = D3D12_HEAP_TYPE_DEFAULT; + Res.HeapFlags = D3D12_HEAP_FLAG_NONE; + Res.Desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + Res.Desc.Width = Width; + Res.Desc.Height = 1; + Res.Desc.DepthOrArraySize = 1; + Res.Desc.MipLevels = 1; + Res.Desc.SampleDesc.Count = 1; + Res.Desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + Res.Desc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; + Res.InitialResourceState = D3D12_RESOURCE_STATE_COPY_DEST; + Res.TransitionTo = D3D12_RESOURCE_STATE_NON_PIXEL_SHADER_RESOURCE; + + Op->Resources.push_back(Res); +} + +void addRootView(st::ShaderOp *Op, UINT Index, const char *ResName) { st::ShaderOpRootValue RV = {}; RV.ResName = Op->Strings.insert(ResName); RV.HeapName = nullptr; @@ -751,7 +774,7 @@ void compileShader(dxc::SpecificDllLoader &DxcSupport, const char *Source, if (VerboseLogging) { hlsl_test::LogCommentFmt(L"Shader Source:"); - hlsl_test::LogCommentFmt(L"%c", Source); + hlsl_test::LogCommentFmt(L"%S", Source); } hlsl_test::LogCommentFmt(LogFlags.str().c_str()); diff --git a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.h b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.h index 3e64171478..435a091a54 100644 --- a/tools/clang/unittests/HLSLExec/HlslExecTestUtils.h +++ b/tools/clang/unittests/HLSLExec/HlslExecTestUtils.h @@ -88,8 +88,12 @@ createComputeOp(const char *Source, const char *Target, const char *RootSig, void addUAVBuffer(st::ShaderOp *Op, const char *Name, UINT64 Width, bool ReadBack, const char *Init = "zero"); -/// Bind a resource to a root UAV parameter by index. -void addRootUAV(st::ShaderOp *Op, UINT Index, const char *ResName); +/// Add a SRV buffer resource to a ShaderOp. +void addSRVBuffer(st::ShaderOp *Op, const char *Name, UINT64 Width, + const char *Init = "zero"); + +/// Bind a resource to a root view parameter by index. +void addRootView(st::ShaderOp *Op, UINT Index, const char *ResName); /// Run a programmatically-built ShaderOp and return the result. std::shared_ptr diff --git a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp index da32f553c4..d9d22863f1 100644 --- a/tools/clang/unittests/HLSLExec/LinAlgTests.cpp +++ b/tools/clang/unittests/HLSLExec/LinAlgTests.cpp @@ -199,38 +199,48 @@ static bool verifyComponentBuffer(ComponentType CompType, const void *Actual, } static bool fillInputBuffer(LPCSTR Name, std::vector &Data, - ComponentType CompType, size_t NumElements) { + ComponentType CompType, size_t NumElements, + size_t StartingVal = 1, bool Increment = true) { if (_stricmp(Name, "Input") != 0) return true; switch (CompType) { - case ComponentType::F32: { - float *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = static_cast(I + 1); - return true; - } - case ComponentType::I32: { - int32_t *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = static_cast(I + 1); - return true; - } - case ComponentType::F16: { - HLSLHalf_t *Ptr = reinterpret_cast(Data.data()); - for (size_t I = 0; I < NumElements; I++) - Ptr[I] = HLSLHalf_t(static_cast(I + 1)); - return true; + case ComponentType::F32: + case ComponentType::I32: + case ComponentType::F16: + break; + default: + return false; } + + for (size_t I = 0; I < NumElements; ++I) { + size_t Value = StartingVal + (Increment ? I : 0); + switch (CompType) { + case ComponentType::F32: { + float *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = static_cast(Value); + break; + } + case ComponentType::I32: { + int32_t *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = static_cast(Value); + break; + } + case ComponentType::F16: { + HLSLHalf_t *Ptr = reinterpret_cast(Data.data()); + Ptr[I] = HLSLHalf_t(static_cast(Value)); + break; + } + } } - return false; + return true; } -static VariantCompType makeExpected(ComponentType CompType, MatrixDim M, - MatrixDim N, float StartingVal, - bool Increment = true, - bool Transpose = false) { +static VariantCompType makeExpectedMat(ComponentType CompType, MatrixDim M, + MatrixDim N, float StartingVal, + bool Increment = true, + bool Transpose = false) { const size_t NumElements = M * N; std::vector Floats(NumElements); std::vector Ints(NumElements); @@ -281,6 +291,13 @@ static VariantCompType makeExpected(ComponentType CompType, MatrixDim M, } } +static VariantCompType makeExpectedVec(ComponentType CompType, + MatrixDim NumElements, float StartingVal, + bool Increment = true) { + return makeExpectedMat(CompType, 1, NumElements, StartingVal, Increment, + false); +} + class DxilConf_SM610_LinAlg { public: BEGIN_TEST_CLASS(DxilConf_SM610_LinAlg) @@ -299,14 +316,40 @@ class DxilConf_SM610_LinAlg { TEST_CLASS_SETUP(setupClass); TEST_METHOD_SETUP(setupMethod); - // Load/Store - TEST_METHOD(LoadStoreRoundtrip_Wave_16x16_F16); - - // Splat Store + // Load/Store/Accumulate Descriptor + TEST_METHOD(LoadStoreDescriptor_Wave_16x16_F16); TEST_METHOD(SplatStore_Wave_16x16_F16); + TEST_METHOD(AccumulateDescriptor_Wave_16x16_F16); + TEST_METHOD(AccumulateDescriptor_Thread_16x16_F16); + + // Load/Store/Accumulate Memory + TEST_METHOD(LoadMemory_Wave_16x16_F16); + TEST_METHOD(StoreMemory_Wave_16x16_F16); + TEST_METHOD(AccumulateMemory_Wave_16x16_F16); // Element access TEST_METHOD(ElementAccess_Wave_16x16_F16); + TEST_METHOD(ElementSet_Wave_16x16_F16); + + // Cast/Convert + TEST_METHOD(CopyConvert_Wave_16x16_F16); + TEST_METHOD(CopyConvert_Wave_16x16_F16_Transpose); + + // Matrix Matrix Arithmetic + TEST_METHOD(MatMatMul_Wave_16x16x16_F16); + TEST_METHOD(MatMatMulAccum_Wave_16x16x16_F16); + TEST_METHOD(MatAccum_Wave_16x16_F16); + + // Matrix Vector Arithmetic + TEST_METHOD(MatVecMul_Thread_16x16_F16); + TEST_METHOD(MatVecMulAdd_Thread_16x16_F16); + TEST_METHOD(OuterProduct_Thread_16x16_F16); + + // Query Accumulator Layout + TEST_METHOD(QueryAccumLayout); + + // Convert + TEST_METHOD(Convert); private: CComPtr D3DDevice; @@ -358,14 +401,14 @@ bool DxilConf_SM610_LinAlg::setupMethod() { return D3D12SDK->createDevice(&D3DDevice, D3D_SHADER_MODEL_6_10, false); } -static const char LoadStoreShader[] = R"( +static const char LoadStoreDescriptorShader[] = R"( RWByteAddressBuffer Input : register(u0); RWByteAddressBuffer Output : register(u1); [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] - void main(uint threadID : SV_GroupIndex) { - if (WaveReadLaneFirst(threadID) != 0) + void main() { + if (GetGroupWaveIndex() != 0) return; __builtin_LinAlgMatrix @@ -378,9 +421,9 @@ static const char LoadStoreShader[] = R"( } )"; -static void runLoadStoreRoundtrip(ID3D12Device *Device, - dxc::SpecificDllLoader &DxcSupport, - const MatrixParams &Params, bool Verbose) { +static void runLoadStoreDescriptor(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t BufferSize = Params.totalBytes(); @@ -390,17 +433,18 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device, std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); - compileShader(DxcSupport, LoadStoreShader, "cs_6_10", Args, Verbose); + compileShader(DxcSupport, LoadStoreDescriptorShader, "cs_6_10", Args, + Verbose); - auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1); // Construct the ShaderOp: two UAV buffers, load from one, store to other. - auto Op = createComputeOp(LoadStoreShader, "cs_6_10", "UAV(u0), UAV(u1)", - Args.c_str()); + auto Op = createComputeOp(LoadStoreDescriptorShader, "cs_6_10", + "UAV(u0), UAV(u1)", Args.c_str()); addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); addUAVBuffer(Op.get(), "Output", BufferSize, true); - addRootUAV(Op.get(), 0, "Input"); - addRootUAV(Op.get(), 1, "Output"); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); auto Result = runShaderOp(Device, DxcSupport, std::move(Op), @@ -418,7 +462,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device, Expected, NumElements, Verbose)); } -void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { +void DxilConf_SM610_LinAlg::LoadStoreDescriptor_Wave_16x16_F16() { MatrixParams Params = {}; Params.CompType = ComponentType::F16; Params.M = 16; @@ -428,7 +472,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() { Params.Layout = LinalgMatrixLayout::RowMajor; Params.NumThreads = 64; Params.Enable16Bit = true; - runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging); + runLoadStoreDescriptor(D3DDevice, DxcSupport, Params, VerboseLogging); } static const char SplatStoreShader[] = R"( @@ -436,8 +480,8 @@ static const char SplatStoreShader[] = R"( [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] - void main(uint threadID : SV_GroupIndex) { - if (WaveReadLaneFirst(threadID) != 0) + void main() { + if (GetGroupWaveIndex() != 0) return; __builtin_LinAlgMatrix @@ -464,12 +508,12 @@ static void runSplatStore(ID3D12Device *Device, compileShader(DxcSupport, SplatStoreShader, "cs_6_10", Args, Verbose); auto Expected = - makeExpected(Params.CompType, Params.M, Params.N, FillValue, false); + makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue, false); auto Op = createComputeOp(SplatStoreShader, "cs_6_10", "UAV(u0)", Args.c_str()); addUAVBuffer(Op.get(), "Output", BufferSize, true); - addRootUAV(Op.get(), 0, "Output"); + addRootView(Op.get(), 0, "Output"); auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); @@ -493,6 +537,95 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() { runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging); } +static const char AccumulateDescriptorShader[] = R"( + #define USE_ACC 2 + + ByteAddressBuffer Input : register(t0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + __builtin_LinAlg_MatrixAccumulateToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + __builtin_LinAlg_MatrixAccumulateToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runAccumulateDescriptor(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, int FillValue, + bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::string Args = buildCompilerArgs(Params); + + compileShader(DxcSupport, AccumulateDescriptorShader, "cs_6_10", Args, + Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + static_cast(FillValue) * 2, false); + + auto Op = createComputeOp(AccumulateDescriptorShader, "cs_6_10", + "SRV(t0), UAV(u1)", Args.c_str()); + addSRVBuffer(Op.get(), "Input", BufferSize, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumElements, Params, FillValue](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements, + /*StartingVal=*/FillValue, + /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::AccumulateDescriptor_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 12, VerboseLogging); +} + +void DxilConf_SM610_LinAlg::AccumulateDescriptor_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runAccumulateDescriptor(D3DDevice, DxcSupport, Params, 19, VerboseLogging); +} + static const char ElementAccessShader[] = R"( RWByteAddressBuffer Input : register(u0); RWByteAddressBuffer Output : register(u1); @@ -506,7 +639,7 @@ static const char ElementAccessShader[] = R"( [WaveSize(4, 64)] [numthreads(NUMTHREADS, 1, 1)] void main(uint threadID : SV_GroupIndex) { - if (WaveReadLaneFirst(threadID) != 0) + if (GetGroupWaveIndex() != 0) return; __builtin_LinAlgMatrix @@ -537,28 +670,23 @@ static void runElementAccess(ID3D12Device *Device, const MatrixParams &Params, bool Verbose) { const size_t NumElements = Params.totalElements(); const size_t NumThreads = Params.NumThreads; - const size_t InputBufSize = Params.totalBytes(); - const size_t ElementSize = elementSize(Params.CompType); - - // Output: ElementSize bytes per element - // 1 element for each mat idx - // 1 uint for each thread's length - const size_t OutputBufSize = - NumElements * ElementSize + NumThreads * sizeof(uint32_t); + const size_t MatrixSize = Params.totalBytes(); + // OutputBuf needs to fit the Matrix plus one uint per thread + const size_t OutputBufSize = MatrixSize + NumThreads * sizeof(uint32_t); std::stringstream ExtraDefs; std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose); - auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1); + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1); auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)", Args.c_str()); - addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname"); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); addUAVBuffer(Op.get(), "Output", OutputBufSize, true); - addRootUAV(Op.get(), 0, "Input"); - addRootUAV(Op.get(), 1, "Output"); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); auto Result = runShaderOp(Device, DxcSupport, std::move(Op), @@ -579,9 +707,8 @@ static void runElementAccess(ID3D12Device *Device, // Verify the end of the buffer is NumThreads number of lengths, whose // sum is greater than or equal to NumElements const BYTE *Out = static_cast(OutData.data()); - size_t MatrixEndOffset = NumElements * ElementSize; const uint32_t *Lengths = - reinterpret_cast(Out + MatrixEndOffset); + reinterpret_cast(Out + MatrixSize); uint32_t TotalLength = 0; for (size_t I = 0; I < NumThreads; ++I) TotalLength += Lengths[I]; @@ -602,4 +729,996 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() { runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging); } +static const char ElementSetShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + + // Increment every element by 5 + for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) { + ELEM_TYPE Elem; + __builtin_LinAlg_MatrixGetElement(Elem, Mat, I); + Elem = Elem + 5; + __builtin_LinAlg_MatrixSetElement(Mat, Mat, I, Elem); + } + + __builtin_LinAlg_MatrixStoreToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runElementSet(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t MatrixSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, ElementSetShader, "cs_6_10", Args, Verbose); + + // Start counting from 6 since each element was increased by 5 + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 6); + + auto Op = createComputeOp(ElementSetShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", MatrixSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", MatrixSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + // Verify the front of the buffer is a list of elements of the expected type + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::ElementSet_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runElementSet(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char CopyConvertShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Src; + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, N_DIM, M_DIM, USE, SCOPE)]] + Dst; + + __builtin_LinAlg_MatrixLoadFromDescriptor( + Src, Input, 0, STRIDE, LAYOUT, 128); + __builtin_LinAlg_CopyConvertMatrix(Dst, Src, TRANSPOSE); + __builtin_LinAlg_MatrixStoreToDescriptor( + Dst, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runCopyConvert(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + bool Transpose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DTRANSPOSE=" << Transpose; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, CopyConvertShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1, + /*Increment=*/true, Transpose); + + // Construct the ShaderOp: two UAV buffers, load from one, store to other. + auto Op = createComputeOp(CopyConvertShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/false); +} + +void DxilConf_SM610_LinAlg::CopyConvert_Wave_16x16_F16_Transpose() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runCopyConvert(D3DDevice, DxcSupport, Params, VerboseLogging, + /*Transpose=*/true); +} + +static const char MatMatMulShader[] = R"( + #define USE_A 0 + #define USE_B 1 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]] + MatA; + __builtin_LinAlg_FillMatrix(MatA, A_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]] + MatB; + __builtin_LinAlg_FillMatrix(MatB, B_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatC; + __builtin_LinAlg_MatrixMatrixMultiply(MatC, MatA, MatB); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatC, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatMatMul(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, MatrixDim K, + float AFill, float BFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DK_DIM=" << K; + ExtraDefs << " -DA_FILL=" << AFill; + ExtraDefs << " -DB_FILL=" << BFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatMatMulShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + AFill * BFill * K, /*Increment=*/false); + + auto Op = + createComputeOp(MatMatMulShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatMatMul_Wave_16x16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatMatMul(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, + /*AFill=*/2.0f, /*BFill=*/3.0f); +} + +static const char MatMatMulAccumShader[] = R"( + #define USE_A 0 + #define USE_B 1 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, K_DIM, USE_A, SCOPE)]] + MatA; + __builtin_LinAlg_FillMatrix(MatA, A_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, K_DIM, N_DIM, USE_B, SCOPE)]] + MatB; + __builtin_LinAlg_FillMatrix(MatB, B_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatC; + __builtin_LinAlg_FillMatrix(MatC, C_FILL); + + __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(MatC, MatA, MatB, MatC); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatC, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatMatMulAccum(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + MatrixDim K, float AFill, float BFill, + float CFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DK_DIM=" << K; + ExtraDefs << " -DA_FILL=" << AFill; + ExtraDefs << " -DB_FILL=" << BFill; + ExtraDefs << " -DC_FILL=" << CFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatMatMulAccumShader, "cs_6_10", Args, Verbose); + + auto Expected = + makeExpectedMat(Params.CompType, Params.M, Params.N, + AFill * BFill * K + CFill, /*Increment=*/false); + + auto Op = + createComputeOp(MatMatMulAccumShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatMatMulAccum_Wave_16x16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatMatMulAccum(D3DDevice, DxcSupport, Params, VerboseLogging, /*K=*/16, + /*AFill=*/2.0f, /*BFill=*/3.0f, /*CFill=*/4.0f); +} + +static const char MatAccumShader[] = R"( + #define USE_A 0 + #define USE_ACC 2 + + RWByteAddressBuffer Output : register(u0); + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_ACC, SCOPE)]] + MatLHS; + __builtin_LinAlg_FillMatrix(MatLHS, LHS_FILL); + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE)]] + MatRHS; + __builtin_LinAlg_FillMatrix(MatRHS, RHS_FILL); + + __builtin_LinAlg_MatrixAccumulate(MatLHS, MatLHS, MatRHS); + + __builtin_LinAlg_MatrixStoreToDescriptor( + MatLHS, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runMatAccum(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, float LHSFill, + float RHSFill) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DLHS_FILL=" << LHSFill; + ExtraDefs << " -DRHS_FILL=" << RHSFill; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatAccumShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + LHSFill + RHSFill, /*Increment=*/false); + + auto Op = createComputeOp(MatAccumShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatAccum_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runMatAccum(D3DDevice, DxcSupport, Params, VerboseLogging, + /*LHSFill=*/2.0f, /*RHSFill=*/3.0f); +} + +static const char MatVecMulShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + ByteAddressBuffer Input : register(t0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + + vector InVec; + for (uint I = 0; I < M_DIM; ++I) { + InVec[I] = Input.Load(I * ELEM_SIZE); + } + + vector OutVec; + __builtin_LinAlg_MatrixVectorMultiply( + OutVec, Mat, OUTPUT_SIGNED, InVec, IN_INTERP); + + for (uint I = 0; I < M_DIM; ++I) { + Output.Store(I * ELEM_SIZE, OutVec[I]); + } + } +)"; + +static void runMatVecMul(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + int FillValue, bool OutputSigned, + ComponentType InputInterp) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned; + ExtraDefs << " -DIN_INTERP=" << static_cast(InputInterp); + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatVecMulShader, "cs_6_10", Args, Verbose); + + auto Expected = + makeExpectedVec(Params.CompType, Params.M, + static_cast(FillValue * FillValue * Params.N), + /*Increment=*/false); + + auto Op = createComputeOp(MatVecMulShader, "cs_6_10", "SRV(t0), UAV(u1)", + Args.c_str()); + addSRVBuffer(Op.get(), "Input", BufferSize, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumElements, Params, FillValue](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements, + /*StartingVal=*/FillValue, + /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, Params.M, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatVecMul_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runMatVecMul(D3DDevice, DxcSupport, Params, VerboseLogging, + /*FillValue=*/2, /*OutputSigned=*/true, ComponentType::F16); +} + +static const char MatVecMulAddShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + ByteAddressBuffer Input : register(t0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_MatrixLoadFromDescriptor( + Mat, Input, 0, STRIDE, LAYOUT, 128); + + vector InVec; + for (uint I = 0; I < M_DIM; ++I) { + InVec[I] = Input.Load(I * ELEM_SIZE); + } + + vector BiasVec; + for (uint I = 0; I < M_DIM; ++I) { + BiasVec[I] = Input.Load(I * ELEM_SIZE); + } + + vector OutVec; + __builtin_LinAlg_MatrixVectorMultiplyAdd( + OutVec, Mat, OUTPUT_SIGNED, InVec, IN_INTERP, BiasVec, BIAS_INTERP); + + for (uint I = 0; I < M_DIM; ++I) { + Output.Store(I * ELEM_SIZE, OutVec[I]); + } + } +)"; + +static void runMatVecMulAdd(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + int FillValue, bool OutputSigned, + ComponentType InputInterp, + ComponentType BiasInterp) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DOUTPUT_SIGNED=" << OutputSigned; + ExtraDefs << " -DIN_INTERP=" << static_cast(InputInterp); + ExtraDefs << " -DBIAS_INTERP=" << static_cast(BiasInterp); + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, MatVecMulAddShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedVec( + Params.CompType, Params.M, + static_cast(FillValue * FillValue * Params.N + FillValue), + /*Increment=*/false); + + auto Op = createComputeOp(MatVecMulAddShader, "cs_6_10", "SRV(t0), UAV(u1)", + Args.c_str()); + addSRVBuffer(Op.get(), "Input", BufferSize, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumElements, Params, FillValue](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, NumElements, + /*StartingVal=*/FillValue, + /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, Params.M, Verbose)); +} + +void DxilConf_SM610_LinAlg::MatVecMulAdd_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runMatVecMulAdd(D3DDevice, DxcSupport, Params, VerboseLogging, + /*FillValue=*/2, /*OutputSigned=*/true, ComponentType::F16, + ComponentType::F16); +} + +static const char OuterProductShader[] = R"( + #define USE_A 0 + #define SCOPE_THREAD 0 + + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + + [numthreads(NUMTHREADS, 1, 1)] + void main() { + vector VecA; + for (uint I = 0; I < M_DIM; ++I) { + VecA[I] = Input.Load(I * ELEM_SIZE); + } + + uint EndVecA = M_DIM * ELEM_SIZE; + + vector VecB; + for (uint I = 0; I < N_DIM; ++I) { + VecB[I] = Input.Load(EndVecA + I * ELEM_SIZE); + } + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE_A, SCOPE_THREAD)]] + Mat; + __builtin_LinAlg_MatrixOuterProduct(Mat, VecA, VecB); + + __builtin_LinAlg_MatrixAccumulateToDescriptor( + Mat, Output, 0, STRIDE, LAYOUT, 128); + } +)"; + +static void runOuterProduct(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumVecElements = Params.M + Params.N; + const size_t InBuffSize = NumVecElements * elementSize(Params.CompType); + const size_t NumMatElements = Params.totalElements(); + const size_t OutBufferSize = Params.totalBytes(); + + std::string Args = buildCompilerArgs(Params); + + compileShader(DxcSupport, OuterProductShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 4, + /*Increment=*/false); + + auto Op = createComputeOp(OuterProductShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", InBuffSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", OutBufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = runShaderOp( + Device, DxcSupport, std::move(Op), + [NumVecElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumVecElements, + /*StartingVal=*/2, /*Increment=*/false), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumMatElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::OuterProduct_Thread_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Scope = MatrixScope::Thread; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 1; + Params.Enable16Bit = true; + runOuterProduct(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char QueryAccumLayoutShader[] = R"( + RWByteAddressBuffer Output : register(u0); + + [numthreads(1, 1, 1)] + void main() { + uint Layout = __builtin_LinAlg_MatrixQueryAccumulatorLayout(); + Output.Store(0, Layout); + } +)"; + +static void runQueryAccumLayout(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + bool Verbose) { + std::string Args = "-HV 202x"; + size_t BufferSize = elementSize(ComponentType::I32); + + compileShader(DxcSupport, QueryAccumLayoutShader, "cs_6_10", Args, Verbose); + + auto Op = createComputeOp(QueryAccumLayoutShader, "cs_6_10", "UAV(u0)", + Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + const uint32_t *Out = static_cast(OutData.data()); + + // Accum Layout must be A or B + VERIFY_IS_TRUE(Out[0] == static_cast(MatrixUse::A) || + Out[0] == static_cast(MatrixUse::B)); + if (Verbose) + hlsl_test::LogCommentFmt(L"AccumulatorLayout = %u", Out[0]); +} + +void DxilConf_SM610_LinAlg::QueryAccumLayout() { + runQueryAccumLayout(D3DDevice, DxcSupport, VerboseLogging); +} + +static const char LoadMemoryShader[] = R"( + RWByteAddressBuffer Input : register(u0); + RWByteAddressBuffer Output : register(u1); + groupshared ELEM_TYPE GsData[M_DIM * N_DIM]; + + #define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS) + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + for (uint I = 0; I < ELEM_PER_THREAD; ++I) { + uint Index = threadID * ELEM_PER_THREAD + I; + GsData[Index] = Input.Load(Index * ELEM_SIZE); + } + + GroupMemoryBarrierWithGroupSync(); + + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_MatrixLoadFromMemory( + Mat, GsData, OFFSET, STRIDE, LAYOUT); + __builtin_LinAlg_MatrixStoreToDescriptor( + Mat, Output, OFFSET, STRIDE, LAYOUT, 128); + } +)"; + +static void runLoadMemory(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DOFFSET=" << 0; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, LoadMemoryShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, 1); + + auto Op = createComputeOp(LoadMemoryShader, "cs_6_10", "UAV(u0), UAV(u1)", + Args.c_str()); + addUAVBuffer(Op.get(), "Input", BufferSize, false, "byname"); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Input"); + addRootView(Op.get(), 1, "Output"); + + auto Result = + runShaderOp(Device, DxcSupport, std::move(Op), + [NumElements, Params](LPCSTR Name, std::vector &Data, + st::ShaderOp *) { + VERIFY_IS_TRUE(fillInputBuffer(Name, Data, Params.CompType, + NumElements), + "Saw unsupported component type"); + }); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::LoadMemory_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runLoadMemory(D3DDevice, DxcSupport, Params, VerboseLogging); +} + +static const char StoreMemoryShader[] = R"( + RWByteAddressBuffer Output : register(u0); + groupshared ELEM_TYPE GsData[M_DIM * N_DIM]; + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main() { + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE); + + __builtin_LinAlg_MatrixStoreToMemory( + Mat, GsData, OFFSET, STRIDE, LAYOUT); + + for (uint I = 0; I < M_DIM*N_DIM; ++I) { + Output.Store(I*ELEM_SIZE, GsData[I]); + } + } +)"; + +static void runStoreMemory(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + float FillValue) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DOFFSET=" << 0; + ExtraDefs << " -DFILL_VALUE=" << FillValue; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, StoreMemoryShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + FillValue, /*Increment=*/false); + + auto Op = + createComputeOp(StoreMemoryShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::A; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runStoreMemory(D3DDevice, DxcSupport, Params, VerboseLogging, + /*FillValue=*/7.0f); +} + +static const char AccumulateMemoryShader[] = R"( + RWByteAddressBuffer Output : register(u0); + groupshared ELEM_TYPE GsData[M_DIM * N_DIM]; + + #define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS) + + [WaveSize(4, 64)] + [numthreads(NUMTHREADS, 1, 1)] + void main(uint threadID : SV_GroupIndex) { + ELEM_TYPE fill = FILL_VALUE; + for (uint I = 0; I < ELEM_PER_THREAD; ++I) { + uint Index = threadID * ELEM_PER_THREAD + I; + GsData[Index] = fill; + } + + GroupMemoryBarrierWithGroupSync(); + + if (GetGroupWaveIndex() != 0) + return; + + __builtin_LinAlgMatrix + [[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]] + Mat; + __builtin_LinAlg_FillMatrix(Mat, FILL_VALUE); + + __builtin_LinAlg_MatrixAccumulateToMemory( + Mat, GsData, OFFSET, STRIDE, LAYOUT); + + for (uint I = 0; I < M_DIM*N_DIM; ++I) { + Output.Store(I*ELEM_SIZE, GsData[I]); + } + } +)"; + +static void runAccumulateMemory(ID3D12Device *Device, + dxc::SpecificDllLoader &DxcSupport, + const MatrixParams &Params, bool Verbose, + float FillValue) { + const size_t NumElements = Params.totalElements(); + const size_t BufferSize = Params.totalBytes(); + + std::stringstream ExtraDefs; + ExtraDefs << " -DOFFSET=" << 0; + ExtraDefs << " -DFILL_VALUE=" << FillValue; + + std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str()); + + compileShader(DxcSupport, AccumulateMemoryShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, + FillValue * 2, /*Increment=*/false); + + auto Op = createComputeOp(AccumulateMemoryShader, "cs_6_10", "UAV(u0)", + Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16() { + MatrixParams Params = {}; + Params.CompType = ComponentType::F16; + Params.M = 16; + Params.N = 16; + Params.Use = MatrixUse::Accumulator; + Params.Scope = MatrixScope::Wave; + Params.Layout = LinalgMatrixLayout::RowMajor; + Params.NumThreads = 64; + Params.Enable16Bit = true; + runAccumulateMemory(D3DDevice, DxcSupport, Params, VerboseLogging, + /*FillValue=*/7.0f); +} + +static const char ConvertShader[] = R"( + #define CT_F16 8 + #define CT_F32 9 + + RWByteAddressBuffer Output : register(u0); + + [numthreads(1, 1, 1)] + void main() { + vector InVec = {1.0, 2.0, 3.0, 4.0}; + vector OutVec; + __builtin_LinAlg_Convert(OutVec, InVec, CT_F16, CT_F32); + Output.Store(0, OutVec.x); + Output.Store(4, OutVec.y); + Output.Store(8, OutVec.z); + Output.Store(12, OutVec.w); + } +)"; + +static void runConvert(ID3D12Device *Device, dxc::SpecificDllLoader &DxcSupport, + bool Verbose) { + std::string Args = "-HV 202x"; + MatrixDim NumElements = 4; + size_t BufferSize = elementSize(ComponentType::F32) * NumElements; + + compileShader(DxcSupport, ConvertShader, "cs_6_10", Args, Verbose); + + auto Expected = makeExpectedVec(ComponentType::F32, NumElements, 1.0); + + auto Op = createComputeOp(ConvertShader, "cs_6_10", "UAV(u0)", Args.c_str()); + addUAVBuffer(Op.get(), "Output", BufferSize, true); + addRootView(Op.get(), 0, "Output"); + + auto Result = runShaderOp(Device, DxcSupport, std::move(Op)); + + MappedData OutData; + Result->Test->GetReadBackData("Output", &OutData); + + VERIFY_IS_TRUE(verifyComponentBuffer(ComponentType::F32, OutData.data(), + Expected, NumElements, Verbose)); +} + +void DxilConf_SM610_LinAlg::Convert() { + runConvert(D3DDevice, DxcSupport, VerboseLogging); +} + } // namespace LinAlg