Skip to content

Commit c9b595d

Browse files
committed
Optimize AArch64 memset to use NEON DUP instruction for small sizes
This change improves memset code generation for non-zero values on AArch64 for sizes 4, 8, and 16 bytes by using NEON's DUP instruction instead of the less efficient multiplication with 0x01010101 pattern. Changes: 1. In SelectionDAG.cpp: For AArch64 targets, generate vector splats for scalar i32/i64 memset operations, which are then efficiently lowered to DUP instructions. 2. In AArch64ISelLowering.cpp: Modify getOptimalMemOpType and getOptimalMemOpLLT to return v16i8 for non-zero memset operations of any size when NEON is available (previously only for sizes >= 32 bytes). 3. Update test expectations to verify the new DUP-based code generation for both NEON and GPR code paths. The optimization is restricted to AArch64 only to avoid breaking RISCV and X86 tests. Signed-off-by: Osama Abdelkader <[email protected]>
1 parent 269f264 commit c9b595d

File tree

4 files changed

+113
-41
lines changed

4 files changed

+113
-41
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8543,6 +8543,20 @@ static SDValue getMemsetValue(SDValue Value, EVT VT, SelectionDAG &DAG,
85438543
if (!IntVT.isInteger())
85448544
IntVT = EVT::getIntegerVT(*DAG.getContext(), IntVT.getSizeInBits());
85458545

8546+
// For repeated-byte patterns, generate a vector splat instead of MUL to
8547+
// enable efficient lowering to DUP on targets like AArch64.
8548+
// Only do this on AArch64 targets to avoid breaking other architectures.
8549+
const TargetMachine &TM = DAG.getTarget();
8550+
if (NumBits > 8 && VT.isInteger() && !VT.isVector() &&
8551+
(NumBits == 32 || NumBits == 64) &&
8552+
TM.getTargetTriple().getArch() == Triple::aarch64) {
8553+
// Generate a vector of bytes: v4i8 for i32, v8i8 for i64
8554+
EVT ByteVecTy = EVT::getVectorVT(*DAG.getContext(), MVT::i8, NumBits / 8);
8555+
SDValue VecSplat = DAG.getSplatBuildVector(ByteVecTy, dl, Value);
8556+
// Bitcast back to the target integer type
8557+
return DAG.getNode(ISD::BITCAST, dl, IntVT, VecSplat);
8558+
}
8559+
85468560
Value = DAG.getNode(ISD::ZERO_EXTEND, dl, IntVT, Value);
85478561
if (NumBits > 8) {
85488562
// Use a multiplication with 0x010101... to extend the input to the

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18328,10 +18328,11 @@ EVT AArch64TargetLowering::getOptimalMemOpType(
1832818328
bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
1832918329
bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
1833018330
bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
18331-
// Only use AdvSIMD to implement memset of 32-byte and above. It would have
18331+
// For zero memset, only use AdvSIMD for 32-byte and above. It would have
1833218332
// taken one instruction to materialize the v2i64 zero and one store (with
1833318333
// restrictive addressing mode). Just do i64 stores.
18334-
bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
18334+
// For non-zero memset, use NEON even for smaller sizes as dup is efficient.
18335+
bool IsSmallZeroMemset = Op.isMemset() && Op.size() < 32 && Op.isZeroMemset();
1833518336
auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
1833618337
if (Op.isAligned(AlignCheck))
1833718338
return true;
@@ -18341,10 +18342,12 @@ EVT AArch64TargetLowering::getOptimalMemOpType(
1834118342
Fast;
1834218343
};
1834318344

18344-
if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
18345-
AlignmentIsAcceptable(MVT::v16i8, Align(16)))
18345+
// For non-zero memset, use NEON even for smaller sizes as dup + scalar store
18346+
// is efficient
18347+
if (CanUseNEON && Op.isMemset() && !IsSmallZeroMemset)
1834618348
return MVT::v16i8;
18347-
if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
18349+
if (CanUseFP && !IsSmallZeroMemset &&
18350+
AlignmentIsAcceptable(MVT::f128, Align(16)))
1834818351
return MVT::f128;
1834918352
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
1835018353
return MVT::i64;
@@ -18358,10 +18361,11 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
1835818361
bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
1835918362
bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
1836018363
bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
18361-
// Only use AdvSIMD to implement memset of 32-byte and above. It would have
18364+
// For zero memset, only use AdvSIMD for 32-byte and above. It would have
1836218365
// taken one instruction to materialize the v2i64 zero and one store (with
1836318366
// restrictive addressing mode). Just do i64 stores.
18364-
bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
18367+
// For non-zero memset, use NEON even for smaller sizes as dup is efficient.
18368+
bool IsSmallZeroMemset = Op.isMemset() && Op.size() < 32 && Op.isZeroMemset();
1836518369
auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
1836618370
if (Op.isAligned(AlignCheck))
1836718371
return true;
@@ -18371,10 +18375,12 @@ LLT AArch64TargetLowering::getOptimalMemOpLLT(
1837118375
Fast;
1837218376
};
1837318377

18374-
if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
18375-
AlignmentIsAcceptable(MVT::v2i64, Align(16)))
18378+
// For non-zero memset, use NEON for all sizes where it's beneficial.
18379+
// NEON dup + scalar store works for any alignment and is efficient.
18380+
if (CanUseNEON && Op.isMemset() && !IsSmallZeroMemset)
1837618381
return LLT::fixed_vector(2, 64);
18377-
if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
18382+
if (CanUseFP && !IsSmallZeroMemset &&
18383+
AlignmentIsAcceptable(MVT::f128, Align(16)))
1837818384
return LLT::scalar(128);
1837918385
if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
1838018386
return LLT::scalar(64);
@@ -29702,6 +29708,31 @@ AArch64TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
2970229708
.getInstr();
2970329709
}
2970429710

