Skip to content

Commit 65815b0

Browse files
llvm-beanzmjbedyV-FEXrt
authored
[0035] Align matrix-vector APIs with coopvec (microsoft#741)
This aligns the matrix-vector HLSL APIs with the SM 6.9 cooperative vector feature such that the matrix is an `A` matrix and the vectors are column vectors rather than row vectors. It also aligns the argument orders between the HLSL and DXIL APIs to make it easier to read (SM 6.9 had a mismatch between HLSL APIs and DXIL). --------- Co-authored-by: Michael Bedy <[email protected]> Co-authored-by: Ashley Coleman <[email protected]>
1 parent aca7ec9 commit 65815b0

File tree

1 file changed

+73
-74
lines changed

1 file changed

+73
-74
lines changed

proposals/0035-linalg-matrix.md

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -190,42 +190,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
190190

191191
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
192192
ComponentEnum MatrixDT, MatrixScopeEnum Scope>
193-
vector<OutputElTy, K> Multiply(vector<InputElTy, M>,
194-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>);
193+
vector<OutputElTy, M> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
194+
vector<InputElTy, K>);
195195

196196
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
197197
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT,
198198
MatrixScopeEnum Scope>
199-
vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
200-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
201-
vector<BiasElTy, K>);
199+
vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
200+
vector<InputElTy, K>, vector<BiasElTy, M>);
202201

203-
template <typename OutputElTy, typename InputElTy,
204-
ComponentEnum InputInterp, typename BiasElTy, SIZE_TYPE M,
205-
SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT,
206-
MatrixScopeEnum Scope>
207-
typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size ==
202+
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
203+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
204+
ComponentEnum MatrixDT, MatrixScopeEnum Scope>
205+
typename hlsl::enable_if<InterpretedVector<InputElTy, VecM, InputInterp>::Size ==
208206
M,
209207
vector<OutputElTy, K> >::type
210-
MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>,
211-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
208+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
209+
InterpretedVector<InputElTy, VecM, InputInterp>,
212210
vector<BiasElTy, K>);
213211

214212
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
215213
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
216214
vector<OutputElTy, K>
217-
MultiplyAdd(vector<InputElTy, M>,
218-
Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
219-
VectorRef<BiasElTy, K>);
215+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
216+
vector<InputElTy, M>, VectorRef<BiasElTy, K>);
220217

221-
template <typename OutputElTy, typename InputElTy,
222-
ComponentEnum InputInterp, ComponentEnum BiasElTy,
223-
SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT>
224-
typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size ==
218+
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
219+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
220+
ComponentEnum MatrixDT>
221+
typename hlsl::enable_if<InterpretedVector<InputElTy, VecM, InputInterp>::Size ==
225222
M,
226223
vector<OutputElTy, K> >::type
227-
MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>,
228-
Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
224+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
225+
InterpretedVector<InputElTy, VecM, InputInterp>,
229226
VectorRef<BiasElTy, K>);
230227

231228
// Outer product functions
@@ -282,32 +279,30 @@ ByteAddressBuffer B : register(t0);
282279

283280
void CoopVec() {
284281
using namespace dx::linalg;
285-
using MatrixBTy = Matrix<ComponentType::F16, 16, 16, MatrixUse::B,
286-
MatrixScope::Thread>;
282+
using MatrixATy =
283+
Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>;
287284

288285
vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
289-
MatrixBTy MatB = MatrixBTy::Load(
286+
MatrixATy MatA = MatrixATy::Load(
290287
MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4,
291288
MatrixLayout::RowMajor);
292-
vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB);
289+
vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec);
293290

294291
vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
295-
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB, NullBias);
292+
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias);
296293

297294
VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf,
298-
/*start offset*/ 4096};
299-
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB, MemBias);
295+
/*start offset*/ 4096};
296+
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias);
300297

301298
// Clang doesn't yet support packed types.
302299
#ifdef __hlsl_dx_compiler
303300
vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
304301

