@@ -153,13 +153,6 @@ cl::opt<bool> EnableSVEGISel(
153153 cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
154154 cl::init(false));
155155
156- // FIXME : This is a temporary flag, and is used to help transition to
157- // performing lowering the proper way using the new PARTIAL_REDUCE_MLA ISD
158- // nodes.
159- static cl::opt<bool> EnablePartialReduceNodes(
160- "aarch64-enable-partial-reduce-nodes", cl::init(false), cl::ReallyHidden,
161- cl::desc("Use the new method of lowering partial reductions."));
162-
163156/// Value type used for condition codes.
164157static const MVT MVT_CC = MVT::i32;
165158
@@ -1457,7 +1450,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14571450 for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
14581451 setOperationAction(ISD::FADD, VT, Custom);
14591452
1460- if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
1453+ if (Subtarget->hasDotProd()) {
14611454 static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
14621455 ISD::PARTIAL_REDUCE_UMLA};
14631456
@@ -1895,7 +1888,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18951888 }
18961889
18971890 // Handle partial reduction operations
1898- if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
1891+ if (Subtarget->isSVEorStreamingSVEAvailable()) {
18991892 // Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
19001893 // Other pairs will default to 'Expand'.
19011894 static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
@@ -1957,17 +1950,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
19571950 setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
19581951 Custom);
19591952
1960- if (EnablePartialReduceNodes) {
1961- static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1962- ISD::PARTIAL_REDUCE_UMLA};
1963- // Must be lowered to SVE instructions.
1964- setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1965- setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1966- setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1967- setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1968- setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1969- setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
1970- }
1953+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
1954+ ISD::PARTIAL_REDUCE_UMLA};
1955+ // Must be lowered to SVE instructions.
1956+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v4i32, Custom);
1957+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v8i16, Custom);
1958+ setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
1959+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v8i16, Custom);
1960+ setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Custom);
1961+ setPartialReduceMLAAction(MLAOps, MVT::v8i16, MVT::v16i8, Custom);
19711962 }
19721963 }
19731964
@@ -2165,16 +2156,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
21652156 assert(I->getIntrinsicID() ==
21662157 Intrinsic::experimental_vector_partial_reduce_add &&
21672158 "Unexpected intrinsic!");
2168- if (EnablePartialReduceNodes)
2169- return true;
2170-
2171- EVT VT = EVT::getEVT(I->getType());
2172- auto Op1 = I->getOperand(1);
2173- EVT Op1VT = EVT::getEVT(Op1->getType());
2174- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
2175- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
2176- VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
2177- return false;
21782159 return true;
21792160}
21802161
@@ -2252,37 +2233,32 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22522233 bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
22532234 bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
22542235
2255- if (EnablePartialReduceNodes) {
2256- static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2257- ISD::PARTIAL_REDUCE_UMLA};
2258- unsigned NumElts = VT.getVectorNumElements();
2259- if (VT.getVectorElementType() == MVT::i64) {
2260- setPartialReduceMLAAction(MLAOps, VT,
2261- MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2262- setPartialReduceMLAAction(
2263- MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2264- setPartialReduceMLAAction(
2265- MLAOps, VT, MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2266- } else if (VT.getVectorElementType() == MVT::i32) {
2267- setPartialReduceMLAAction(MLAOps, VT,
2236+ static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
2237+ ISD::PARTIAL_REDUCE_UMLA};
2238+ unsigned NumElts = VT.getVectorNumElements();
2239+ if (VT.getVectorElementType() == MVT::i64) {
2240+ setPartialReduceMLAAction(MLAOps, VT,
2241+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
2242+ setPartialReduceMLAAction(MLAOps, VT,
2243+ MVT::getVectorVT(MVT::i16, NumElts * 4), Custom);
2244+ setPartialReduceMLAAction(MLAOps, VT,
2245+ MVT::getVectorVT(MVT::i32, NumElts * 2), Custom);
2246+ } else if (VT.getVectorElementType() == MVT::i32) {
2247+ setPartialReduceMLAAction(MLAOps, VT,
2248+ MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2249+ setPartialReduceMLAAction(MLAOps, VT,
2250+ MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2251+ } else if (VT.getVectorElementType() == MVT::i16) {
2252+ setPartialReduceMLAAction(MLAOps, VT,
2253+ MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2254+ }
2255+ if (Subtarget->hasMatMulInt8()) {
2256+ if (VT.getVectorElementType() == MVT::i32)
2257+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
22682258 MVT::getVectorVT(MVT::i8, NumElts * 4), Custom);
2269- setPartialReduceMLAAction(
2270- MLAOps, VT, MVT::getVectorVT(MVT::i16, NumElts * 2), Custom);
2271- } else if (VT.getVectorElementType() == MVT::i16) {
2272- setPartialReduceMLAAction(MLAOps, VT,
2273- MVT::getVectorVT(MVT::i8, NumElts * 2), Custom);
2274- }
2275-
2276- if (Subtarget->hasMatMulInt8()) {
2277- if (VT.getVectorElementType() == MVT::i32)
2278- setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2279- MVT::getVectorVT(MVT::i8, NumElts * 4),
2280- Custom);
2281- else if (VT.getVectorElementType() == MVT::i64)
2282- setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2283- MVT::getVectorVT(MVT::i8, NumElts * 8),
2284- Custom);
2285- }
2259+ else if (VT.getVectorElementType() == MVT::i64)
2260+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SUMLA, VT,
2261+ MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
22862262 }
22872263
22882264 // Lower fixed length vector operations to scalable equivalents.
0 commit comments