29711+
bool AArch64TargetLowering::shallExtractConstSplatVectorElementToStore(
29712+
Type *VectorTy, unsigned ElemSizeInBits, unsigned &Index) const {
29713+
// On AArch64, we can efficiently extract a scalar from a splat vector using
29714+
// str s/d/q0 which extracts 32/64/128 bits from the vector register.
29715+
// This is useful for memset where we generate a v16i8 splat and need to store
29716+
// a smaller scalar (e.g., i32 for a 4-byte memset).
29717+
if (FixedVectorType *VTy = dyn_cast<FixedVectorType>(VectorTy)) {
29718+
// Only handle v16i8 splat (128 bits total, 16 elements of 8 bits each)
29719+
if (VTy->getNumElements() == 16 && VTy->getElementType()->isIntegerTy(8)) {
29720+
// Check if we're extracting a 32-bit or 64-bit element
29721+
if (ElemSizeInBits == 32) {
29722+
// Extract element 0 of the 128-bit vector as a 32-bit scalar
29723+
Index = 0;
29724+
return true;
29725+
}
29726+
if (ElemSizeInBits == 64) {
29727+
// Extract elements 0-7 as a 64-bit scalar
29728+
Index = 0;
29729+
return true;
29730+
}
29731+
}
29732+
}
29733+
return false;
29734+
}
29735+
2970529736
bool AArch64TargetLowering::enableAggressiveFMAFusion(EVT VT) const {
2970629737
return Subtarget->hasAggressiveFMA() && VT.isFloatingPoint();
2970729738
}

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ class AArch64TargetLowering : public TargetLowering {
475475
MachineBasicBlock::instr_iterator &MBBI,
476476
const TargetInstrInfo *TII) const override;
477477

478+
bool shallExtractConstSplatVectorElementToStore(
479+
Type *VectorTy, unsigned ElemSizeInBits, unsigned &Index) const override;
480+
478481
/// Enable aggressive FMA fusion on targets that want it.
479482
bool enableAggressiveFMAFusion(EVT VT) const override;
480483

llvm/test/CodeGen/AArch64/memset-inline.ll

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,57 @@ define void @memset_2(ptr %a, i8 %value) nounwind {
2727
}
2828

2929
define void @memset_4(ptr %a, i8 %value) nounwind {
30-
; ALL-LABEL: memset_4:
31-
; ALL: // %bb.0:
32-
; ALL-NEXT: mov w8, #16843009
33-
; ALL-NEXT: and w9, w1, #0xff
34-
; ALL-NEXT: mul w8, w9, w8
35-
; ALL-NEXT: str w8, [x0]
36-
; ALL-NEXT: ret
30+
; GPR-LABEL: memset_4:
31+
; GPR: // %bb.0:
32+
; GPR-NEXT: mov w8, #16843009
33+
; GPR-NEXT: and w9, w1, #0xff
34+
; GPR-NEXT: mul w8, w9, w8
35+
; GPR-NEXT: str w8, [x0]
36+
; GPR-NEXT: ret
37+
;
38+
; NEON-LABEL: memset_4:
39+
; NEON: // %bb.0:
40+
; NEON-NEXT: dup v0.8b, w1
41+
; NEON-NEXT: str s0, [x0]
42+
; NEON-NEXT: ret
3743
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 4, i1 0)
3844
ret void
3945
}
4046

