Skip to content

Commit 40ede0e

Browse files
authored
[LinAlg spec change] Update vector sizes on Multiply* to match column-vector multiplication (microsoft#843)
Matrix-vector APIs should be aligned with coopvec, meaning the multiplication would be matrix x column-based vector. This was originally done in microsoft#741 but somehow got mostly reverted.
1 parent cf59259 commit 40ede0e

1 file changed

Lines changed: 50 additions & 50 deletions

File tree

proposals/0035-linalg-matrix.md

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -234,42 +234,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
234234

235235
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
236236
ComponentEnum MatrixDT>
237-
vector<OutputElTy, K>
237+
vector<OutputElTy, M>
238238
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
239-
vector<InputElTy, M> Vec);
239+
vector<InputElTy, K> Vec);
240240

241241
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
242242
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
243-
vector<OutputElTy, K>
243+
vector<OutputElTy, M>
244244
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
245-
vector<InputElTy, M>, vector<BiasElTy, K> Vec);
245+
vector<InputElTy, K>, vector<BiasElTy, M> Vec);
246246

247247
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
248-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
248+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
249249
ComponentEnum MatrixDT>
250250
typename hlsl::enable_if<
251-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
252-
vector<OutputElTy, K> >::type
251+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
252+
vector<OutputElTy, M> >::type
253253
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
254-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
255-
vector<BiasElTy, K> Bias);
254+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
255+
vector<BiasElTy, M> Bias);
256256

257257
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
258258
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
259259
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
260-
vector<OutputElTy, K> >::type
260+
vector<OutputElTy, M> >::type
261261
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
262-
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef);
262+
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
263263

264264
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
265-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
265+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
266266
ComponentEnum MatrixDT>
267267
typename hlsl::enable_if<
268-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
269-
vector<OutputElTy, K> >::type
268+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
269+
vector<OutputElTy, M> >::type
270270
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
271-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
272-
VectorRef<BiasElTy, K> BiasRef);
271+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
272+
VectorRef<BiasElTy, M> BiasRef);
273273

274274
// Outer product functions
275275
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>
@@ -1086,16 +1086,16 @@ type and takes arguments with potentially mismatched element types.
10861086
``` c++
10871087
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
10881088
ComponentEnum MatrixDT>
1089-
vector<OutputElTy, K>
1089+
vector<OutputElTy, M>
10901090
linalg::Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1091-
vector<InputElTy, M> Vec);
1091+
vector<InputElTy, K> Vec);
10921092
```
10931093

10941094
Requires `Thread` scope matrix input, may be called from divergent control flow.
10951095

1096-
The `linalg::Multiply` function has an overload that takes an `M`-element vector
1097-
and an MxK `A` matrix with `Thread` scope. The function returns a `K`-element
1098-
vector.
1096+
The `linalg::Multiply` function has an overload that takes an MxK `A` matrix
1097+
with `Thread` scope, a `K`-element vector `Vec`. The operation multiplies the
1098+
matrix by the `K`-element vector `Vec` producing a result `M`-element vector.
10991099

11001100
#### linalg::OuterProduct(vector, vector)
11011101

@@ -1116,23 +1116,23 @@ parameter for the output matrix element type.
11161116
``` c++
11171117
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
11181118
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
1119-
vector<OutputElTy, K>
1119+
vector<OutputElTy, M>
11201120
linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
1121-
vector<InputElTy, M> Vec, vector<BiasElTy, K> Bias);
1121+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Bias);
11221122
```
11231123

11241124
Requires `Thread` scope matrix input, may be called from divergent control flow.
11251125

11261126
The `linalg::MultiplyAdd` function has an overload that takes an MxK `A` matrix
1127-
with `Thread` scope, an `M`-element vector, and a `K`-element vector. The operation
1128-
multiplies the `M`-element vector by the matrix then adds the `K`-element vector
1129-
producing a result `K`-element vector.
1127+
with `Thread` scope, a `K`-element vector `Vec`, and a `M`-element vector
1128+
`Bias`. The operation multiplies the matrix by the `K`-element vector `Vec` and
1129+
then adds the `M`-element vector `Bias` producing a result `M`-element vector.
11301130

