Skip to content

Commit b3e84a8

Browse files
V-FEXrthekota
andauthored
[SM 6.9] Backport vector dot for 6.8 linkage (microsoft#7809)
Fixes microsoft#7794 Vector dot in SM6.9 libraries will be converted to scalar dot when linked against SM6.8 shader --------- Co-authored-by: Helena Kotas <[email protected]>
1 parent 0aa9fea commit b3e84a8

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

lib/HLSL/DxilScalarizeVectorIntrinsics.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "dxc/DXIL/DxilModule.h"
1515
#include "dxc/HLSL/DxilGenerationPass.h"
1616

17+
#include "llvm/ADT/SmallVector.h"
1718
#include "llvm/ADT/StringRef.h"
1819
#include "llvm/IR/Constant.h"
1920
#include "llvm/IR/Constants.h"
@@ -33,6 +34,7 @@ static bool scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
3334
CallInst *CI);
3435
static bool scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI);
3536
static bool scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI);
37+
static bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI);
3638
static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI);
3739

3840
class DxilScalarizeVectorIntrinsics : public ModulePass {
@@ -66,6 +68,7 @@ class DxilScalarizeVectorIntrinsics : public ModulePass {
6668
OpClass == DXIL::OpCodeClass::RawBufferVectorLoad ||
6769
OpClass == DXIL::OpCodeClass::RawBufferVectorStore ||
6870
OpClass == DXIL::OpCodeClass::VectorReduce ||
71+
OpClass == DXIL::OpCodeClass::Dot ||
6972
OpClass == DXIL::OpCodeClass::WaveMatch);
7073
if (!CouldRewrite)
7174
continue;
@@ -84,6 +87,9 @@ class DxilScalarizeVectorIntrinsics : public ModulePass {
8487
case DXIL::OpCodeClass::VectorReduce:
8588
Changed |= scalarizeVectorReduce(HlslOP, CI);
8689
continue;
90+
case DXIL::OpCodeClass::Dot:
91+
Changed |= scalarizeVectorDot(HlslOP, CI);
92+
continue;
8793
case DXIL::OpCodeClass::WaveMatch:
8894
Changed |= scalarizeVectorWaveMatch(HlslOP, CI);
8995
continue;
@@ -337,6 +343,66 @@ static bool scalarizeVectorWaveMatch(hlsl::OP *HlslOP, CallInst *CI) {
337343
return true;
338344
}
339345

346+
// Scalarize vectorized dot product
347+
static bool scalarizeVectorDot(hlsl::OP *HlslOP, CallInst *CI) {
348+
IRBuilder<> Builder(CI);
349+
350+
Value *AVecArg = CI->getArgOperand(1);
351+
Value *BVecArg = CI->getArgOperand(2);
352+
VectorType *VecTy = cast<VectorType>(AVecArg->getType());
353+
Type *ScalarTy = VecTy->getScalarType();
354+
const unsigned VecSize = VecTy->getNumElements();
355+
356+
// The only valid opcode is FDot which only has floating point overload.
357+
// If we hit this assert then this functions lowering needs to be updated
358+
assert(ScalarTy->isFloatingPointTy() && "Unexpected scalar type");
359+
360+
SmallVector<Value *, 4> AElts(VecSize);
361+
SmallVector<Value *, 4> BElts(VecSize);
362+
363+
for (unsigned EltIdx = 0; EltIdx < VecSize; EltIdx++) {
364+
AElts[EltIdx] = Builder.CreateExtractElement(AVecArg, EltIdx);
365+
BElts[EltIdx] = Builder.CreateExtractElement(BVecArg, EltIdx);
366+
}
367+
368+
DXIL::OpCode DotOp = DXIL::OpCode::Dot4;
369+
switch (VecSize) {
370+
// Calling dot on a vec1 is not typical but also not impossible
371+
// DXIL doesn't have a native Dot1 opcode but thats the same as a
372+
// single FMul. HLOperation lower is expected to do the conversion
373+
// so we assert here in case that ever changes.
374+
case 1:
375+
assert(false && "vector dot shouldn't appear for vec1");
376+
break;
377+
case 2:
378+
DotOp = DXIL::OpCode::Dot2;
379+
break;
380+
case 3:
381+
DotOp = DXIL::OpCode::Dot3;
382+
break;
383+
case 4:
384+
DotOp = DXIL::OpCode::Dot4;
385+
break;
386+
default:
387+
assert(false &&
388+
"Vectors larger than 4 components are not supported in SM6.8");
389+
break;
390+
}
391+
392+
SmallVector<Value *, 9> Args(VecSize * 2 + 1);
393+
Args[0] = Builder.getInt32((unsigned)DotOp);
394+
395+
for (unsigned EltIdx = 0; EltIdx < VecSize; EltIdx++) {
396+
Args[EltIdx + 1] = AElts[EltIdx];
397+
Args[EltIdx + 1 + VecSize] = BElts[EltIdx];
398+
}
399+
400+
Function *Func = HlslOP->GetOpFunc(DotOp, ScalarTy);
401+
Value *Dot = Builder.CreateCall(Func, Args, CI->getName());
402+
CI->replaceAllUsesWith(Dot);
403+
return true;
404+
}
405+
340406
// Scalarize native vector operation represented by `CI`, generating
341407
// scalar calls for each element of the its vector parameters.
342408
// Use `HlslOP` to retrieve the associated scalar op function.

lib/HLSL/HLOperationLower.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2606,8 +2606,10 @@ Value *TranslateDot(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
26062606
Type *EltTy = Ty->getScalarType();
26072607

26082608
// SM6.9 introduced a DXIL operation for vectorized dot product
2609+
// The operation is only advantageous for vect size>1, vec1s will be
2610+
// lowered to a single Mul.
26092611
if (hlslOP->GetModule()->GetHLModule().GetShaderModel()->IsSM69Plus() &&
2610-
EltTy->isFloatingPointTy()) {
2612+
EltTy->isFloatingPointTy() && Ty->getVectorNumElements() > 1) {
26112613
Value *arg1 = CI->getArgOperand(HLOperandIndex::kBinaryOpSrc1Idx);
26122614
IRBuilder<> Builder(CI);
26132615
Constant *opArg = hlslOP->GetU32Const((unsigned)DXIL::OpCode::FDot);

tools/clang/test/CodeGenDXIL/hlsl/types/longvec-scalarized-intrinsics-sm68.hlsl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,27 @@ float4 main(uint i : SV_PrimitiveID, uint4 m : M) : SV_Target {
146146
// CHECK: call float @dx.op.unary.f32(i32 21, float %{{.*}}) ; Exp(value)
147147
res += pow(vec1, vec2);
148148

149-
// CHECK: mul i32
150-
// CHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; UMad(a,b,c)
151-
// CHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; UMad(a,b,c)
152-
// CHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; UMad(a,b,c)
153-
res += dot(ivec1, ivec2);
149+
vector<float, 2> fDot2L = rbuf.Load< vector<float, 2> >(i++*32);
150+
vector<float, 2> fDot2R = rbuf.Load< vector<float, 2> >(i++*32);
151+
vector<float, 3> fDot3L = rbuf.Load< vector<float, 3> >(i++*32);
152+
vector<float, 3> fDot3R = rbuf.Load< vector<float, 3> >(i++*32);
153+
vector<float, 4> fDot4L = rbuf.Load< vector<float, 4> >(i++*32);
154+
vector<float, 4> fDot4R = rbuf.Load< vector<float, 4> >(i++*32);
155+
vector<float, 4> fDotRes = 0;
156+
157+
// CHECK: fmul fast float %{{.*}}, %{{.*}}
158+
fDotRes[0] = dot(fDot2L.x, fDot4R.w);
159+
160+
// CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; Dot2(ax,ay,bx,by)
161+
fDotRes[1] = dot(fDot2L, fDot2R);
162+
163+
// CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; Dot3(ax,ay,az,bx,by,bz)
164+
fDotRes[2] = dot(fDot3L, fDot3R);
165+
166+
// CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) ; Dot4(ax,ay,az,aw,bx,by,bz,bw)
167+
fDotRes[3] = dot(fDot4L, fDot4R);
168+
169+
res += fDotRes;
154170

155171
// CHECK: call float @dx.op.unary.f32(i32 29, float %{{.*}}) ; Round_z(value)
156172
// CHECK: call float @dx.op.unary.f32(i32 29, float %{{.*}}) ; Round_z(value)

0 commit comments

Comments
 (0)