Skip to content

Commit a5877b6

Browse files
[SM6.10][Exec] Fix GetElement test, add transpose to helper (#8361)
- Fixes Wave sizes on the GetElement test - Always store the matrix back in row-major - Removes now unnecessary `MAJOR_DIM` define - Adds transpose to `makeExpected` --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 13efb6a commit a5877b6

File tree

1 file changed

+60
-36
lines changed

1 file changed

+60
-36
lines changed

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,58 @@ static bool fillInputBuffer(LPCSTR Name, std::vector<BYTE> &Data,
228228
return false;
229229
}
230230

231-
static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
232-
float StartingVal, bool Increment) {
231+
static VariantCompType makeExpected(ComponentType CompType, int32_t M,
232+
int32_t N, float StartingVal,
233+
bool Increment = true,
234+
bool Transpose = false) {
235+
const size_t NumElements = M * N;
236+
std::vector<float> Floats(NumElements);
237+
std::vector<int32_t> Ints(NumElements);
238+
std::vector<HLSLHalf_t> Halfs(NumElements);
239+
240+
for (size_t I = 0; I < M; ++I) {
241+
for (size_t J = 0; J < N; ++J) {
242+
size_t Value = I * M + J;
243+
size_t Idx = Transpose ? J * N + I : Value;
244+
switch (CompType) {
245+
case ComponentType::F32:
246+
Floats[Idx] = StartingVal + static_cast<float>(Increment ? Value : 0);
247+
break;
248+
case ComponentType::I32:
249+
VERIFY_IS_TRUE(StartingVal < static_cast<float>(
250+
std::numeric_limits<int32_t>::max()),
251+
"Value too large to cast to int32_t");
252+
VERIFY_IS_TRUE(StartingVal > static_cast<float>(
253+
std::numeric_limits<int32_t>::min()),
254+
"Value too small to cast to int32_t");
255+
Ints[Idx] = static_cast<int32_t>(StartingVal) +
256+
static_cast<int32_t>(Increment ? Value : 0);
257+
break;
258+
case ComponentType::F16: {
259+
// Downcasting is safe here since HLSLHalf_t will clamp if F is too
260+
// large.
261+
float F = StartingVal + static_cast<float>(Increment ? Value : 0);
262+
Halfs[Idx] = HLSLHalf_t(F);
263+
break;
264+
}
265+
default:
266+
VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType");
267+
break;
268+
}
269+
}
270+
}
271+
233272
switch (CompType) {
234-
case ComponentType::F32: {
235-
std::vector<float> Floats(NumElements);
236-
for (size_t I = 0; I < NumElements; I++)
237-
Floats[I] = StartingVal + static_cast<float>(Increment ? I : 0);
273+
case ComponentType::F32:
238274
return Floats;
239-
}
240-
case ComponentType::I32: {
241-
DXASSERT(StartingVal < static_cast<float>(INT_MAX),
242-
"Value too large to cast to int32_t");
243-
std::vector<int32_t> Ints(NumElements);
244-
for (size_t I = 0; I < NumElements; I++)
245-
Ints[I] = static_cast<int32_t>(StartingVal) +
246-
static_cast<int32_t>(Increment ? I : 0);
275+
case ComponentType::I32:
247276
return Ints;
248-
}
249-
case ComponentType::F16: {
250-
std::vector<HLSLHalf_t> Halfs(NumElements);
251-
for (size_t I = 0; I < NumElements; I++) {
252-
// Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253-
float F = StartingVal + static_cast<float>(Increment ? I : 0);
254-
Halfs[I] = HLSLHalf_t(F);
255-
}
277+
case ComponentType::F16:
256278
return Halfs;
279+
default:
280+
VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType");
281+
return Floats;
257282
}
258-
}
259-
260-
DXASSERT(false, "Unable to fill unexpected ComponentType");
261-
return std::vector<float>();
262283
}
263284

264285
static void logCompiledButSkipping() {
@@ -384,6 +405,7 @@ static const char LoadStoreShader[] = R"(
384405
RWByteAddressBuffer Output : register(u1);
385406
386407
#ifndef EMULATE_TEST
408+
[WaveSize(4, 64)]
387409
[numthreads(NUMTHREADS, 1, 1)]
388410
void main() {
389411
__builtin_LinAlgMatrix
@@ -429,7 +451,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
429451
return;
430452
}
431453

432-
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
454+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);
433455

434456
// Construct the ShaderOp: two UAV buffers, load from one, store to other.
435457
auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)",
@@ -463,7 +485,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
463485
Params.Use = MatrixUse::A;
464486
Params.Scope = MatrixScope::Wave;
465487
Params.Layout = LinalgMatrixLayout::RowMajor;
466-
Params.NumThreads = 4;
488+
Params.NumThreads = 64;
467489
Params.Enable16Bit = true;
468490
Params.EmulateTest = EmulateTest;
469491
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -474,6 +496,7 @@ static const char SplatStoreShader[] = R"(
474496
RWByteAddressBuffer Output : register(u0);
475497
476498
#ifndef EMULATE_TEST
499+
[WaveSize(4, 64)]
477500
[numthreads(NUMTHREADS, 1, 1)]
478501
void main() {
479502
__builtin_LinAlgMatrix
@@ -517,7 +540,8 @@ static void runSplatStore(ID3D12Device *Device,
517540
return;
518541
}
519542

520-
auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);
543+
auto Expected =
544+
makeExpected(Params.CompType, Params.M, Params.N, FillValue, false);
521545