4147
define void @memset_8(ptr %a, i8 %value) nounwind {
42-
; ALL-LABEL: memset_8:
43-
; ALL: // %bb.0:
44-
; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
45-
; ALL-NEXT: mov x8, #72340172838076673
46-
; ALL-NEXT: and x9, x1, #0xff
47-
; ALL-NEXT: mul x8, x9, x8
48-
; ALL-NEXT: str x8, [x0]
49-
; ALL-NEXT: ret
48+
; GPR-LABEL: memset_8:
49+
; GPR: // %bb.0:
50+
; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
51+
; GPR-NEXT: mov x8, #72340172838076673
52+
; GPR-NEXT: and x9, x1, #0xff
53+
; GPR-NEXT: mul x8, x9, x8
54+
; GPR-NEXT: str x8, [x0]
55+
; GPR-NEXT: ret
56+
;
57+
; NEON-LABEL: memset_8:
58+
; NEON: // %bb.0:
59+
; NEON-NEXT: dup v0.8b, w1
60+
; NEON-NEXT: str d0, [x0]
61+
; NEON-NEXT: ret
5062
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 8, i1 0)
5163
ret void
5264
}
5365

5466
define void @memset_16(ptr %a, i8 %value) nounwind {
55-
; ALL-LABEL: memset_16:
56-
; ALL: // %bb.0:
57-
; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
58-
; ALL-NEXT: mov x8, #72340172838076673
59-
; ALL-NEXT: and x9, x1, #0xff
60-
; ALL-NEXT: mul x8, x9, x8
61-
; ALL-NEXT: stp x8, x8, [x0]
62-
; ALL-NEXT: ret
67+
; GPR-LABEL: memset_16:
68+
; GPR: // %bb.0:
69+
; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
70+
; GPR-NEXT: mov x8, #72340172838076673
71+
; GPR-NEXT: and x9, x1, #0xff
72+
; GPR-NEXT: mul x8, x9, x8
73+
; GPR-NEXT: stp x8, x8, [x0]
74+
; GPR-NEXT: ret
75+
;
76+
; NEON-LABEL: memset_16:
77+
; NEON: // %bb.0:
78+
; NEON-NEXT: dup v0.16b, w1
79+
; NEON-NEXT: str q0, [x0]
80+
; NEON-NEXT: ret
6381
tail call void @llvm.memset.inline.p0.i64(ptr %a, i8 %value, i64 16, i1 0)
6482
ret void
6583
}
@@ -110,14 +128,20 @@ define void @memset_64(ptr %a, i8 %value) nounwind {
110128
; /////////////////////////////////////////////////////////////////////////////
111129

112130
define void @aligned_memset_16(ptr align 16 %a, i8 %value) nounwind {
113-
; ALL-LABEL: aligned_memset_16:
114-
; ALL: // %bb.0:
115-
; ALL-NEXT: // kill: def $w1 killed $w1 def $x1
116-
; ALL-NEXT: mov x8, #72340172838076673
117-
; ALL-NEXT: and x9, x1, #0xff
118-
; ALL-NEXT: mul x8, x9, x8
119-
; ALL-NEXT: stp x8, x8, [x0]
120-
; ALL-NEXT: ret
131+
; GPR-LABEL: aligned_memset_16:
132+
; GPR: // %bb.0:
133+
; GPR-NEXT: // kill: def $w1 killed $w1 def $x1
134+
; GPR-NEXT: mov x8, #72340172838076673
135+
; GPR-NEXT: and x9, x1, #0xff
136+
; GPR-NEXT: mul x8, x9, x8
137+
; GPR-NEXT: stp x8, x8, [x0]
138+
; GPR-NEXT: ret
139+
;
140+
; NEON-LABEL: aligned_memset_16:
141+
; NEON: // %bb.0:
142+
; NEON-NEXT: dup v0.16b, w1
143+
; NEON-NEXT: str q0, [x0]
144+
; NEON-NEXT: ret
121145
tail call void @llvm.memset.inline.p0.i64(ptr align 16 %a, i8 %value, i64 16, i1 0)
122146
ret void
123147
}

0 commit comments

Comments
 (0)