1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21
22 #include "AArch64TargetMachine.h"
23 #include "llvm/CodeGen/GlobalISel/Combiner.h"
24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
26 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
27 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
28 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
29 #include "llvm/CodeGen/GlobalISel/Utils.h"
30 #include "llvm/CodeGen/MachineDominators.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineRegisterInfo.h"
33 #include "llvm/CodeGen/TargetOpcodes.h"
34 #include "llvm/CodeGen/TargetPassConfig.h"
35 #include "llvm/Support/Debug.h"
36
37 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
38
39 using namespace llvm;
40 using namespace MIPatternMatch;
41
42 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
43 /// Rewrite for pairwise fadd pattern
44 /// (s32 (g_extract_vector_elt
45 /// (g_fadd (vXs32 Other)
46 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
47 /// ->
48 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
49 /// (g_extract_vector_elt (vXs32 Other) 1))
matchExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<unsigned,LLT,Register> & MatchInfo)50 bool matchExtractVecEltPairwiseAdd(
51 MachineInstr &MI, MachineRegisterInfo &MRI,
52 std::tuple<unsigned, LLT, Register> &MatchInfo) {
53 Register Src1 = MI.getOperand(1).getReg();
54 Register Src2 = MI.getOperand(2).getReg();
55 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
56
57 auto Cst = getConstantVRegValWithLookThrough(Src2, MRI);
58 if (!Cst || Cst->Value != 0)
59 return false;
60 // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
61
62 // Now check for an fadd operation. TODO: expand this for integer add?
63 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
64 if (!FAddMI)
65 return false;
66
67 // If we add support for integer add, must restrict these types to just s64.
68 unsigned DstSize = DstTy.getSizeInBits();
69 if (DstSize != 16 && DstSize != 32 && DstSize != 64)
70 return false;
71
72 Register Src1Op1 = FAddMI->getOperand(1).getReg();
73 Register Src1Op2 = FAddMI->getOperand(2).getReg();
74 MachineInstr *Shuffle =
75 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
76 MachineInstr *Other = MRI.getVRegDef(Src1Op1);
77 if (!Shuffle) {
78 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
79 Other = MRI.getVRegDef(Src1Op2);
80 }
81
82 // We're looking for a shuffle that moves the second element to index 0.
83 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
84 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
85 std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
86 std::get<1>(MatchInfo) = DstTy;
87 std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
88 return true;
89 }
90 return false;
91 }
92
applyExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<unsigned,LLT,Register> & MatchInfo)93 bool applyExtractVecEltPairwiseAdd(
94 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
95 std::tuple<unsigned, LLT, Register> &MatchInfo) {
96 unsigned Opc = std::get<0>(MatchInfo);
97 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
98 // We want to generate two extracts of elements 0 and 1, and add them.
99 LLT Ty = std::get<1>(MatchInfo);
100 Register Src = std::get<2>(MatchInfo);
101 LLT s64 = LLT::scalar(64);
102 B.setInstrAndDebugLoc(MI);
103 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
104 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
105 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
106 MI.eraseFromParent();
107 return true;
108 }
109
isSignExtended(Register R,MachineRegisterInfo & MRI)110 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
111 // TODO: check if extended build vector as well.
112 unsigned Opc = MRI.getVRegDef(R)->getOpcode();
113 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
114 }
115
isZeroExtended(Register R,MachineRegisterInfo & MRI)116 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
117 // TODO: check if extended build vector as well.
118 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
119 }
120
matchAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)121 bool matchAArch64MulConstCombine(
122 MachineInstr &MI, MachineRegisterInfo &MRI,
123 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
124 assert(MI.getOpcode() == TargetOpcode::G_MUL);
125 Register LHS = MI.getOperand(1).getReg();
126 Register RHS = MI.getOperand(2).getReg();
127 Register Dst = MI.getOperand(0).getReg();
128 const LLT Ty = MRI.getType(LHS);
129
130 // The below optimizations require a constant RHS.
131 auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
132 if (!Const)
133 return false;
134
135 const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
136 // The following code is ported from AArch64ISelLowering.
137 // Multiplication of a power of two plus/minus one can be done more
138 // cheaply as as shift+add/sub. For now, this is true unilaterally. If
139 // future CPUs have a cheaper MADD instruction, this may need to be
140 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
141 // 64-bit is 5 cycles, so this is always a win.
142 // More aggressively, some multiplications N0 * C can be lowered to
143 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
144 // e.g. 6=3*2=(2+1)*2.
145 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
146 // which equals to (1+2)*16-(1+2).
147 // TrailingZeroes is used to test if the mul can be lowered to
148 // shift+add+shift.
149 unsigned TrailingZeroes = ConstValue.countTrailingZeros();
150 if (TrailingZeroes) {
151 // Conservatively do not lower to shift+add+shift if the mul might be
152 // folded into smul or umul.
153 if (MRI.hasOneNonDBGUse(LHS) &&
154 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
155 return false;
156 // Conservatively do not lower to shift+add+shift if the mul might be
157 // folded into madd or msub.
158 if (MRI.hasOneNonDBGUse(Dst)) {
159 MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
160 unsigned UseOpc = UseMI.getOpcode();
161 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
162 UseOpc == TargetOpcode::G_SUB)
163 return false;
164 }
165 }
166 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
167 // and shift+add+shift.
168 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
169
170 unsigned ShiftAmt, AddSubOpc;
171 // Is the shifted value the LHS operand of the add/sub?
172 bool ShiftValUseIsLHS = true;
173 // Do we need to negate the result?
174 bool NegateResult = false;
175
176 if (ConstValue.isNonNegative()) {
177 // (mul x, 2^N + 1) => (add (shl x, N), x)
178 // (mul x, 2^N - 1) => (sub (shl x, N), x)
179 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
180 APInt SCVMinus1 = ShiftedConstValue - 1;
181 APInt CVPlus1 = ConstValue + 1;
182 if (SCVMinus1.isPowerOf2()) {
183 ShiftAmt = SCVMinus1.logBase2();
184 AddSubOpc = TargetOpcode::G_ADD;
185 } else if (CVPlus1.isPowerOf2()) {
186 ShiftAmt = CVPlus1.logBase2();
187 AddSubOpc = TargetOpcode::G_SUB;
188 } else
189 return false;
190 } else {
191 // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
192 // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
193 APInt CVNegPlus1 = -ConstValue + 1;
194 APInt CVNegMinus1 = -ConstValue - 1;
195 if (CVNegPlus1.isPowerOf2()) {
196 ShiftAmt = CVNegPlus1.logBase2();
197 AddSubOpc = TargetOpcode::G_SUB;
198 ShiftValUseIsLHS = false;
199 } else if (CVNegMinus1.isPowerOf2()) {
200 ShiftAmt = CVNegMinus1.logBase2();
201 AddSubOpc = TargetOpcode::G_ADD;
202 NegateResult = true;
203 } else
204 return false;
205 }
206
207 if (NegateResult && TrailingZeroes)
208 return false;
209
210 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
211 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
212 auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
213
214 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
215 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
216 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
217 assert(!(NegateResult && TrailingZeroes) &&
218 "NegateResult and TrailingZeroes cannot both be true for now.");
219 // Negate the result.
220 if (NegateResult) {
221 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
222 return;
223 }
224 // Shift the result.
225 if (TrailingZeroes) {
226 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
227 return;
228 }
229 B.buildCopy(DstReg, Res.getReg(0));
230 };
231 return true;
232 }
233
applyAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)234 bool applyAArch64MulConstCombine(
235 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
236 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
237 B.setInstrAndDebugLoc(MI);
238 ApplyFn(B, MI.getOperand(0).getReg());
239 MI.eraseFromParent();
240 return true;
241 }
242
243 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
matchBitfieldExtractFromSExtInReg(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder &)> & MatchInfo)244 static bool matchBitfieldExtractFromSExtInReg(
245 MachineInstr &MI, MachineRegisterInfo &MRI,
246 std::function<void(MachineIRBuilder &)> &MatchInfo) {
247 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
248 Register Dst = MI.getOperand(0).getReg();
249 Register Src = MI.getOperand(1).getReg();
250 int64_t Width = MI.getOperand(2).getImm();
251 LLT Ty = MRI.getType(Src);
252 assert((Ty == LLT::scalar(32) || Ty == LLT::scalar(64)) &&
253 "Unexpected type for G_SEXT_INREG?");
254 Register ShiftSrc;
255 int64_t ShiftImm;
256 if (!mi_match(
257 Src, MRI,
258 m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)),
259 m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm))))))
260 return false;
261 if (ShiftImm < 0 || ShiftImm + Width > Ty.getSizeInBits())
262 return false;
263 MatchInfo = [=](MachineIRBuilder &B) {
264 auto Cst1 = B.buildConstant(Ty, ShiftImm);
265 auto Cst2 = B.buildConstant(Ty, Width);
266 B.buildInstr(TargetOpcode::G_SBFX, {Dst}, {ShiftSrc, Cst1, Cst2});
267 };
268 return true;
269 }
270
271 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
272 #include "AArch64GenPostLegalizeGICombiner.inc"
273 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
274
275 namespace {
276 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
277 #include "AArch64GenPostLegalizeGICombiner.inc"
278 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
279
280 class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
281 GISelKnownBits *KB;
282 MachineDominatorTree *MDT;
283
284 public:
285 AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
286
AArch64PostLegalizerCombinerInfo(bool EnableOpt,bool OptSize,bool MinSize,GISelKnownBits * KB,MachineDominatorTree * MDT)287 AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
288 GISelKnownBits *KB,
289 MachineDominatorTree *MDT)
290 : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
291 /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
292 KB(KB), MDT(MDT) {
293 if (!GeneratedRuleCfg.parseCommandLineOption())
294 report_fatal_error("Invalid rule identifier");
295 }
296
297 virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
298 MachineIRBuilder &B) const override;
299 };
300
combine(GISelChangeObserver & Observer,MachineInstr & MI,MachineIRBuilder & B) const301 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
302 MachineInstr &MI,
303 MachineIRBuilder &B) const {
304 const auto *LI =
305 MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
306 CombinerHelper Helper(Observer, B, KB, MDT, LI);
307 AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
308 return Generated.tryCombineAll(Observer, MI, B, Helper);
309 }
310
311 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
312 #include "AArch64GenPostLegalizeGICombiner.inc"
313 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
314
315 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
316 public:
317 static char ID;
318
319 AArch64PostLegalizerCombiner(bool IsOptNone = false);
320
getPassName() const321 StringRef getPassName() const override {
322 return "AArch64PostLegalizerCombiner";
323 }
324
325 bool runOnMachineFunction(MachineFunction &MF) override;
326 void getAnalysisUsage(AnalysisUsage &AU) const override;
327
328 private:
329 bool IsOptNone;
330 };
331 } // end anonymous namespace
332
getAnalysisUsage(AnalysisUsage & AU) const333 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
334 AU.addRequired<TargetPassConfig>();
335 AU.setPreservesCFG();
336 getSelectionDAGFallbackAnalysisUsage(AU);
337 AU.addRequired<GISelKnownBitsAnalysis>();
338 AU.addPreserved<GISelKnownBitsAnalysis>();
339 if (!IsOptNone) {
340 AU.addRequired<MachineDominatorTree>();
341 AU.addPreserved<MachineDominatorTree>();
342 AU.addRequired<GISelCSEAnalysisWrapperPass>();
343 AU.addPreserved<GISelCSEAnalysisWrapperPass>();
344 }
345 MachineFunctionPass::getAnalysisUsage(AU);
346 }
347
AArch64PostLegalizerCombiner(bool IsOptNone)348 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
349 : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
350 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
351 }
352
runOnMachineFunction(MachineFunction & MF)353 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
354 if (MF.getProperties().hasProperty(
355 MachineFunctionProperties::Property::FailedISel))
356 return false;
357 assert(MF.getProperties().hasProperty(
358 MachineFunctionProperties::Property::Legalized) &&
359 "Expected a legalized function?");
360 auto *TPC = &getAnalysis<TargetPassConfig>();
361 const Function &F = MF.getFunction();
362 bool EnableOpt =
363 MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
364 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
365 MachineDominatorTree *MDT =
366 IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
367 AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
368 F.hasMinSize(), KB, MDT);
369 GISelCSEAnalysisWrapper &Wrapper =
370 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
371 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
372 Combiner C(PCInfo, TPC);
373 return C.combineMachineInstrs(MF, CSEInfo);
374 }
375
376 char AArch64PostLegalizerCombiner::ID = 0;
377 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
378 "Combine AArch64 MachineInstrs after legalization", false,
379 false)
380 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
381 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
382 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
383 "Combine AArch64 MachineInstrs after legalization", false,
384 false)
385
386 namespace llvm {
createAArch64PostLegalizerCombiner(bool IsOptNone)387 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
388 return new AArch64PostLegalizerCombiner(IsOptNone);
389 }
390 } // end namespace llvm
391