Skip to content

Commit 834a3cc

Browse files
authored
[SPIRV] Handle ptrcast between array and vector types (llvm#166418)
This commit adds support for legalizing pointer casts between array and vector types within the SPIRV backend. This is necessary to handle cases where a vector is loaded from or stored to an array, which can occur with HLSL matrix types. The following changes are included: - Added to load a vector from an array. - Added to store a vector to an array. - Added the test case to verify the functionality.
1 parent 1deaedd commit 834a3cc

File tree

2 files changed

+134
-0
lines changed

2 files changed

+134
-0
lines changed

llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,81 @@ class SPIRVLegalizePointerCast : public FunctionPass {
116116
return LI;
117117
}
118118

119+
// Loads elements from an array and constructs a vector.
120+
Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
121+
Value *Source) {
122+
// Load each element of the array.
123+
SmallVector<Value *, 4> LoadedElements;
124+
for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
125+
// Create a GEP to access the i-th element of the array.
126+
SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
127+
SmallVector<Value *, 4> Args;
128+
Args.push_back(B.getInt1(false));
129+
Args.push_back(Source);
130+
Args.push_back(B.getInt32(0));
131+
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
132+
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
133+
GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);
134+
135+
// Load the value from the element pointer.
136+
Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr);
137+
buildAssignType(B, TargetType->getElementType(), Load);
138+
LoadedElements.push_back(Load);
139+
}
140+
141+
// Build the vector from the loaded elements.
142+
Value *NewVector = PoisonValue::get(TargetType);
143+
buildAssignType(B, TargetType, NewVector);
144+
145+
for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
146+
Value *Index = B.getInt32(i);
147+
SmallVector<Type *, 4> Types = {TargetType, TargetType,
148+
TargetType->getElementType(),
149+
Index->getType()};
150+
SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
151+
NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
152+
buildAssignType(B, TargetType, NewVector);
153+
}
154+
return NewVector;
155+
}
156+
157+
// Stores elements from a vector into an array.
158+
void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
159+
Value *DstArrayPtr, ArrayType *ArrTy,
160+
Align Alignment) {
161+
auto *VecTy = cast<FixedVectorType>(SrcVector->getType());
162+
163+
// Ensure the element types of the array and vector are the same.
164+
assert(VecTy->getElementType() == ArrTy->getElementType() &&
165+
"Element types of array and vector must be the same.");
166+
167+
for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
168+
// Create a GEP to access the i-th element of the array.
169+
SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
170+
DstArrayPtr->getType()};
171+
SmallVector<Value *, 4> Args;
172+
Args.push_back(B.getInt1(false));
173+
Args.push_back(DstArrayPtr);
174+
Args.push_back(B.getInt32(0));
175+
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
176+
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
177+
GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);
178+
179+
// Extract the element from the vector and store it.
180+
Value *Index = B.getInt32(i);
181+
SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
182+
Index->getType()};
183+
SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
184+
Value *Element =
185+
B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
186+
buildAssignType(B, VecTy->getElementType(), Element);
187+
188+
Types = {Element->getType(), ElementPtr->getType()};
189+
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
190+
B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
191+
}
192+
}
193+
119194
// Replaces the load instruction to get rid of the ptrcast used as source
120195
// operand.
121196
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
@@ -154,6 +229,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
154229
// - float v = s.m;
155230
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
156231
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
232+
else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
233+
Output = loadVectorFromArray(B, DVT, OriginalOperand);
157234
else
158235
llvm_unreachable("Unimplemented implicit down-cast from load.");
159236

@@ -288,6 +365,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
288365
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
289366
auto *D_ST = dyn_cast<StructType>(ToTy);
290367
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
368+
auto *D_AT = dyn_cast<ArrayType>(ToTy);
291369

292370
B.SetInsertPoint(BadStore);
293371
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
@@ -296,6 +374,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
296374
storeVectorFromVector(B, Src, Dst, Alignment);
297375
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
298376
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
377+
else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
378+
storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
299379
else
300380
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");
301381

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
5+
; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4
6+
; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0
7+
; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4
8+
; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]]
9+
; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]]
10+
; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
11+
; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
12+
; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0
13+
; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1
14+
; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2
15+
; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3
16+
; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]]
17+
; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]]
18+
19+
@G_in = internal addrspace(10) global [4 x float] zeroinitializer
20+
@G_out = internal addrspace(10) global [4 x float] zeroinitializer
21+
22+
define spir_func void @main() {
23+
entry:
24+
; CHECK: [[GEP0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]]
25+
; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]]
26+
; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]]
27+
; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]]
28+
; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]]
29+
; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]]
30+
; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]]
31+
; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]]
32+
; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0
33+
; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1
34+
; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2
35+
; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3
36+
%0 = load <4 x float>, ptr addrspace(10) @G_in, align 64
37+
38+
; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]]
39+
; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0
40+
; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]]
41+
; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]]
42+
; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1
43+
; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]]
44+
; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]]
45+
; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2
46+
; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]]
47+
; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]]
48+
; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3
49+
; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]]
50+
store <4 x float> %0, ptr addrspace(10) @G_out, align 64
51+
52+
; CHECK-NEXT: OpReturn
53+
ret void
54+
}

0 commit comments

Comments
 (0)