1 //=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// Post-legalization combines on generic MachineInstrs. 11 /// 12 /// The combines here must preserve instruction legality. 13 /// 14 /// Lowering combines (e.g. pseudo matching) should be handled by 15 /// AArch64PostLegalizerLowering. 16 /// 17 /// Combines which don't rely on instruction legality should go in the 18 /// AArch64PreLegalizerCombiner. 19 /// 20 //===----------------------------------------------------------------------===// 21 22 #include "AArch64TargetMachine.h" 23 #include "llvm/ADT/STLExtras.h" 24 #include "llvm/CodeGen/GlobalISel/CSEInfo.h" 25 #include "llvm/CodeGen/GlobalISel/CSEMIRBuilder.h" 26 #include "llvm/CodeGen/GlobalISel/Combiner.h" 27 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 28 #include "llvm/CodeGen/GlobalISel/CombinerInfo.h" 29 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" 30 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 31 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 32 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 33 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 34 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 35 #include "llvm/CodeGen/GlobalISel/Utils.h" 36 #include "llvm/CodeGen/MachineDominators.h" 37 #include "llvm/CodeGen/MachineFunctionPass.h" 38 #include "llvm/CodeGen/MachineRegisterInfo.h" 39 #include "llvm/CodeGen/TargetOpcodes.h" 40 #include "llvm/CodeGen/TargetPassConfig.h" 41 #include "llvm/Support/Debug.h" 42 43 #define GET_GICOMBINER_DEPS 44 #include "AArch64GenPostLegalizeGICombiner.inc" 45 #undef GET_GICOMBINER_DEPS 46 47 #define DEBUG_TYPE "aarch64-postlegalizer-combiner" 48 49 using namespace llvm; 50 using namespace MIPatternMatch; 51 52 namespace { 53 54 #define GET_GICOMBINER_TYPES 55 #include "AArch64GenPostLegalizeGICombiner.inc" 56 #undef GET_GICOMBINER_TYPES 57 58 /// This combine tries do what performExtractVectorEltCombine does in SDAG. 59 /// Rewrite for pairwise fadd pattern 60 /// (s32 (g_extract_vector_elt 61 /// (g_fadd (vXs32 Other) 62 /// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0)) 63 /// -> 64 /// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0) 65 /// (g_extract_vector_elt (vXs32 Other) 1)) 66 bool matchExtractVecEltPairwiseAdd( 67 MachineInstr &MI, MachineRegisterInfo &MRI, 68 std::tuple<unsigned, LLT, Register> &MatchInfo) { 69 Register Src1 = MI.getOperand(1).getReg(); 70 Register Src2 = MI.getOperand(2).getReg(); 71 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 72 73 auto Cst = getIConstantVRegValWithLookThrough(Src2, MRI); 74 if (!Cst || Cst->Value != 0) 75 return false; 76 // SDAG also checks for FullFP16, but this looks to be beneficial anyway. 77 78 // Now check for an fadd operation. TODO: expand this for integer add? 79 auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI); 80 if (!FAddMI) 81 return false; 82 83 // If we add support for integer add, must restrict these types to just s64. 84 unsigned DstSize = DstTy.getSizeInBits(); 85 if (DstSize != 16 && DstSize != 32 && DstSize != 64) 86 return false; 87 88 Register Src1Op1 = FAddMI->getOperand(1).getReg(); 89 Register Src1Op2 = FAddMI->getOperand(2).getReg(); 90 MachineInstr *Shuffle = 91 getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI); 92 MachineInstr *Other = MRI.getVRegDef(Src1Op1); 93 if (!Shuffle) { 94 Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI); 95 Other = MRI.getVRegDef(Src1Op2); 96 } 97 98 // We're looking for a shuffle that moves the second element to index 0. 99 if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 && 100 Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) { 101 std::get<0>(MatchInfo) = TargetOpcode::G_FADD; 102 std::get<1>(MatchInfo) = DstTy; 103 std::get<2>(MatchInfo) = Other->getOperand(0).getReg(); 104 return true; 105 } 106 return false; 107 } 108 109 void applyExtractVecEltPairwiseAdd( 110 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 111 std::tuple<unsigned, LLT, Register> &MatchInfo) { 112 unsigned Opc = std::get<0>(MatchInfo); 113 assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!"); 114 // We want to generate two extracts of elements 0 and 1, and add them. 115 LLT Ty = std::get<1>(MatchInfo); 116 Register Src = std::get<2>(MatchInfo); 117 LLT s64 = LLT::scalar(64); 118 B.setInstrAndDebugLoc(MI); 119 auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0)); 120 auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1)); 121 B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1}); 122 MI.eraseFromParent(); 123 } 124 125 bool isSignExtended(Register R, MachineRegisterInfo &MRI) { 126 // TODO: check if extended build vector as well. 127 unsigned Opc = MRI.getVRegDef(R)->getOpcode(); 128 return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG; 129 } 130 131 bool isZeroExtended(Register R, MachineRegisterInfo &MRI) { 132 // TODO: check if extended build vector as well. 133 return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT; 134 } 135 136 bool matchAArch64MulConstCombine( 137 MachineInstr &MI, MachineRegisterInfo &MRI, 138 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 139 assert(MI.getOpcode() == TargetOpcode::G_MUL); 140 Register LHS = MI.getOperand(1).getReg(); 141 Register RHS = MI.getOperand(2).getReg(); 142 Register Dst = MI.getOperand(0).getReg(); 143 const LLT Ty = MRI.getType(LHS); 144 145 // The below optimizations require a constant RHS. 146 auto Const = getIConstantVRegValWithLookThrough(RHS, MRI); 147 if (!Const) 148 return false; 149 150 APInt ConstValue = Const->Value.sext(Ty.getSizeInBits()); 151 // The following code is ported from AArch64ISelLowering. 152 // Multiplication of a power of two plus/minus one can be done more 153 // cheaply as shift+add/sub. For now, this is true unilaterally. If 154 // future CPUs have a cheaper MADD instruction, this may need to be 155 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and 156 // 64-bit is 5 cycles, so this is always a win. 157 // More aggressively, some multiplications N0 * C can be lowered to 158 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, 159 // e.g. 6=3*2=(2+1)*2. 160 // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 161 // which equals to (1+2)*16-(1+2). 162 // TrailingZeroes is used to test if the mul can be lowered to 163 // shift+add+shift. 164 unsigned TrailingZeroes = ConstValue.countr_zero(); 165 if (TrailingZeroes) { 166 // Conservatively do not lower to shift+add+shift if the mul might be 167 // folded into smul or umul. 168 if (MRI.hasOneNonDBGUse(LHS) && 169 (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI))) 170 return false; 171 // Conservatively do not lower to shift+add+shift if the mul might be 172 // folded into madd or msub. 173 if (MRI.hasOneNonDBGUse(Dst)) { 174 MachineInstr &UseMI = *MRI.use_instr_begin(Dst); 175 unsigned UseOpc = UseMI.getOpcode(); 176 if (UseOpc == TargetOpcode::G_ADD || UseOpc == TargetOpcode::G_PTR_ADD || 177 UseOpc == TargetOpcode::G_SUB) 178 return false; 179 } 180 } 181 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub 182 // and shift+add+shift. 183 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); 184 185 unsigned ShiftAmt, AddSubOpc; 186 // Is the shifted value the LHS operand of the add/sub? 187 bool ShiftValUseIsLHS = true; 188 // Do we need to negate the result? 189 bool NegateResult = false; 190 191 if (ConstValue.isNonNegative()) { 192 // (mul x, 2^N + 1) => (add (shl x, N), x) 193 // (mul x, 2^N - 1) => (sub (shl x, N), x) 194 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) 195 APInt SCVMinus1 = ShiftedConstValue - 1; 196 APInt CVPlus1 = ConstValue + 1; 197 if (SCVMinus1.isPowerOf2()) { 198 ShiftAmt = SCVMinus1.logBase2(); 199 AddSubOpc = TargetOpcode::G_ADD; 200 } else if (CVPlus1.isPowerOf2()) { 201 ShiftAmt = CVPlus1.logBase2(); 202 AddSubOpc = TargetOpcode::G_SUB; 203 } else 204 return false; 205 } else { 206 // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) 207 // (mul x, -(2^N + 1)) => - (add (shl x, N), x) 208 APInt CVNegPlus1 = -ConstValue + 1; 209 APInt CVNegMinus1 = -ConstValue - 1; 210 if (CVNegPlus1.isPowerOf2()) { 211 ShiftAmt = CVNegPlus1.logBase2(); 212 AddSubOpc = TargetOpcode::G_SUB; 213 ShiftValUseIsLHS = false; 214 } else if (CVNegMinus1.isPowerOf2()) { 215 ShiftAmt = CVNegMinus1.logBase2(); 216 AddSubOpc = TargetOpcode::G_ADD; 217 NegateResult = true; 218 } else 219 return false; 220 } 221 222 if (NegateResult && TrailingZeroes) 223 return false; 224 225 ApplyFn = [=](MachineIRBuilder &B, Register DstReg) { 226 auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt); 227 auto ShiftedVal = B.buildShl(Ty, LHS, Shift); 228 229 Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS; 230 Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0); 231 auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS}); 232 assert(!(NegateResult && TrailingZeroes) && 233 "NegateResult and TrailingZeroes cannot both be true for now."); 234 // Negate the result. 235 if (NegateResult) { 236 B.buildSub(DstReg, B.buildConstant(Ty, 0), Res); 237 return; 238 } 239 // Shift the result. 240 if (TrailingZeroes) { 241 B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes)); 242 return; 243 } 244 B.buildCopy(DstReg, Res.getReg(0)); 245 }; 246 return true; 247 } 248 249 void applyAArch64MulConstCombine( 250 MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B, 251 std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) { 252 B.setInstrAndDebugLoc(MI); 253 ApplyFn(B, MI.getOperand(0).getReg()); 254 MI.eraseFromParent(); 255 } 256 257 /// Try to fold a G_MERGE_VALUES of 2 s32 sources, where the second source 258 /// is a zero, into a G_ZEXT of the first. 259 bool matchFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI) { 260 auto &Merge = cast<GMerge>(MI); 261 LLT SrcTy = MRI.getType(Merge.getSourceReg(0)); 262 if (SrcTy != LLT::scalar(32) || Merge.getNumSources() != 2) 263 return false; 264 return mi_match(Merge.getSourceReg(1), MRI, m_SpecificICst(0)); 265 } 266 267 void applyFoldMergeToZext(MachineInstr &MI, MachineRegisterInfo &MRI, 268 MachineIRBuilder &B, GISelChangeObserver &Observer) { 269 // Mutate %d(s64) = G_MERGE_VALUES %a(s32), 0(s32) 270 // -> 271 // %d(s64) = G_ZEXT %a(s32) 272 Observer.changingInstr(MI); 273 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 274 MI.removeOperand(2); 275 Observer.changedInstr(MI); 276 } 277 278 /// \returns True if a G_ANYEXT instruction \p MI should be mutated to a G_ZEXT 279 /// instruction. 280 bool matchMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI) { 281 // If this is coming from a scalar compare then we can use a G_ZEXT instead of 282 // a G_ANYEXT: 283 // 284 // %cmp:_(s32) = G_[I|F]CMP ... <-- produces 0/1. 285 // %ext:_(s64) = G_ANYEXT %cmp(s32) 286 // 287 // By doing this, we can leverage more KnownBits combines. 288 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT); 289 Register Dst = MI.getOperand(0).getReg(); 290 Register Src = MI.getOperand(1).getReg(); 291 return MRI.getType(Dst).isScalar() && 292 mi_match(Src, MRI, 293 m_any_of(m_GICmp(m_Pred(), m_Reg(), m_Reg()), 294 m_GFCmp(m_Pred(), m_Reg(), m_Reg()))); 295 } 296 297 void applyMutateAnyExtToZExt(MachineInstr &MI, MachineRegisterInfo &MRI, 298 MachineIRBuilder &B, 299 GISelChangeObserver &Observer) { 300 Observer.changingInstr(MI); 301 MI.setDesc(B.getTII().get(TargetOpcode::G_ZEXT)); 302 Observer.changedInstr(MI); 303 } 304 305 /// Match a 128b store of zero and split it into two 64 bit stores, for 306 /// size/performance reasons. 307 bool matchSplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI) { 308 GStore &Store = cast<GStore>(MI); 309 if (!Store.isSimple()) 310 return false; 311 LLT ValTy = MRI.getType(Store.getValueReg()); 312 if (ValTy.isScalableVector()) 313 return false; 314 if (!ValTy.isVector() || ValTy.getSizeInBits() != 128) 315 return false; 316 if (Store.getMemSizeInBits() != ValTy.getSizeInBits()) 317 return false; // Don't split truncating stores. 318 if (!MRI.hasOneNonDBGUse(Store.getValueReg())) 319 return false; 320 auto MaybeCst = isConstantOrConstantSplatVector( 321 *MRI.getVRegDef(Store.getValueReg()), MRI); 322 return MaybeCst && MaybeCst->isZero(); 323 } 324 325 void applySplitStoreZero128(MachineInstr &MI, MachineRegisterInfo &MRI, 326 MachineIRBuilder &B, 327 GISelChangeObserver &Observer) { 328 B.setInstrAndDebugLoc(MI); 329 GStore &Store = cast<GStore>(MI); 330 assert(MRI.getType(Store.getValueReg()).isVector() && 331 "Expected a vector store value"); 332 LLT NewTy = LLT::scalar(64); 333 Register PtrReg = Store.getPointerReg(); 334 auto Zero = B.buildConstant(NewTy, 0); 335 auto HighPtr = B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, 336 B.buildConstant(LLT::scalar(64), 8)); 337 auto &MF = *MI.getMF(); 338 auto *LowMMO = MF.getMachineMemOperand(&Store.getMMO(), 0, NewTy); 339 auto *HighMMO = MF.getMachineMemOperand(&Store.getMMO(), 8, NewTy); 340 B.buildStore(Zero, PtrReg, *LowMMO); 341 B.buildStore(Zero, HighPtr, *HighMMO); 342 Store.eraseFromParent(); 343 } 344 345 bool matchOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, 346 std::tuple<Register, Register, Register> &MatchInfo) { 347 const LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 348 if (!DstTy.isVector()) 349 return false; 350 351 Register AO1, AO2, BVO1, BVO2; 352 if (!mi_match(MI, MRI, 353 m_GOr(m_GAnd(m_Reg(AO1), m_Reg(BVO1)), 354 m_GAnd(m_Reg(AO2), m_Reg(BVO2))))) 355 return false; 356 357 auto *BV1 = getOpcodeDef<GBuildVector>(BVO1, MRI); 358 auto *BV2 = getOpcodeDef<GBuildVector>(BVO2, MRI); 359 if (!BV1 || !BV2) 360 return false; 361 362 for (int I = 0, E = DstTy.getNumElements(); I < E; I++) { 363 auto ValAndVReg1 = 364 getIConstantVRegValWithLookThrough(BV1->getSourceReg(I), MRI); 365 auto ValAndVReg2 = 366 getIConstantVRegValWithLookThrough(BV2->getSourceReg(I), MRI); 367 if (!ValAndVReg1 || !ValAndVReg2 || 368 ValAndVReg1->Value != ~ValAndVReg2->Value) 369 return false; 370 } 371 372 MatchInfo = {AO1, AO2, BVO1}; 373 return true; 374 } 375 376 void applyOrToBSP(MachineInstr &MI, MachineRegisterInfo &MRI, 377 MachineIRBuilder &B, 378 std::tuple<Register, Register, Register> &MatchInfo) { 379 B.setInstrAndDebugLoc(MI); 380 B.buildInstr( 381 AArch64::G_BSP, {MI.getOperand(0).getReg()}, 382 {std::get<2>(MatchInfo), std::get<0>(MatchInfo), std::get<1>(MatchInfo)}); 383 MI.eraseFromParent(); 384 } 385 386 // Combines Mul(And(Srl(X, 15), 0x10001), 0xffff) into CMLTz 387 bool matchCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, 388 Register &SrcReg) { 389 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 390 391 if (DstTy != LLT::fixed_vector(2, 64) && DstTy != LLT::fixed_vector(2, 32) && 392 DstTy != LLT::fixed_vector(4, 32) && DstTy != LLT::fixed_vector(4, 16) && 393 DstTy != LLT::fixed_vector(8, 16)) 394 return false; 395 396 auto AndMI = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI); 397 if (AndMI->getOpcode() != TargetOpcode::G_AND) 398 return false; 399 auto LShrMI = getDefIgnoringCopies(AndMI->getOperand(1).getReg(), MRI); 400 if (LShrMI->getOpcode() != TargetOpcode::G_LSHR) 401 return false; 402 403 // Check the constant splat values 404 auto V1 = isConstantOrConstantSplatVector( 405 *MRI.getVRegDef(MI.getOperand(2).getReg()), MRI); 406 auto V2 = isConstantOrConstantSplatVector( 407 *MRI.getVRegDef(AndMI->getOperand(2).getReg()), MRI); 408 auto V3 = isConstantOrConstantSplatVector( 409 *MRI.getVRegDef(LShrMI->getOperand(2).getReg()), MRI); 410 if (!V1.has_value() || !V2.has_value() || !V3.has_value()) 411 return false; 412 unsigned HalfSize = DstTy.getScalarSizeInBits() / 2; 413 if (!V1.value().isMask(HalfSize) || V2.value() != (1ULL | 1ULL << HalfSize) || 414 V3 != (HalfSize - 1)) 415 return false; 416 417 SrcReg = LShrMI->getOperand(1).getReg(); 418 419 return true; 420 } 421 422 void applyCombineMulCMLT(MachineInstr &MI, MachineRegisterInfo &MRI, 423 MachineIRBuilder &B, Register &SrcReg) { 424 Register DstReg = MI.getOperand(0).getReg(); 425 LLT DstTy = MRI.getType(DstReg); 426 LLT HalfTy = 427 DstTy.changeElementCount(DstTy.getElementCount().multiplyCoefficientBy(2)) 428 .changeElementSize(DstTy.getScalarSizeInBits() / 2); 429 430 Register ZeroVec = B.buildConstant(HalfTy, 0).getReg(0); 431 Register CastReg = 432 B.buildInstr(TargetOpcode::G_BITCAST, {HalfTy}, {SrcReg}).getReg(0); 433 Register CMLTReg = 434 B.buildICmp(CmpInst::Predicate::ICMP_SLT, HalfTy, CastReg, ZeroVec) 435 .getReg(0); 436 437 B.buildInstr(TargetOpcode::G_BITCAST, {DstReg}, {CMLTReg}).getReg(0); 438 MI.eraseFromParent(); 439 } 440 441 class AArch64PostLegalizerCombinerImpl : public Combiner { 442 protected: 443 const CombinerHelper Helper; 444 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig; 445 const AArch64Subtarget &STI; 446 447 public: 448 AArch64PostLegalizerCombinerImpl( 449 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 450 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 451 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, 452 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 453 const LegalizerInfo *LI); 454 455 static const char *getName() { return "AArch64PostLegalizerCombiner"; } 456 457 bool tryCombineAll(MachineInstr &I) const override; 458 459 private: 460 #define GET_GICOMBINER_CLASS_MEMBERS 461 #include "AArch64GenPostLegalizeGICombiner.inc" 462 #undef GET_GICOMBINER_CLASS_MEMBERS 463 }; 464 465 #define GET_GICOMBINER_IMPL 466 #include "AArch64GenPostLegalizeGICombiner.inc" 467 #undef GET_GICOMBINER_IMPL 468 469 AArch64PostLegalizerCombinerImpl::AArch64PostLegalizerCombinerImpl( 470 MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, 471 GISelKnownBits &KB, GISelCSEInfo *CSEInfo, 472 const AArch64PostLegalizerCombinerImplRuleConfig &RuleConfig, 473 const AArch64Subtarget &STI, MachineDominatorTree *MDT, 474 const LegalizerInfo *LI) 475 : Combiner(MF, CInfo, TPC, &KB, CSEInfo), 476 Helper(Observer, B, /*IsPreLegalize*/ false, &KB, MDT, LI), 477 RuleConfig(RuleConfig), STI(STI), 478 #define GET_GICOMBINER_CONSTRUCTOR_INITS 479 #include "AArch64GenPostLegalizeGICombiner.inc" 480 #undef GET_GICOMBINER_CONSTRUCTOR_INITS 481 { 482 } 483 484 class AArch64PostLegalizerCombiner : public MachineFunctionPass { 485 public: 486 static char ID; 487 488 AArch64PostLegalizerCombiner(bool IsOptNone = false); 489 490 StringRef getPassName() const override { 491 return "AArch64PostLegalizerCombiner"; 492 } 493 494 bool runOnMachineFunction(MachineFunction &MF) override; 495 void getAnalysisUsage(AnalysisUsage &AU) const override; 496 497 private: 498 bool IsOptNone; 499 AArch64PostLegalizerCombinerImplRuleConfig RuleConfig; 500 501 502 struct StoreInfo { 503 GStore *St = nullptr; 504 // The G_PTR_ADD that's used by the store. We keep this to cache the 505 // MachineInstr def. 506 GPtrAdd *Ptr = nullptr; 507 // The signed offset to the Ptr instruction. 508 int64_t Offset = 0; 509 LLT StoredType; 510 }; 511 bool tryOptimizeConsecStores(SmallVectorImpl<StoreInfo> &Stores, 512 CSEMIRBuilder &MIB); 513 514 bool optimizeConsecutiveMemOpAddressing(MachineFunction &MF, 515 CSEMIRBuilder &MIB); 516 }; 517 } // end anonymous namespace 518 519 void AArch64PostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { 520 AU.addRequired<TargetPassConfig>(); 521 AU.setPreservesCFG(); 522 getSelectionDAGFallbackAnalysisUsage(AU); 523 AU.addRequired<GISelKnownBitsAnalysis>(); 524 AU.addPreserved<GISelKnownBitsAnalysis>(); 525 if (!IsOptNone) { 526 AU.addRequired<MachineDominatorTreeWrapperPass>(); 527 AU.addPreserved<MachineDominatorTreeWrapperPass>(); 528 AU.addRequired<GISelCSEAnalysisWrapperPass>(); 529 AU.addPreserved<GISelCSEAnalysisWrapperPass>(); 530 } 531 MachineFunctionPass::getAnalysisUsage(AU); 532 } 533 534 AArch64PostLegalizerCombiner::AArch64PostLegalizerCombiner(bool IsOptNone) 535 : MachineFunctionPass(ID), IsOptNone(IsOptNone) { 536 initializeAArch64PostLegalizerCombinerPass(*PassRegistry::getPassRegistry()); 537 538 if (!RuleConfig.parseCommandLineOption()) 539 report_fatal_error("Invalid rule identifier"); 540 } 541 542 bool AArch64PostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { 543 if (MF.getProperties().hasProperty( 544 MachineFunctionProperties::Property::FailedISel)) 545 return false; 546 assert(MF.getProperties().hasProperty( 547 MachineFunctionProperties::Property::Legalized) && 548 "Expected a legalized function?"); 549 auto *TPC = &getAnalysis<TargetPassConfig>(); 550 const Function &F = MF.getFunction(); 551 bool EnableOpt = 552 MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); 553 554 const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>(); 555 const auto *LI = ST.getLegalizerInfo(); 556 557 GISelKnownBits *KB = &getAnalysis<GISelKnownBitsAnalysis>().get(MF); 558 MachineDominatorTree *MDT = 559 IsOptNone ? nullptr 560 : &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 561 GISelCSEAnalysisWrapper &Wrapper = 562 getAnalysis<GISelCSEAnalysisWrapperPass>().getCSEWrapper(); 563 auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); 564 565 CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, 566 /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), 567 F.hasMinSize()); 568 // Disable fixed-point iteration to reduce compile-time 569 CInfo.MaxIterations = 1; 570 CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass; 571 // Legalizer performs DCE, so a full DCE pass is unnecessary. 572 CInfo.EnableFullDCE = false; 573 AArch64PostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *KB, CSEInfo, 574 RuleConfig, ST, MDT, LI); 575 bool Changed = Impl.combineMachineInstrs(); 576 577 auto MIB = CSEMIRBuilder(MF); 578 MIB.setCSEInfo(CSEInfo); 579 Changed |= optimizeConsecutiveMemOpAddressing(MF, MIB); 580 return Changed; 581 } 582 583 bool AArch64PostLegalizerCombiner::tryOptimizeConsecStores( 584 SmallVectorImpl<StoreInfo> &Stores, CSEMIRBuilder &MIB) { 585 if (Stores.size() <= 2) 586 return false; 587 588 // Profitabity checks: 589 int64_t BaseOffset = Stores[0].Offset; 590 unsigned NumPairsExpected = Stores.size() / 2; 591 unsigned TotalInstsExpected = NumPairsExpected + (Stores.size() % 2); 592 // Size savings will depend on whether we can fold the offset, as an 593 // immediate of an ADD. 594 auto &TLI = *MIB.getMF().getSubtarget().getTargetLowering(); 595 if (!TLI.isLegalAddImmediate(BaseOffset)) 596 TotalInstsExpected++; 597 int SavingsExpected = Stores.size() - TotalInstsExpected; 598 if (SavingsExpected <= 0) 599 return false; 600 601 auto &MRI = MIB.getMF().getRegInfo(); 602 603 // We have a series of consecutive stores. Factor out the common base 604 // pointer and rewrite the offsets. 605 Register NewBase = Stores[0].Ptr->getReg(0); 606 for (auto &SInfo : Stores) { 607 // Compute a new pointer with the new base ptr and adjusted offset. 608 MIB.setInstrAndDebugLoc(*SInfo.St); 609 auto NewOff = MIB.buildConstant(LLT::scalar(64), SInfo.Offset - BaseOffset); 610 auto NewPtr = MIB.buildPtrAdd(MRI.getType(SInfo.St->getPointerReg()), 611 NewBase, NewOff); 612 if (MIB.getObserver()) 613 MIB.getObserver()->changingInstr(*SInfo.St); 614 SInfo.St->getOperand(1).setReg(NewPtr.getReg(0)); 615 if (MIB.getObserver()) 616 MIB.getObserver()->changedInstr(*SInfo.St); 617 } 618 LLVM_DEBUG(dbgs() << "Split a series of " << Stores.size() 619 << " stores into a base pointer and offsets.\n"); 620 return true; 621 } 622 623 static cl::opt<bool> 624 EnableConsecutiveMemOpOpt("aarch64-postlegalizer-consecutive-memops", 625 cl::init(true), cl::Hidden, 626 cl::desc("Enable consecutive memop optimization " 627 "in AArch64PostLegalizerCombiner")); 628 629 bool AArch64PostLegalizerCombiner::optimizeConsecutiveMemOpAddressing( 630 MachineFunction &MF, CSEMIRBuilder &MIB) { 631 // This combine needs to run after all reassociations/folds on pointer 632 // addressing have been done, specifically those that combine two G_PTR_ADDs 633 // with constant offsets into a single G_PTR_ADD with a combined offset. 634 // The goal of this optimization is to undo that combine in the case where 635 // doing so has prevented the formation of pair stores due to illegal 636 // addressing modes of STP. The reason that we do it here is because 637 // it's much easier to undo the transformation of a series consecutive 638 // mem ops, than it is to detect when doing it would be a bad idea looking 639 // at a single G_PTR_ADD in the reassociation/ptradd_immed_chain combine. 640 // 641 // An example: 642 // G_STORE %11:_(<2 x s64>), %base:_(p0) :: (store (<2 x s64>), align 1) 643 // %off1:_(s64) = G_CONSTANT i64 4128 644 // %p1:_(p0) = G_PTR_ADD %0:_, %off1:_(s64) 645 // G_STORE %11:_(<2 x s64>), %p1:_(p0) :: (store (<2 x s64>), align 1) 646 // %off2:_(s64) = G_CONSTANT i64 4144 647 // %p2:_(p0) = G_PTR_ADD %0:_, %off2:_(s64) 648 // G_STORE %11:_(<2 x s64>), %p2:_(p0) :: (store (<2 x s64>), align 1) 649 // %off3:_(s64) = G_CONSTANT i64 4160 650 // %p3:_(p0) = G_PTR_ADD %0:_, %off3:_(s64) 651 // G_STORE %11:_(<2 x s64>), %17:_(p0) :: (store (<2 x s64>), align 1) 652 bool Changed = false; 653 auto &MRI = MF.getRegInfo(); 654 655 if (!EnableConsecutiveMemOpOpt) 656 return Changed; 657 658 SmallVector<StoreInfo, 8> Stores; 659 // If we see a load, then we keep track of any values defined by it. 660 // In the following example, STP formation will fail anyway because 661 // the latter store is using a load result that appears after the 662 // the prior store. In this situation if we factor out the offset then 663 // we increase code size for no benefit. 664 // G_STORE %v1:_(s64), %base:_(p0) :: (store (s64)) 665 // %v2:_(s64) = G_LOAD %ldptr:_(p0) :: (load (s64)) 666 // G_STORE %v2:_(s64), %base:_(p0) :: (store (s64)) 667 SmallVector<Register> LoadValsSinceLastStore; 668 669 auto storeIsValid = [&](StoreInfo &Last, StoreInfo New) { 670 // Check if this store is consecutive to the last one. 671 if (Last.Ptr->getBaseReg() != New.Ptr->getBaseReg() || 672 (Last.Offset + static_cast<int64_t>(Last.StoredType.getSizeInBytes()) != 673 New.Offset) || 674 Last.StoredType != New.StoredType) 675 return false; 676 677 // Check if this store is using a load result that appears after the 678 // last store. If so, bail out. 679 if (any_of(LoadValsSinceLastStore, [&](Register LoadVal) { 680 return New.St->getValueReg() == LoadVal; 681 })) 682 return false; 683 684 // Check if the current offset would be too large for STP. 685 // If not, then STP formation should be able to handle it, so we don't 686 // need to do anything. 687 int64_t MaxLegalOffset; 688 switch (New.StoredType.getSizeInBits()) { 689 case 32: 690 MaxLegalOffset = 252; 691 break; 692 case 64: 693 MaxLegalOffset = 504; 694 break; 695 case 128: 696 MaxLegalOffset = 1008; 697 break; 698 default: 699 llvm_unreachable("Unexpected stored type size"); 700 } 701 if (New.Offset < MaxLegalOffset) 702 return false; 703 704 // If factoring it out still wouldn't help then don't bother. 705 return New.Offset - Stores[0].Offset <= MaxLegalOffset; 706 }; 707 708 auto resetState = [&]() { 709 Stores.clear(); 710 LoadValsSinceLastStore.clear(); 711 }; 712 713 for (auto &MBB : MF) { 714 // We're looking inside a single BB at a time since the memset pattern 715 // should only be in a single block. 716 resetState(); 717 for (auto &MI : MBB) { 718 // Skip for scalable vectors 719 if (auto *LdSt = dyn_cast<GLoadStore>(&MI); 720 LdSt && MRI.getType(LdSt->getOperand(0).getReg()).isScalableVector()) 721 continue; 722 723 if (auto *St = dyn_cast<GStore>(&MI)) { 724 Register PtrBaseReg; 725 APInt Offset; 726 LLT StoredValTy = MRI.getType(St->getValueReg()); 727 unsigned ValSize = StoredValTy.getSizeInBits(); 728 if (ValSize < 32 || St->getMMO().getSizeInBits() != ValSize) 729 continue; 730 731 Register PtrReg = St->getPointerReg(); 732 if (mi_match( 733 PtrReg, MRI, 734 m_OneNonDBGUse(m_GPtrAdd(m_Reg(PtrBaseReg), m_ICst(Offset))))) { 735 GPtrAdd *PtrAdd = cast<GPtrAdd>(MRI.getVRegDef(PtrReg)); 736 StoreInfo New = {St, PtrAdd, Offset.getSExtValue(), StoredValTy}; 737 738 if (Stores.empty()) { 739 Stores.push_back(New); 740 continue; 741 } 742 743 // Check if this store is a valid continuation of the sequence. 744 auto &Last = Stores.back(); 745 if (storeIsValid(Last, New)) { 746 Stores.push_back(New); 747 LoadValsSinceLastStore.clear(); // Reset the load value tracking. 748 } else { 749 // The store isn't a valid to consider for the prior sequence, 750 // so try to optimize what we have so far and start a new sequence. 751 Changed |= tryOptimizeConsecStores(Stores, MIB); 752 resetState(); 753 Stores.push_back(New); 754 } 755 } 756 } else if (auto *Ld = dyn_cast<GLoad>(&MI)) { 757 LoadValsSinceLastStore.push_back(Ld->getDstReg()); 758 } 759 } 760 Changed |= tryOptimizeConsecStores(Stores, MIB); 761 resetState(); 762 } 763 764 return Changed; 765 } 766 767 char AArch64PostLegalizerCombiner::ID = 0; 768 INITIALIZE_PASS_BEGIN(AArch64PostLegalizerCombiner, DEBUG_TYPE, 769 "Combine AArch64 MachineInstrs after legalization", false, 770 false) 771 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 772 INITIALIZE_PASS_DEPENDENCY(GISelKnownBitsAnalysis) 773 INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE, 774 "Combine AArch64 MachineInstrs after legalization", false, 775 false) 776 777 namespace llvm { 778 FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) { 779 return new AArch64PostLegalizerCombiner(IsOptNone); 780 } 781 } // end namespace llvm 782