@@ -190,42 +190,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
190190
191191template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
192192 ComponentEnum MatrixDT, MatrixScopeEnum Scope>
193- vector<OutputElTy, K > Multiply(vector<InputElTy , M>,
194- Matrix<MatrixDT, M, K, MatrixUse::B, Scope >);
193+ vector<OutputElTy, M > Multiply(Matrix<MatrixDT , M, K, MatrixUse::A, Scope >,
194+ vector<InputElTy, K >);
195195
196196template <typename OutputElTy, typename InputElTy, typename BiasElTy,
197197 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT,
198198 MatrixScopeEnum Scope>
199- vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
200- Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
201- vector<BiasElTy, K>);
199+ vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
200+ vector<InputElTy, K>, vector<BiasElTy, M>);
202201
203- template <typename OutputElTy, typename InputElTy,
204- ComponentEnum InputInterp, typename BiasElTy, SIZE_TYPE M,
205- SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT,
206- MatrixScopeEnum Scope>
207- typename hlsl::enable_if<InterpretedVector<InputElTy, N, InputInterp>::Size ==
202+ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
203+ typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K,
204+ ComponentEnum MatrixDT, MatrixScopeEnum Scope>
205+ typename hlsl::enable_if<InterpretedVector<InputElTy, VecM, InputInterp>::Size ==
208206 M,
209207 vector<OutputElTy, K> >::type
210- MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp >,
211- Matrix<MatrixDT, M, K, MatrixUse::B, Scope >,
208+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope >,
209+ InterpretedVector<InputElTy, VecM, InputInterp >,
212210 vector<BiasElTy, K>);
213211
214212template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
215213 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
216214vector<OutputElTy, K>
217- MultiplyAdd(vector<InputElTy, M>,
218- Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
219- VectorRef<BiasElTy, K>);
215+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
216+ vector<InputElTy, M>, VectorRef<BiasElTy, K>);
220217
221- template <typename OutputElTy, typename InputElTy,
222- ComponentEnum InputInterp, ComponentEnum BiasElTy ,
223- SIZE_TYPE M, SIZE_TYPE N, SIZE_TYPE K, ComponentEnum MatrixDT>
224- typename hlsl::enable_if<InterpretedVector<InputElTy, N , InputInterp>::Size ==
218+ template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
219+ ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM, SIZE_TYPE K ,
220+ ComponentEnum MatrixDT>
221+ typename hlsl::enable_if<InterpretedVector<InputElTy, VecM , InputInterp>::Size ==
225222 M,
226223 vector<OutputElTy, K> >::type
227- MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp >,
228- Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread >,
224+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread >,
225+ InterpretedVector<InputElTy, VecM, InputInterp >,
229226 VectorRef<BiasElTy, K>);
230227
231228// Outer product functions
@@ -282,32 +279,30 @@ ByteAddressBuffer B : register(t0);
282279
283280void CoopVec () {
284281 using namespace dx::linalg;
285- using MatrixBTy = Matrix<ComponentType::F16, 16, 16, MatrixUse::B,
286- MatrixScope::Thread>;
282+ using MatrixATy =
283+ Matrix<ComponentType::F16, 16, 16, MatrixUse::A, MatrixScope::Thread>;
287284
288285 vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
289- MatrixBTy MatB = MatrixBTy ::Load(
286+ MatrixATy MatA = MatrixATy ::Load(
290287 MBuf, 0, /* Row stride = number of columns * element size * / 16 * 4,
291288 MatrixLayout::RowMajor);
292- vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB );
289+ vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec );
293290
294291 vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
295- vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB , NullBias);
292+ vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1 , NullBias);
296293
297294 VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf,
298- /* start offset* / 4096};
299- vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB , MemBias);
295+ /* start offset* / 4096};
296+ vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2 , MemBias);
300297
301298 // Clang doesn't yet support packed types.
302299#ifdef __hlsl_dx_compiler
303300 vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
304301
305302 vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
306- MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MatB,
307- MemBias);
303+ MatA, MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MemBias);
308304 vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
309- MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MatB,
310- NullBias);
305+ MatA, MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), NullBias);
311306#endif
312307}
313308```
@@ -416,7 +411,7 @@ The following table summarizes the operations supported for each matrix scope:
416411| ` Matrix::SumAccumulate() ` | ✗ | ✓ | ✓ |
417412| ` linalg::Multiply(Matrix, Matrix) ` | ✗ | ✓ | ✓ |
418413| ` linalg::Multiply(vector, Matrix) ` | ✓ | ✗ | ✗ |
419- | ` linalg::MultiplyAdd(vector, Matrix , vector) ` | ✓ | ✗ | ✗ |
414+ | ` linalg::MultiplyAdd(Matrix, vector , vector) ` | ✓ | ✗ | ✗ |
420415| ` linalg::OuterProduct(vector, vector) ` | ✓ | ✓ | ✓ |
421416
422417Throughout this document a matrix may be described as having a scope as
@@ -697,7 +692,7 @@ Requires `Wave` or `ThreadGroup` scope matrix.
697692
698693Returns the number of matrix components accessible to the current thread. If the
699694matrix's elements are stored in a packed type, ` Length ` will return the number of
700- packed elements (e.g. if a thread has 8 accessible elements of ` int8 ` type
695+ packed elements (e.g. if a thread has 8 accessible elements of ` int8 ` type
701696packed into 2 ` int8_t4_packed ` , ` Length ` will return 2). The mapping and
702697distribution of threads to matrix elements is opaque and
703698implementation-specific. The value returned by ` Length ` may be different for
@@ -928,24 +923,30 @@ infers the type of the output accumulator to match the input vector element type
928923the other overload takes a template parameter for the output matrix element type.
929924All matrix scopes are allowed for the output matrix.
930925
931- #### linalg::MultiplyAdd(vector, Matrix , vector)
926+ #### linalg::MultiplyAdd(Matrix, vector , vector)
932927
933928``` c++
934929template <typename OutputElTy, typename InputElTy, typename BiasElTy, uint M,
935930 uint K, ComponentType MatrixDT>
936931vector<OutputElTy, K>
937- linalg::MultiplyAdd(vector<InputElTy , M>,
938- Matrix<MatrixDT , M, K, MatrixUse::B, MatrixScope::Thread >,
932+ linalg::MultiplyAdd(Matrix<MatrixDT , M, K, MatrixUse::A, MatrixScope::Thread >,
933+ vector<InputElTy , M>,
939934 vector<BiasElTy, K>);
940935```
941936
942937Requires ` Thread ` scope matrix input, may be called from divergent control flow.
943938
944- The ` linalg::MultiplyAdd ` function has an overload that takes an ` M ` -element, an
945- MxK ` B ` matrix with ` Thread ` scope , and a ` K ` -element vector. The operation
939+ The ` linalg::MultiplyAdd ` function has an overload that takes an MxK ` A ` matrix
940+ with ` Thread ` scope, an ` M ` -element vector , and a ` K ` -element vector. The operation
946941multiplies the ` M ` -element vector by the matrix then adds the ` K ` -element vector
947942producing a result ` K ` -element vector.
948943
944+ Either vector may be a native vector or an ` InterpretedVector ` which combines a
945+ packed element vector with an interpretation type. The ` K ` -element vector may
946+ also be a ` VectorRef ` which refers to a vector in memory. Using the ` VectorRef `
947+ overload makes it easier for the backend compiler to optimize the bias vector
948+ loads with the ALU operations.
949+
949950### DXIL Types
950951
951952This feature adds the following new DXIL enumerations, which used as immediate
@@ -1212,37 +1213,37 @@ Must be called from wave-uniform control flow.
12121213``` llvm
12131214declare <[NUMo] x [TYo]> @dx.op.matvecmul.v[NUMo][TYo].v[NUMi][TYi](
12141215 immarg i32, ; opcode
1216+ %dx.types.MatrixRef, ; matrix A
12151217 <[NUMi] x [TYi]>, ; input vector
1216- immarg i32, ; input interpretation type (DXILComponentType)
1217- %dx.types.MatrixRef ; matrix A
1218+ immarg i32 ; input interpretation type (DXILComponentType)
12181219)
12191220```
12201221
1221- This operation implements a row-vector multiplication against a ` B ` matrix of
1222+ This operation implements a row-vector multiplication against an ` A ` matrix of
12221223` Thread ` scope.
12231224
12241225Validation will enforce that:
1225- * The input vector length matches the ` M ` matrix dimension
1226- * The matrix A is a ` B ` matrix of ` Thread ` scope
1226+ * The input vector length matches the ` K ` matrix dimension
1227+ * The matrix A is an ` A ` matrix of ` Thread ` scope
12271228
12281229``` llvm
12291230declare <[NUMo] x [TYo]> @dx.op.matvecmuladd.v[NUMo][TYo].v[NUMi][TYi].v[NUMo][TYb](
12301231 immarg i32, ; opcode
1232+ %dx.types.MatrixRef, ; matrix A
12311233 <[NUMi] x [TYi]>, ; input vector
12321234 immarg i32, ; input interpretation type (DXILComponentType)
1233- %dx.types.MatrixRef, ; matrix A
12341235 <[NUMo] x [TYb]>, ; bias vector
12351236 immarg i32 ; bias interpretation type (DXILComponentType)
12361237)
12371238```
12381239
1239- This operation implements a row-vector multiplication against a ` B ` matrix of
1240+ This operation implements a row-vector multiplication against an ` A ` matrix of
12401241` Thread ` scope with a bias vector added to the result.
12411242
12421243Validation will enforce that:
1243- * The input vector length matches the ` M ` matrix dimension
1244- * The bias vector length matches the ` N ` matrix dimension
1245- * The matrix A is a ` B ` matrix of ` Thread ` scope
1244+ * The input vector length matches the ` K ` matrix dimension
1245+ * The bias vector length matches the ` M ` matrix dimension
1246+ * The matrix A is an ` A ` matrix of ` Thread ` scope
12461247
12471248``` llvm
12481249declare void @dx.op.matrixAccumulateToDescriptor(
@@ -1374,7 +1375,7 @@ in the [`DXILComponentType` enumeration](#dxil-enumerations).
13741375
13751376## Appendix 2: HLSL Header
13761377
1377- [ Compiler Explorer] ( https://godbolt.org/z/W5a7zbPr3 )
1378+ [ Compiler Explorer] ( https://godbolt.org/z/aPWK1KjeE )
13781379> Note: this mostly works with Clang, but has some issues to work out still.
13791380
13801381``` cpp
@@ -1639,41 +1640,39 @@ Multiply(const Matrix<T, M, K, MatrixUse::A, MatrixScope::ThreadGroup>,
16391640
16401641template <typename OutputElTy, typename InputElTy, SIZE_TYPE M, SIZE_TYPE K,
16411642 ComponentEnum MatrixDT, MatrixScopeEnum Scope>
1642- vector<OutputElTy, K > Multiply(vector<InputElTy , M>,
1643- Matrix<MatrixDT, M, K, MatrixUse::B, Scope >);
1643+ vector<OutputElTy, M > Multiply(Matrix<MatrixDT , M, K, MatrixUse::A, Scope >,
1644+ vector<InputElTy, K >);
16441645
16451646template <typename OutputElTy, typename InputElTy, typename BiasElTy,
16461647 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT,
16471648 MatrixScopeEnum Scope>
1648- vector<OutputElTy, K> MultiplyAdd(vector<InputElTy, M>,
1649- Matrix<MatrixDT, M, K, MatrixUse::B, Scope>,
1650- vector<BiasElTy, K>);
1649+ vector<OutputElTy, M> MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope>,
1650+ vector<InputElTy, K>, vector<BiasElTy, M>);
16511651
16521652template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
1653- typename BiasElTy, SIZE_TYPE M, SIZE_TYPE N , SIZE_TYPE K,
1653+ typename BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
16541654 ComponentEnum MatrixDT, MatrixScopeEnum Scope>
1655- typename hlsl::enable_if<InterpretedVector<InputElTy, N , InputInterp>::Size ==
1655+ typename hlsl::enable_if<InterpretedVector<InputElTy, VecM , InputInterp>::Size ==
16561656 M,
16571657 vector<OutputElTy, K> >::type
1658- MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp >,
1659- Matrix<MatrixDT, M, K, MatrixUse::B, Scope >,
1658+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, Scope >,
1659+ InterpretedVector<InputElTy, VecM, InputInterp >,
16601660 vector<BiasElTy, K>);
16611661
16621662template <typename OutputElTy, typename InputElTy, ComponentEnum BiasElTy,
16631663 SIZE_TYPE M, SIZE_TYPE K, ComponentEnum MatrixDT>
16641664vector<OutputElTy, K>
1665- MultiplyAdd(vector<InputElTy, M>,
1666- Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread>,
1667- VectorRef<BiasElTy, K>);
1665+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread>,
1666+ vector<InputElTy, M>, VectorRef<BiasElTy, K>);
16681667
16691668template <typename OutputElTy, typename InputElTy, ComponentEnum InputInterp,
1670- ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE N , SIZE_TYPE K,
1669+ ComponentEnum BiasElTy, SIZE_TYPE M, SIZE_TYPE VecM , SIZE_TYPE K,
16711670 ComponentEnum MatrixDT>
1672- typename hlsl::enable_if<InterpretedVector<InputElTy, N , InputInterp>::Size ==
1671+ typename hlsl::enable_if<InterpretedVector<InputElTy, VecM , InputInterp>::Size ==
16731672 M,
16741673 vector<OutputElTy, K> >::type
1675- MultiplyAdd(InterpretedVector<InputElTy, N, InputInterp >,
1676- Matrix<MatrixDT, M, K, MatrixUse::B, MatrixScope::Thread >,
1674+ MultiplyAdd(Matrix<MatrixDT, M, K, MatrixUse::A, MatrixScope::Thread >,
1675+ InterpretedVector<InputElTy, VecM, InputInterp >,
16771676 VectorRef<BiasElTy, K>);
16781677
16791678// Outer product functions
@@ -1722,30 +1721,30 @@ ByteAddressBuffer MBuf : register(t0);
17221721
17231722void CoopVec() {
17241723 using namespace dx::linalg;
1725- using MatrixBTy =
1726- Matrix<ComponentType::F16, 16, 16, MatrixUse::B , MatrixScope::Thread>;
1724+ using MatrixATy =
1725+ Matrix<ComponentType::F16, 16, 16, MatrixUse::A , MatrixScope::Thread>;
17271726
17281727 vector<float16_t, 16> Vec = (vector<float16_t, 16>)0;
1729- MatrixBTy MatB = MatrixBTy ::Load(
1728+ MatrixATy MatA = MatrixATy ::Load(
17301729 MBuf, 0, /* Row stride = number of columns * element size * / 16 * 4,
17311730 MatrixLayout::RowMajor);
1732- vector<float16_t, 16> Layer1 = Multiply<float16_t>(Vec, MatB );
1731+ vector<float16_t, 16> Layer1 = Multiply<float16_t>(MatA, Vec );
17331732
17341733 vector<float16_t, 16> NullBias = (vector<float16_t, 16>)0;
1735- vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(Layer1, MatB , NullBias);
1734+ vector<float16_t, 16> Layer2 = MultiplyAdd<float16_t>(MatA, Layer1 , NullBias);
17361735
17371736 VectorRef<ComponentType::F8_E4M3, 16> MemBias = {MBuf,
17381737 /* start offset* / 4096};
1739- vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(Layer2, MatB , MemBias);
1738+ vector<float16_t, 16> Layer3 = MultiplyAdd<float16_t>(MatA, Layer2 , MemBias);
17401739
17411740 // Clang doesn't yet support packed types.
17421741#ifdef __ hlsl_dx_compiler
17431742 vector<uint8_t4_packed, 4> SomeData = (vector<uint8_t4_packed, 4>)0;
17441743
17451744 vector<float16_t, 16> Layer4 = MultiplyAdd<float16_t>(
1746- MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MatB , MemBias);
1745+ MatA, MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MemBias);
17471746 vector<float16_t, 16> Layer5 = MultiplyAdd<float16_t>(
1748- MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), MatB , NullBias);
1747+ MatA, MakeInterpretedVector< ComponentType::F8_E4M3 > (SomeData), NullBias);
17491748#endif
17501749}
17511750
0 commit comments