305302
vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
306-
MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB,
307-
MemBias);
303+
MatA, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MemBias);
308304
vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
309-
MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB,
310-
NullBias);
305+
MatA, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), NullBias);
311306
#endif
312307
}
313308
```
@@ -416,7 +411,7 @@ The following table summarizes the operations supported for each matrix scope:
416411
| `Matrix::SumAccumulate()` ||||
417412
| `linalg::Multiply(Matrix, Matrix)` ||||
418413
| `linalg::Multiply(vector, Matrix)` ||||
419-
| `linalg::MultiplyAdd(vector, Matrix, vector)` ||||
414+
| `linalg::MultiplyAdd(Matrix, vector, vector)` ||||
420415
| `linalg::OuterProduct(vector, vector)` ||||
421416

422417
Throughout this document a matrix may be described as having a scope as
@@ -697,7 +692,7 @@ Requires `Wave` or `ThreadGroup` scope matrix.
697692

698693
Returns the number of matrix components accessible to the current thread. If the
699694
matrix's elements are stored in a packed type, `Length` will return the number of
700-
packed elements (e.g. if a thread has 8 accessible elements of `int8` type
695+
packed elements (e.g. if a thread has 8 accessible elements of `int8` type
701696
packed into 2 `int8_t4_packed`, `Length` will return 2). The mapping and
702697
distribution of threads to matrix elements is opaque and
703698
implementation-specific. The value returned by `Length` may be different for
@@ -928,24 +923,30 @@ infers the type of the output accumulator to match the input vector element type
928923
the other overload takes a template parameter for the output matrix element type.
929924
All matrix scopes are allowed for the output matrix.
930925
931-
#### linalg::MultiplyAdd(vector, Matrix, vector)
926+
#### linalg::MultiplyAdd(Matrix, vector, vector)
932927
933928
``` c++
934929
template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M,
935930
uint K, ComponentType MatrixDT>
936931
vector<OutputElTy, K>
937-
linalg::MultiplyAdd(vector<InputElTy, M>,
938-
Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
932+
linalg::MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
933+
vector<InputElTy, M>,
939934
vector<BiasElTy, K>);
940935
```
941936

942937
Requires `Thread` scope matrix input, may be called from divergent control flow.
943938

944-
The `linalg::MultiplyAdd` function has an overload that takes an `M`-element, an
945-
MxK `B` matrix with `Thread` scope, and a `K`-element vector. The operation
939+
The `linalg::MultiplyAdd` function has an overload that takes an MxK `A` matrix
940+
with `Thread` scope, an `M`-element vector, and a `K`-element vector. The operation
946941
multiplies the `M`-element vector by the matrix then adds the `K`-element vector
947942
producing a result `K`-element vector.
948943

944+
Either vector may be a native vector or an `InterpretedVector` which combines a
945+
packed element vector with an interpretation type. The `K`-element vector may
946+
also be a `VectorRef` which refers to a vector in memory. Using the `VectorRef`
947+
overload makes it easier for the backend compiler to optimize the bias vector
948+
loads with the ALU operations.
949+
949950
### DXIL Types
950951

951952
This feature adds the following new DXIL enumerations, which used as immediate
@@ -1212,37 +1213,37 @@ Must be called from wave-uniform control flow.
12121213
``` llvm
12131214
declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi](
12141215
immarg i32, ; opcode
1216+
%dx.types.MatrixRef, ; matrix A
12151217
<[NUMi] x [TYi]>, ; input vector
1216-
immarg i32, ; input interpretation type (DXILComponentType)
1217-
%dx.types.MatrixRef ; matrix A
1218+
immarg i32 ; input interpretation type (DXILComponentType)
12181219
)
12191220
```
12201221

1221-
This operation implements a row-vector multiplication against a `B` matrix of
1222+
This operation implements a row-vector multiplication against an `A` matrix of
12221223
`Thread` scope.
12231224

12241225
Validation will enforce that:
1225-
* The input vector length matches the `M` matrix dimension
1226-
* The matrix A is a `B` matrix of `Thread` scope
1226+
* The input vector length matches the `K` matrix dimension
1227+
* The matrix A is an `A` matrix of `Thread` scope
12271228

12281229
``` llvm
12291230
declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi].v[NUMo][TYb](
12301231
immarg i32, ; opcode
1232+
%dx.types.MatrixRef, ; matrix A
12311233
<[NUMi] x [TYi]>, ; input vector
12321234
immarg i32, ; input interpretation type (DXILComponentType)
1233-
%dx.types.MatrixRef, ; matrix A
12341235
<[NUMo] x [TYb]>, ; bias vector
12351236
immarg i32 ; bias interpretation type (DXILComponentType)
12361237
)
12371238
```
12381239

1239-
This operation implements a row-vector multiplication against a `B` matrix of
1240+
This operation implements a row-vector multiplication against an `A` matrix of
12401241
`Thread` scope with a bias vector added to the result.
12411242

12421243
Validation will enforce that:
1243-
* The input vector length matches the `M` matrix dimension
1244-
* The bias vector length matches the `N` matrix dimension
1245-
* The matrix A is a `B` matrix of `Thread` scope
1244+
* The input vector length matches the `K` matrix dimension
1245+
* The bias vector length matches the `M` matrix dimension
1246+
* The matrix A is an `A` matrix of `Thread` scope
12461247

12471248
```llvm
12481249
declare void @dx.op.matrixAccumulateToDescriptor(
@@ -1374,7 +1375,7 @@ in the [`DXILComponentType` enumeration](#dxil-enumerations).
13741375

13751376
## Appendix 2: HLSL Header
13761377

1377-
[Compiler Explorer](https://godbolt.org/z/W5a7zbPr3)
1378+
[Compiler Explorer](https://godbolt.org/z/aPWK1KjeE)
13781379
> Note: this mostly works with Clang, but has some issues to work out still.
13791380
13801381
```cpp
@@ -1639,41 +1640,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
16391640

16401641
template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
16411642
ComponentEnum MatrixDT, MatrixScopeEnum Scope>
1642-
vector<OutputElTy, K> Multiply(vector<InputElTy, M>,
1643-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>);
1643+
vector<OutputElTy, M> Multiply(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
1644+
vector<InputElTy, K>);
16441645

16451646
template <typename OutputElTy, typename InputElTy, typename BiasElTy,
16461647
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT,
16471648
MatrixScopeEnum Scope>
1648-
vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
1649-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
1650-
vector<BiasElTy, K>);
1649+
vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
1650+
vector<InputElTy, K>, vector<BiasElTy, M>);
16511651

16521652
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
1653-
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K,
1653+
typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
16541654
ComponentEnum MatrixDT, MatrixScopeEnum Scope>
1655-
typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size ==
1655+
typename hlsl::enable_if<InterpretedVector<InputElTy, VecM, InputInterp>::Size ==
16561656
M,
16571657
vector<OutputElTy, K> >::type
1658-
MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>,
1659-
Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
1658+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
1659+
InterpretedVector<InputElTy, VecM, InputInterp>,
16601660
vector<BiasElTy, K>);
16611661

16621662
template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
16631663
SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
16641664
vector<OutputElTy, K>
1665-
MultiplyAdd(vector<InputElTy, M>,
1666-
Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
1667-
VectorRef<BiasElTy, K>);
1665+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
1666+
vector<InputElTy, M>, VectorRef<BiasElTy, K>);
16681667

16691668
template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
1670-
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K,
1669+
ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
16711670
ComponentEnum MatrixDT>
1672-
typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size ==
1671+
typename hlsl::enable_if<InterpretedVector<InputElTy, VecM, InputInterp>::Size ==
16731672
M,
16741673
vector<OutputElTy, K> >::type
1675-
MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp>,
1676-
Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
1674+
MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
1675+
InterpretedVector<InputElTy, VecM, InputInterp>,
16771676
VectorRef<BiasElTy, K>);
16781677

16791678
// Outer product functions
@@ -1722,30 +1721,30 @@ ByteAddressBuffer MBuf : register(t0);
17221721

17231722
void CoopVec() {
17241723
using namespace dx::linalg;
1725-
using MatrixBTy =
1726-
Matrix<ComponentType::F16, 16, 16, MatrixUse::B, MatrixScope::Thread>;
1724+
using MatrixATy =
1725+
Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>;
17271726

17281727
vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
1729-
MatrixBTy MatB = MatrixBTy::Load(
1728+
MatrixATy MatA = MatrixATy::Load(
17301729
MBuf, 0, /* Row stride = number of columns * element size */ 16 * 4,
17311730
MatrixLayout::RowMajor);
1732-
vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB);
1731+
vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec);
17331732

17341733
vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
1735-
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB, NullBias);
1734+
vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1, NullBias);
17361735

17371736
VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf,
17381737
/*start offset*/ 4096};
1739-
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB, MemBias);
1738+
vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2, MemBias);
17401739

17411740
// Clang doesn't yet support packed types.
17421741
#ifdef __hlsl_dx_compiler
17431742
vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
17441743

17451744
vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
1746-
MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, MemBias);
1745+
MatA, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MemBias);
17471746
vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
1748-
MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), MatB, NullBias);
1747+
MatA, MakeInterpretedVector<ComponentType::F8_E4M3>(SomeData), NullBias);
17491748
#endif
17501749
}
17511750

0 commit comments

Comments
 (0)