Skip to content

Commit 02686ef

Browse files
authored
[SM6.10] Update HL LinAlg builtin signatures (#8173)
Resolves the `int MatrixRef` TODO left from previous changes, as well as updates most builtins to match their final shape. This should be considered an NFC change since these builtins aren't implemented yet and instead we are just updating the shape of the reversed op slots
1 parent 6de7fe1 commit 02686ef

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

utils/hct/gen_intrin_main.txt

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -397,32 +397,27 @@ void [[min_sm=6.10]] __builtin_VectorAccumulate(in LinAlg<c> InputVector, in RWB
397397

398398
// LinAlg intrinsics
399399

400-
// TODO: Replace all int MatrixRef with MatrixRef type
401400
// TODO: Replace all int GroupSharedMem with groupshared memory
402-
void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(int MatrixRef, numeric value);
403-
void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(int MatrixRefDest, int MatrixRefSrc, bool transpose);
404-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
405-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);
406-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(int MatrixRef);
407-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(int MatrixRef, int32_only threadLocalIndex);
408-
numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(int MatrixRef, int32_only threadLocalIndex);
409-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(int MatrixRef, int32_only threadLocalIndex, numeric value);
410-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
411-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);
412-
int32_only [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
413-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(int MatrixRefA, int MatrixRefB, int MatrixRefC);
414-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(int MatrixRefA, int MatrixRefB, int MatrixRefC);
415-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(int MatrixRefRHS, int MatrixRefLHS);
416-
417-
// TODO: Fix vector types
418-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(int MatrixRef);
419-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(int MatrixRef);
420-
421-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(int MatrixRef, resource buf, int32_only offset, int32_only stride, int32_only layout);
422-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(int MatrixRef, int GroupSharedMem, int32_only offset, int32_only stride, int32_only layout);
423-
424-
// TODO: Fix vector types
425-
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(int MatrixRef);
401+
void [[min_sm=6.10]] __builtin_LinAlg_FillMatrix(out LinAlgMatrix ret, in numeric value);
402+
void [[min_sm=6.10]] __builtin_LinAlg_CopyConvertMatrix(out LinAlgMatrix ret, in LinAlgMatrix source, in bool transpose);
403+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in ByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
404+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromDescriptor(out LinAlgMatrix ret, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
405+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixLoadFromMemory(out LinAlgMatrix ret, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
406+
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixLength(in LinAlgMatrix matrix);
407+
uint<2> [[min_sm=6.10]] __builtin_LinAlg_MatrixGetCoordinate(in LinAlgMatrix matrix, in uint threadLocalIndex);
408+
numeric [[min_sm=6.10]] __builtin_LinAlg_MatrixGetElement(in LinAlgMatrix matrix, in uint threadLocalIndex);
409+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixSetElement(out LinAlgMatrix ret, in LinAlgMatrix matrix, in uint threadLocalIndex, in numeric value);
410+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
411+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixStoreToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
412+
uint [[min_sm=6.10]] __builtin_LinAlg_MatrixQueryAccumulatorLayout();
413+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiply(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
414+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixMatrixMultiplyAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixA, in LinAlgMatrix matrixB);
415+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulate(out LinAlgMatrix matrixC, in LinAlgMatrix matrixLHS, in LinAlgMatrix matrixRHS);
416+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiply(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp);
417+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixVectorMultiplyAdd(out numeric<> ret, in LinAlgMatrix mat, in numeric<> input, in uint input_interp, in numeric<> bias, in uint bias_interp);
418+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToDescriptor(in LinAlgMatrix matrix, in RWByteAddressBuffer buf, in uint offset, in uint stride, in uint layout);
419+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixAccumulateToMemory(in LinAlgMatrix matrix, in int GroupSharedMem, in uint offset, in uint stride, in uint layout);
420+
void [[min_sm=6.10]] __builtin_LinAlg_MatrixOuterProduct(out LinAlgMatrix ret, in numeric<> vecA, in numeric<> vecB);
426421

427422
} namespace
428423

0 commit comments

Comments
 (0)