11311131
Either vector may be a native vector or an `InterpretedVector` which combines a
1132-
packed element vector with an interpretation type. The `K`-element vector may
1133-
also be a `VectorRef` which refers to a vector in memory. Using the `VectorRef`
1134-
overload makes it easier for the backend compiler to optimize the bias vector
1135-
loads with the ALU operations.
1132+
packed element vector with an interpretation type. The `M`-element vector `Bias`
1133+
may also be a `VectorRef` which refers to a vector in memory. Using the
1134+
`VectorRef` overload makes it easier for the backend compiler to optimize the
1135+
bias vector loads with the ALU operations.
11361136

11371137
### DXIL Types
11381138

@@ -1492,8 +1492,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMul.v[NUMo][TYo].[MatTy].v[NUMi][TYi
14921492
)
14931493
```
14941494

1495-
This operation implements a row-vector multiplication against an `A` matrix of
1496-
`Thread` scope.
1495+
This operation implements a column-vector multiplication against an `A` matrix
1496+
of `Thread` scope.
14971497

14981498
Validation will enforce that:
14991499
* The input vector length matches the `K` matrix dimension
@@ -1516,8 +1516,8 @@ declare <[NUMo] x [TYo]> @dx.op.linAlgMatVecMulAdd.v[NUMo][TYo].[MatTy].v[NUMi][
15161516
)
15171517
```
15181518

1519-
This operation implements a row-vector multiplication against an `A` matrix of
1520-
`Thread` scope with a bias vector added to the result.
1519+
This operation implements a column-vector multiplication against an `A` matrix
1520+
of `Thread` scope with a bias vector added to the result.
15211521

15221522
Validation will enforce that:
15231523
* The input vector length matches the `K` matrix dimension
@@ -2104,42 +2104,42 @@ Matrix<CompTy, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup> Multiply(
21042104

21052105
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
21062106
ComponentEnum MatrixDT>
2107-
vector<OutputElTy, K>
2107+
vector<OutputElTy, M>
21082108
Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2109-
vector<InputElTy, M> Vec);
2109+
vector<InputElTy, K> Vec);
21102110

21112111
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
21122112
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
2113-
vector<OutputElTy, K>
2113+
vector<OutputElTy, M>
21142114
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2115-
vector<InputElTy, M>, vector<BiasElTy, K> Vec);
2115+
vector<InputElTy, K> Vec, vector<BiasElTy, M> Vec);
21162116

21172117
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2118-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
2118+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
21192119
ComponentEnum MatrixDT>
21202120
typename hlsl::enable_if<
2121-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
2122-
vector<OutputElTy, K> >::type
2121+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
2122+
vector<OutputElTy, M> >::type
21232123
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2124-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
2125-
vector<BiasElTy, K> Bias);
2124+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
2125+
vector<BiasElTy, M> Bias);
21262126

21272127
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
21282128
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
21292129
typename hlsl::enable_if<hlsl::is_arithmetic<InputElTy>::value,
2130-
vector<OutputElTy, K> >::type
2130+
vector<OutputElTy, M> >::type
21312131
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2132-
vector<InputElTy, M> Vec, VectorRef<BiasElTy, K> BiasRef);
2132+
vector<InputElTy, K> Vec, VectorRef<BiasElTy, M> BiasRef);
21332133

21342134
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
2135-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
2135+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecK, SIZE_TYPE K,
21362136
ComponentEnum MatrixDT>
21372137
typename hlsl::enable_if<
2138-
InterpretedVector<InputElTy, VecM, InputInterp>::Size == M,
2139-
vector<OutputElTy, K> >::type
2138+
InterpretedVector<InputElTy, VecK, InputInterp>::Size == K,
2139+
vector<OutputElTy, M> >::type
21402140
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread> MatrixA,
2141-
InterpretedVector<InputElTy, VecM, InputInterp> InterpVec,
2142-
VectorRef<BiasElTy, K> BiasRef);
2141+
InterpretedVector<InputElTy, VecK, InputInterp> InterpVec,
2142+
VectorRef<BiasElTy, M> BiasRef);
21432143

21442144
// Outer product functions
21452145
template <ComponentEnum OutTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE N>

0 commit comments

Comments
 (0)