1414#include " dxc/DXIL/DxilModule.h"
1515#include " dxc/HLSL/DxilGenerationPass.h"
1616
17+ #include " llvm/ADT/SmallVector.h"
1718#include " llvm/ADT/StringRef.h"
1819#include " llvm/IR/Constant.h"
1920#include " llvm/IR/Constants.h"
@@ -33,6 +34,7 @@ static bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
3334 CallInst *CI);
3435static bool scalarizeVectorIntrinsic (hlsl::OP *HlslOP, CallInst *CI);
3536static bool scalarizeVectorReduce (hlsl::OP *HlslOP, CallInst *CI);
37+ static bool scalarizeVectorDot (hlsl::OP *HlslOP, CallInst *CI);
3638static bool scalarizeVectorWaveMatch (hlsl::OP *HlslOP, CallInst *CI);
3739
3840class DxilScalarizeVectorIntrinsics : public ModulePass {
@@ -66,6 +68,7 @@ class DxilScalarizeVectorIntrinsics : public ModulePass {
6668 OpClass == DXIL::OpCodeClass::RawBufferVectorLoad ||
6769 OpClass == DXIL::OpCodeClass::RawBufferVectorStore ||
6870 OpClass == DXIL::OpCodeClass::VectorReduce ||
71+ OpClass == DXIL::OpCodeClass::Dot ||
6972 OpClass == DXIL::OpCodeClass::WaveMatch);
7073 if (!CouldRewrite)
7174 continue ;
@@ -84,6 +87,9 @@ class DxilScalarizeVectorIntrinsics : public ModulePass {
8487 case DXIL::OpCodeClass::VectorReduce:
8588 Changed |= scalarizeVectorReduce (HlslOP, CI);
8689 continue ;
90+ case DXIL::OpCodeClass::Dot:
91+ Changed |= scalarizeVectorDot (HlslOP, CI);
92+ continue ;
8793 case DXIL::OpCodeClass::WaveMatch:
8894 Changed |= scalarizeVectorWaveMatch (HlslOP, CI);
8995 continue ;
@@ -337,6 +343,66 @@ static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI) {
337343 return true ;
338344}
339345
346+ // Scalarize vectorized dot product
347+ static bool scalarizeVectorDot (hlsl::OP *HlslOP, CallInst *CI) {
348+ IRBuilder<> Builder (CI);
349+
350+ Value *AVecArg = CI->getArgOperand (1 );
351+ Value *BVecArg = CI->getArgOperand (2 );
352+ VectorType *VecTy = cast<VectorType>(AVecArg->getType ());
353+ Type *ScalarTy = VecTy->getScalarType ();
354+ const unsigned VecSize = VecTy->getNumElements ();
355+
356+ // The only valid opcode is FDot which only has floating point overload.
357+ // If we hit this assert then this functions lowering needs to be updated
358+ assert (ScalarTy->isFloatingPointTy () && " Unexpected scalar type" );
359+
360+ SmallVector<Value *, 4 > AElts (VecSize);
361+ SmallVector<Value *, 4 > BElts (VecSize);
362+
363+ for (unsigned EltIdx = 0 ; EltIdx < VecSize; EltIdx++) {
364+ AElts[EltIdx] = Builder.CreateExtractElement (AVecArg, EltIdx);
365+ BElts[EltIdx] = Builder.CreateExtractElement (BVecArg, EltIdx);
366+ }
367+
368+ DXIL::OpCode DotOp = DXIL::OpCode::Dot4;
369+ switch (VecSize) {
370+ // Calling dot on a vec1 is not typical but also not impossible
371+ // DXIL doesn't have a native Dot1 opcode but thats the same as a
372+ // single FMul. HLOperation lower is expected to do the conversion
373+ // so we assert here in case that ever changes.
374+ case 1 :
375+ assert (false && " vector dot shouldn't appear for vec1" );
376+ break ;
377+ case 2 :
378+ DotOp = DXIL::OpCode::Dot2;
379+ break ;
380+ case 3 :
381+ DotOp = DXIL::OpCode::Dot3;
382+ break ;
383+ case 4 :
384+ DotOp = DXIL::OpCode::Dot4;
385+ break ;
386+ default :
387+ assert (false &&
388+ " Vectors larger than 4 components are not supported in SM6.8" );
389+ break ;
390+ }
391+
392+ SmallVector<Value *, 9 > Args (VecSize * 2 + 1 );
393+ Args[0 ] = Builder.getInt32 ((unsigned )DotOp);
394+
395+ for (unsigned EltIdx = 0 ; EltIdx < VecSize; EltIdx++) {
396+ Args[EltIdx + 1 ] = AElts[EltIdx];
397+ Args[EltIdx + 1 + VecSize] = BElts[EltIdx];
398+ }
399+
400+ Function *Func = HlslOP->GetOpFunc (DotOp, ScalarTy);
401+ Value *Dot = Builder.CreateCall (Func, Args, CI->getName ());
402+ CI->replaceAllUsesWith (Dot);
403+ return true ;
404+ }
405+
340406// Scalarize native vector operation represented by `CI`, generating
341407// scalar calls for each element of the its vector parameters.
342408// Use `HlslOP` to retrieve the associated scalar op function.
0 commit comments