@@ -228,37 +228,58 @@ static bool fillInputBuffer(LPCSTR Name, std::vector<BYTE> &Data,
228228 return false ;
229229}
230230
231- static VariantCompType makeExpected (ComponentType CompType, size_t NumElements,
232- float StartingVal, bool Increment) {
231+ static VariantCompType makeExpected (ComponentType CompType, int32_t M,
232+ int32_t N, float StartingVal,
233+ bool Increment = true ,
234+ bool Transpose = false ) {
235+ const size_t NumElements = M * N;
236+ std::vector<float > Floats (NumElements);
237+ std::vector<int32_t > Ints (NumElements);
238+ std::vector<HLSLHalf_t> Halfs (NumElements);
239+
240+ for (size_t I = 0 ; I < M; ++I) {
241+ for (size_t J = 0 ; J < N; ++J) {
242+ size_t Value = I * M + J;
243+ size_t Idx = Transpose ? J * N + I : Value;
244+ switch (CompType) {
245+ case ComponentType::F32:
246+ Floats[Idx] = StartingVal + static_cast <float >(Increment ? Value : 0 );
247+ break ;
248+ case ComponentType::I32:
249+ VERIFY_IS_TRUE (StartingVal < static_cast <float >(
250+ std::numeric_limits<int32_t >::max ()),
251+ " Value too large to cast to int32_t" );
252+ VERIFY_IS_TRUE (StartingVal > static_cast <float >(
253+ std::numeric_limits<int32_t >::min ()),
254+ " Value too small to cast to int32_t" );
255+ Ints[Idx] = static_cast <int32_t >(StartingVal) +
256+ static_cast <int32_t >(Increment ? Value : 0 );
257+ break ;
258+ case ComponentType::F16: {
259+ // Downcasting is safe here since HLSLHalf_t will clamp if F is too
260+ // large.
261+ float F = StartingVal + static_cast <float >(Increment ? Value : 0 );
262+ Halfs[Idx] = HLSLHalf_t (F);
263+ break ;
264+ }
265+ default :
266+ VERIFY_IS_TRUE (false , " Unable to fill unexpected ComponentType" );
267+ break ;
268+ }
269+ }
270+ }
271+
233272 switch (CompType) {
234- case ComponentType::F32: {
235- std::vector<float > Floats (NumElements);
236- for (size_t I = 0 ; I < NumElements; I++)
237- Floats[I] = StartingVal + static_cast <float >(Increment ? I : 0 );
273+ case ComponentType::F32:
238274 return Floats;
239- }
240- case ComponentType::I32: {
241- DXASSERT (StartingVal < static_cast <float >(INT_MAX),
242- " Value too large to cast to int32_t" );
243- std::vector<int32_t > Ints (NumElements);
244- for (size_t I = 0 ; I < NumElements; I++)
245- Ints[I] = static_cast <int32_t >(StartingVal) +
246- static_cast <int32_t >(Increment ? I : 0 );
275+ case ComponentType::I32:
247276 return Ints;
248- }
249- case ComponentType::F16: {
250- std::vector<HLSLHalf_t> Halfs (NumElements);
251- for (size_t I = 0 ; I < NumElements; I++) {
252- // Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
253- float F = StartingVal + static_cast <float >(Increment ? I : 0 );
254- Halfs[I] = HLSLHalf_t (F);
255- }
277+ case ComponentType::F16:
256278 return Halfs;
279+ default :
280+ VERIFY_IS_TRUE (false , " Unable to fill unexpected ComponentType" );
281+ return Floats;
257282 }
258- }
259-
260- DXASSERT (false , " Unable to fill unexpected ComponentType" );
261- return std::vector<float >();
262283}
263284
264285static void logCompiledButSkipping () {
@@ -384,6 +405,7 @@ static const char LoadStoreShader[] = R"(
384405 RWByteAddressBuffer Output : register(u1);
385406
386407#ifndef EMULATE_TEST
408+ [WaveSize(4, 64)]
387409 [numthreads(NUMTHREADS, 1, 1)]
388410 void main() {
389411 __builtin_LinAlgMatrix
@@ -429,7 +451,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
429451 return ;
430452 }
431453
432- auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
454+ auto Expected = makeExpected (Params.CompType , Params. M , Params. N , 1 );
433455
434456 // Construct the ShaderOp: two UAV buffers, load from one, store to other.
435457 auto Op = createComputeOp (LoadStoreShader, Target.c_str (), " UAV(u0), UAV(u1)" ,
@@ -463,7 +485,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
463485 Params.Use = MatrixUse::A;
464486 Params.Scope = MatrixScope::Wave;
465487 Params.Layout = LinalgMatrixLayout::RowMajor;
466- Params.NumThreads = 4 ;
488+ Params.NumThreads = 64 ;
467489 Params.Enable16Bit = true ;
468490 Params.EmulateTest = EmulateTest;
469491 runLoadStoreRoundtrip (D3DDevice, DxcSupport, Params, VerboseLogging,
@@ -474,6 +496,7 @@ static const char SplatStoreShader[] = R"(
474496 RWByteAddressBuffer Output : register(u0);
475497
476498#ifndef EMULATE_TEST
499+ [WaveSize(4, 64)]
477500 [numthreads(NUMTHREADS, 1, 1)]
478501 void main() {
479502 __builtin_LinAlgMatrix
@@ -517,7 +540,8 @@ static void runSplatStore(ID3D12Device *Device,
517540 return ;
518541 }
519542
520- auto Expected = makeExpected (Params.CompType , NumElements, FillValue, false );
543+ auto Expected =
544+ makeExpected (Params.CompType , Params.M , Params.N , FillValue, false );
521545
522546 auto Op = createComputeOp (SplatStoreShader, Target.c_str (), " UAV(u0)" ,
523547 Args.c_str ());
@@ -541,7 +565,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
541565 Params.Use = MatrixUse::Accumulator;
542566 Params.Scope = MatrixScope::Wave;
543567 Params.Layout = LinalgMatrixLayout::RowMajor;
544- Params.NumThreads = 4 ;
568+ Params.NumThreads = 64 ;
545569 Params.Enable16Bit = true ;
546570 Params.EmulateTest = EmulateTest;
547571 runSplatStore (D3DDevice, DxcSupport, Params, 42 .0f , VerboseLogging,
@@ -553,11 +577,13 @@ static const char ElementAccessShader[] = R"(
553577 RWByteAddressBuffer Output : register(u1);
554578
555579 // flatten the 2D index into a 1D index then scale by element size
580+ // Always store row-major and work it out in the test runner
556581 uint coordToByteOffset(uint2 coord) {
557- return (coord.x * MAJOR_DIM + coord.y ) * ELEM_SIZE;
582+ return (coord.y * N_DIM + coord.x ) * ELEM_SIZE;
558583 }
559584
560585#ifndef EMULATE_TEST
586+ [WaveSize(4, 64)]
561587 [numthreads(NUMTHREADS, 1, 1)]
562588 void main(uint threadIndex : SV_GroupIndex) {
563589 __builtin_LinAlgMatrix
@@ -605,8 +631,7 @@ static void runElementAccess(ID3D12Device *Device,
605631 const size_t NumThreads = Params.NumThreads ;
606632 const size_t InputBufSize = Params.totalBytes ();
607633 const size_t ElementSize = elementSize (Params.CompType );
608- const size_t MajorDim =
609- Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N ;
634+
610635 // Output: ElementSize bytes per element
611636 // 1 element for each mat idx
612637 // 1 uint for each thread's length
@@ -618,7 +643,6 @@ static void runElementAccess(ID3D12Device *Device,
618643 Target = " cs_6_8" ;
619644
620645 std::stringstream ExtraDefs;
621- ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
622646 std::string Args = buildCompilerArgs (Params, ExtraDefs.str ().c_str ());
623647
624648 compileShader (DxcSupport, ElementAccessShader, Target.c_str (), Args, Verbose);
@@ -628,7 +652,7 @@ static void runElementAccess(ID3D12Device *Device,
628652 return ;
629653 }
630654
631- auto Expected = makeExpected (Params.CompType , NumElements, 1 , true );
655+ auto Expected = makeExpected (Params.CompType , Params. M , Params. N , 1 );
632656
633657 auto Op = createComputeOp (ElementAccessShader, Target.c_str (),
634658 " UAV(u0), UAV(u1)" , Args.c_str ());
@@ -674,7 +698,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
674698 Params.Use = MatrixUse::Accumulator;
675699 Params.Scope = MatrixScope::Wave;
676700 Params.Layout = LinalgMatrixLayout::RowMajor;
677- Params.NumThreads = 4 ;
701+ Params.NumThreads = 64 ;
678702 Params.Enable16Bit = true ;
679703 Params.EmulateTest = EmulateTest;
680704 runElementAccess (D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
0 commit comments