522546
auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)",
523547
Args.c_str());
@@ -541,7 +565,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
541565
Params.Use = MatrixUse::Accumulator;
542566
Params.Scope = MatrixScope::Wave;
543567
Params.Layout = LinalgMatrixLayout::RowMajor;
544-
Params.NumThreads = 4;
568+
Params.NumThreads = 64;
545569
Params.Enable16Bit = true;
546570
Params.EmulateTest = EmulateTest;
547571
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging,
@@ -553,11 +577,13 @@ static const char ElementAccessShader[] = R"(
553577
RWByteAddressBuffer Output : register(u1);
554578
555579
// flatten the 2D index into a 1D index then scale by element size
580+
// Always store row-major and work it out in the test runner
556581
uint coordToByteOffset(uint2 coord) {
557-
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
582+
return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
558583
}
559584
560585
#ifndef EMULATE_TEST
586+
[WaveSize(4, 64)]
561587
[numthreads(NUMTHREADS, 1, 1)]
562588
void main(uint threadIndex : SV_GroupIndex) {
563589
__builtin_LinAlgMatrix
@@ -605,8 +631,7 @@ static void runElementAccess(ID3D12Device *Device,
605631
const size_t NumThreads = Params.NumThreads;
606632
const size_t InputBufSize = Params.totalBytes();
607633
const size_t ElementSize = elementSize(Params.CompType);
608-
const size_t MajorDim =
609-
Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N;
634+
610635
// Output: ElementSize bytes per element
611636
// 1 element for each mat idx
612637
// 1 uint for each thread's length
@@ -618,7 +643,6 @@ static void runElementAccess(ID3D12Device *Device,
618643
Target = "cs_6_8";
619644

620645
std::stringstream ExtraDefs;
621-
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
622646
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
623647

624648
compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);
@@ -628,7 +652,7 @@ static void runElementAccess(ID3D12Device *Device,
628652
return;
629653
}
630654

631-
auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
655+
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);
632656

633657
auto Op = createComputeOp(ElementAccessShader, Target.c_str(),
634658
"UAV(u0), UAV(u1)", Args.c_str());
@@ -674,7 +698,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
674698
Params.Use = MatrixUse::Accumulator;
675699
Params.Scope = MatrixScope::Wave;
676700
Params.Layout = LinalgMatrixLayout::RowMajor;
677-
Params.NumThreads = 4;
701+
Params.NumThreads = 64;
678702
Params.Enable16Bit = true;
679703
Params.EmulateTest = EmulateTest;
680704
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);

0 commit comments

Comments
 (0)