xref: /netbsd-src/external/apache2/llvm/dist/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp (revision 82d56013d7b633d116a93943de88e08335357a7c)
1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// Post-legalization combines on generic MachineInstrs.
11 ///
12 /// The combines here must preserve instruction legality.
13 ///
14 /// Lowering combines (e.g. pseudo matching) should be handled by
15 /// AArch64PostLegalizerLowering.
16 ///
17 /// Combines which don't rely on instruction legality should go in the
18 /// AArch64PreLegalizerCombiner.
19 ///
20 //===----------------------------------------------------------------------===//
21 
22 #include "AArch64TargetMachine.h"
23 #include "llvm/CodeGen/GlobalISel/Combiner.h"
24 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
25 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
26 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
27 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
28 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
29 #include "llvm/CodeGen/GlobalISel/Utils.h"
30 #include "llvm/CodeGen/MachineDominators.h"
31 #include "llvm/CodeGen/MachineFunctionPass.h"
32 #include "llvm/CodeGen/MachineRegisterInfo.h"
33 #include "llvm/CodeGen/TargetOpcodes.h"
34 #include "llvm/CodeGen/TargetPassConfig.h"
35 #include "llvm/Support/Debug.h"
36 
37 #define DEBUG_TYPE "aarch64-postlegalizer-combiner"
38 
39 using namespace llvm;
40 using namespace MIPatternMatch;
41 
42 /// This combine tries do what performExtractVectorEltCombine does in SDAG.
43 /// Rewrite for pairwise fadd pattern
44 ///   (s32 (g_extract_vector_elt
45 ///           (g_fadd (vXs32 Other)
46 ///                  (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
47 /// ->
48 ///   (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
49 ///              (g_extract_vector_elt (vXs32 Other) 1))
matchExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,std::tuple<unsigned,LLT,Register> & MatchInfo)50 bool matchExtractVecEltPairwiseAdd(
51     MachineInstr &MI, MachineRegisterInfo &MRI,
52     std::tuple<unsigned, LLT, Register> &MatchInfo) {
53   Register Src1 = MI.getOperand(1).getReg();
54   Register Src2 = MI.getOperand(2).getReg();
55   LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
56 
57   auto Cst = getConstantVRegValWithLookThrough(Src2, MRI);
58   if (!Cst || Cst->Value != 0)
59     return false;
60   // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
61 
62   // Now check for an fadd operation. TODO: expand this for integer add?
63   auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
64   if (!FAddMI)
65     return false;
66 
67   // If we add support for integer add, must restrict these types to just s64.
68   unsigned DstSize = DstTy.getSizeInBits();
69   if (DstSize != 16 && DstSize != 32 && DstSize != 64)
70     return false;
71 
72   Register Src1Op1 = FAddMI->getOperand(1).getReg();
73   Register Src1Op2 = FAddMI->getOperand(2).getReg();
74   MachineInstr *Shuffle =
75       getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
76   MachineInstr *Other = MRI.getVRegDef(Src1Op1);
77   if (!Shuffle) {
78     Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
79     Other = MRI.getVRegDef(Src1Op2);
80   }
81 
82   // We're looking for a shuffle that moves the second element to index 0.
83   if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
84       Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
85     std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
86     std::get<1>(MatchInfo) = DstTy;
87     std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
88     return true;
89   }
90   return false;
91 }
92 
applyExtractVecEltPairwiseAdd(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::tuple<unsigned,LLT,Register> & MatchInfo)93 bool applyExtractVecEltPairwiseAdd(
94     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
95     std::tuple<unsigned, LLT, Register> &MatchInfo) {
96   unsigned Opc = std::get<0>(MatchInfo);
97   assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
98   // We want to generate two extracts of elements 0 and 1, and add them.
99   LLT Ty = std::get<1>(MatchInfo);
100   Register Src = std::get<2>(MatchInfo);
101   LLT s64 = LLT::scalar(64);
102   B.setInstrAndDebugLoc(MI);
103   auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
104   auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
105   B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
106   MI.eraseFromParent();
107   return true;
108 }
109 
isSignExtended(Register R,MachineRegisterInfo & MRI)110 static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
111   // TODO: check if extended build vector as well.
112   unsigned Opc = MRI.getVRegDef(R)->getOpcode();
113   return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
114 }
115 
isZeroExtended(Register R,MachineRegisterInfo & MRI)116 static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
117   // TODO: check if extended build vector as well.
118   return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
119 }
120 
matchAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)121 bool matchAArch64MulConstCombine(
122     MachineInstr &MI, MachineRegisterInfo &MRI,
123     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
124   assert(MI.getOpcode() == TargetOpcode::G_MUL);
125   Register LHS = MI.getOperand(1).getReg();
126   Register RHS = MI.getOperand(2).getReg();
127   Register Dst = MI.getOperand(0).getReg();
128   const LLT Ty = MRI.getType(LHS);
129 
130   // The below optimizations require a constant RHS.
131   auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
132   if (!Const)
133     return false;
134 
135   const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
136   // The following code is ported from AArch64ISelLowering.
137   // Multiplication of a power of two plus/minus one can be done more
138   // cheaply as as shift+add/sub. For now, this is true unilaterally. If
139   // future CPUs have a cheaper MADD instruction, this may need to be
140   // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
141   // 64-bit is 5 cycles, so this is always a win.
142   // More aggressively, some multiplications N0 * C can be lowered to
143   // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
144   // e.g. 6=3*2=(2+1)*2.
145   // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
146   // which equals to (1+2)*16-(1+2).
147   // TrailingZeroes is used to test if the mul can be lowered to
148   // shift+add+shift.
149   unsigned TrailingZeroes = ConstValue.countTrailingZeros();
150   if (TrailingZeroes) {
151     // Conservatively do not lower to shift+add+shift if the mul might be
152     // folded into smul or umul.
153     if (MRI.hasOneNonDBGUse(LHS) &&
154         (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
155       return false;
156     // Conservatively do not lower to shift+add+shift if the mul might be
157     // folded into madd or msub.
158     if (MRI.hasOneNonDBGUse(Dst)) {
159       MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
160       unsigned UseOpc = UseMI.getOpcode();
161       if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD ||
162           UseOpc == TargetOpcode::G_SUB)
163         return false;
164     }
165   }
166   // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
167   // and shift+add+shift.
168   APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
169 
170   unsigned ShiftAmt, AddSubOpc;
171   // Is the shifted value the LHS operand of the add/sub?
172   bool ShiftValUseIsLHS = true;
173   // Do we need to negate the result?
174   bool NegateResult = false;
175 
176   if (ConstValue.isNonNegative()) {
177     // (mul x, 2^N + 1) => (add (shl x, N), x)
178     // (mul x, 2^N - 1) => (sub (shl x, N), x)
179     // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
180     APInt SCVMinus1 = ShiftedConstValue - 1;
181     APInt CVPlus1 = ConstValue + 1;
182     if (SCVMinus1.isPowerOf2()) {
183       ShiftAmt = SCVMinus1.logBase2();
184       AddSubOpc = TargetOpcode::G_ADD;
185     } else if (CVPlus1.isPowerOf2()) {
186       ShiftAmt = CVPlus1.logBase2();
187       AddSubOpc = TargetOpcode::G_SUB;
188     } else
189       return false;
190   } else {
191     // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
192     // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
193     APInt CVNegPlus1 = -ConstValue + 1;
194     APInt CVNegMinus1 = -ConstValue - 1;
195     if (CVNegPlus1.isPowerOf2()) {
196       ShiftAmt = CVNegPlus1.logBase2();
197       AddSubOpc = TargetOpcode::G_SUB;
198       ShiftValUseIsLHS = false;
199     } else if (CVNegMinus1.isPowerOf2()) {
200       ShiftAmt = CVNegMinus1.logBase2();
201       AddSubOpc = TargetOpcode::G_ADD;
202       NegateResult = true;
203     } else
204       return false;
205   }
206 
207   if (NegateResult && TrailingZeroes)
208     return false;
209 
210   ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
211     auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
212     auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
213 
214     Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
215     Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
216     auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
217     assert(!(NegateResult && TrailingZeroes) &&
218            "NegateResult and TrailingZeroes cannot both be true for now.");
219     // Negate the result.
220     if (NegateResult) {
221       B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
222       return;
223     }
224     // Shift the result.
225     if (TrailingZeroes) {
226       B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
227       return;
228     }
229     B.buildCopy(DstReg, Res.getReg(0));
230   };
231   return true;
232 }
233 
applyAArch64MulConstCombine(MachineInstr & MI,MachineRegisterInfo & MRI,MachineIRBuilder & B,std::function<void (MachineIRBuilder & B,Register DstReg)> & ApplyFn)234 bool applyAArch64MulConstCombine(
235     MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
236     std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
237   B.setInstrAndDebugLoc(MI);
238   ApplyFn(B, MI.getOperand(0).getReg());
239   MI.eraseFromParent();
240   return true;
241 }
242 
243 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift.
matchBitfieldExtractFromSExtInReg(MachineInstr & MI,MachineRegisterInfo & MRI,std::function<void (MachineIRBuilder &)> & MatchInfo)244 static bool matchBitfieldExtractFromSExtInReg(
245     MachineInstr &MI, MachineRegisterInfo &MRI,
246     std::function<void(MachineIRBuilder &)> &MatchInfo) {
247   assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
248   Register Dst = MI.getOperand(0).getReg();
249   Register Src = MI.getOperand(1).getReg();
250   int64_t Width = MI.getOperand(2).getImm();
251   LLT Ty = MRI.getType(Src);
252   assert((Ty == LLT::scalar(32) || Ty == LLT::scalar(64)) &&
253          "Unexpected type for G_SEXT_INREG?");
254   Register ShiftSrc;
255   int64_t ShiftImm;
256   if (!mi_match(
257           Src, MRI,
258           m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)),
259                                   m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm))))))
260     return false;
261   if (ShiftImm < 0 || ShiftImm + Width > Ty.getSizeInBits())
262     return false;
263   MatchInfo = [=](MachineIRBuilder &B) {
264     auto Cst1 = B.buildConstant(Ty, ShiftImm);
265     auto Cst2 = B.buildConstant(Ty, Width);
266     B.buildInstr(TargetOpcode::G_SBFX, {Dst}, {ShiftSrc, Cst1, Cst2});
267   };
268   return true;
269 }
270 
271 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
272 #include "AArch64GenPostLegalizeGICombiner.inc"
273 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS
274 
275 namespace {
276 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
277 #include "AArch64GenPostLegalizeGICombiner.inc"
278 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_H
279 
280 class AArch64PostLegalizerCombinerInfo : public CombinerInfo {
281   GISelKnownBits *KB;
282   MachineDominatorTree *MDT;
283 
284 public:
285   AArch64GenPostLegalizerCombinerHelperRuleConfig GeneratedRuleCfg;
286 
AArch64PostLegalizerCombinerInfo(bool EnableOpt,bool OptSize,bool MinSize,GISelKnownBits * KB,MachineDominatorTree * MDT)287   AArch64PostLegalizerCombinerInfo(bool EnableOpt, bool OptSize, bool MinSize,
288                                    GISelKnownBits *KB,
289                                    MachineDominatorTree *MDT)
290       : CombinerInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
291                      /*LegalizerInfo*/ nullptr, EnableOpt, OptSize, MinSize),
292         KB(KB), MDT(MDT) {
293     if (!GeneratedRuleCfg.parseCommandLineOption())
294       report_fatal_error("Invalid rule identifier");
295   }
296 
297   virtual bool combine(GISelChangeObserver &Observer, MachineInstr &MI,
298                        MachineIRBuilder &B) const override;
299 };
300 
combine(GISelChangeObserver & Observer,MachineInstr & MI,MachineIRBuilder & B) const301 bool AArch64PostLegalizerCombinerInfo::combine(GISelChangeObserver &Observer,
302                                                MachineInstr &MI,
303                                                MachineIRBuilder &B) const {
304   const auto *LI =
305       MI.getParent()->getParent()->getSubtarget().getLegalizerInfo();
306   CombinerHelper Helper(Observer, B, KB, MDT, LI);
307   AArch64GenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg);
308   return Generated.tryCombineAll(Observer, MI, B, Helper);
309 }
310 
311 #define AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
312 #include "AArch64GenPostLegalizeGICombiner.inc"
313 #undef AARCH64POSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_CPP
314 
315 class AArch64PostLegalizerCombiner : public MachineFunctionPass {
316 public:
317   static char ID;
318 
319   AArch64PostLegalizerCombiner(bool IsOptNone = false);
320 
getPassName() const321   StringRef getPassName() const override {
322     return "AArch64PostLegalizerCombiner";
323   }
324 
325   bool runOnMachineFunction(MachineFunction &MF) override;
326   void getAnalysisUsage(AnalysisUsage &AU) const override;
327 
328 private:
329   bool IsOptNone;
330 };
331 } // end anonymous namespace
332 
getAnalysisUsage(AnalysisUsage & AU) const333 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
334   AU.addRequired<TargetPassConfig>();
335   AU.setPreservesCFG();
336   getSelectionDAGFallbackAnalysisUsage(AU);
337   AU.addRequired<GISelKnownBitsAnalysis>();
338   AU.addPreserved<GISelKnownBitsAnalysis>();
339   if (!IsOptNone) {
340     AU.addRequired<MachineDominatorTree>();
341     AU.addPreserved<MachineDominatorTree>();
342     AU.addRequired<GISelCSEAnalysisWrapperPass>();
343     AU.addPreserved<GISelCSEAnalysisWrapperPass>();
344   }
345   MachineFunctionPass::getAnalysisUsage(AU);
346 }
347 
AArch64PostLegalizerCombiner(bool IsOptNone)348 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone)
349     : MachineFunctionPass(ID), IsOptNone(IsOptNone) {
350   initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry());
351 }
352 
runOnMachineFunction(MachineFunction & MF)353 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) {
354   if (MF.getProperties().hasProperty(
355           MachineFunctionProperties::Property::FailedISel))
356     return false;
357   assert(MF.getProperties().hasProperty(
358              MachineFunctionProperties::Property::Legalized) &&
359          "Expected a legalized function?");
360   auto *TPC = &getAnalysis<TargetPassConfig>();
361   const Function &F = MF.getFunction();
362   bool EnableOpt =
363       MF.getTarget().getOptLevel() != CodeGenOpt::None && !skipFunction(F);
364   GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF);
365   MachineDominatorTree *MDT =
366       IsOptNone ? nullptr : &getAnalysis<MachineDominatorTree>();
367   AArch64PostLegalizerCombinerInfo PCInfo(EnableOpt, F.hasOptSize(),
368                                           F.hasMinSize(), KB, MDT);
369   GISelCSEAnalysisWrapper &Wrapper =
370       getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper();
371   auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig());
372   Combiner C(PCInfo, TPC);
373   return C.combineMachineInstrs(MF, CSEInfo);
374 }
375 
376 char AArch64PostLegalizerCombiner::ID = 0;
377 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE,
378                       "Combine AArch64 MachineInstrs after legalization", false,
379                       false)
380 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
381 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis)
382 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
383                     "Combine AArch64 MachineInstrs after legalization", false,
384                     false)
385 
386 namespace llvm {
createAArch64PostLegalizerCombiner(bool IsOptNone)387 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
388   return new AArch64PostLegalizerCombiner(IsOptNone);
389 }
390 } // end namespace llvm
391