1 //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 9 #include "llvm/ADT/APFloat.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/ADT/SetVector.h" 12 #include "llvm/ADT/SmallBitVector.h" 13 #include "llvm/Analysis/CmpInstAnalysis.h" 14 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 15 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h" 16 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 19 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 21 #include "llvm/CodeGen/GlobalISel/Utils.h" 22 #include "llvm/CodeGen/LowLevelTypeUtils.h" 23 #include "llvm/CodeGen/MachineBasicBlock.h" 24 #include "llvm/CodeGen/MachineDominators.h" 25 #include "llvm/CodeGen/MachineInstr.h" 26 #include "llvm/CodeGen/MachineMemOperand.h" 27 #include "llvm/CodeGen/MachineRegisterInfo.h" 28 #include "llvm/CodeGen/RegisterBankInfo.h" 29 #include "llvm/CodeGen/TargetInstrInfo.h" 30 #include "llvm/CodeGen/TargetLowering.h" 31 #include "llvm/CodeGen/TargetOpcodes.h" 32 #include "llvm/IR/ConstantRange.h" 33 #include "llvm/IR/DataLayout.h" 34 #include "llvm/IR/InstrTypes.h" 35 #include "llvm/Support/Casting.h" 36 #include "llvm/Support/DivisionByConstantInfo.h" 37 #include "llvm/Support/ErrorHandling.h" 38 #include "llvm/Support/MathExtras.h" 39 #include "llvm/Target/TargetMachine.h" 40 #include <cmath> 41 #include <optional> 42 #include <tuple> 43 44 #define DEBUG_TYPE "gi-combiner" 45 46 using namespace llvm; 47 using namespace MIPatternMatch; 48 49 // Option to allow testing of the combiner while no targets know about indexed 50 // addressing. 51 static cl::opt<bool> 52 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(false), 53 cl::desc("Force all indexed operations to be " 54 "legal for the GlobalISel combiner")); 55 56 CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, 57 MachineIRBuilder &B, bool IsPreLegalize, 58 GISelKnownBits *KB, MachineDominatorTree *MDT, 59 const LegalizerInfo *LI) 60 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), KB(KB), 61 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), 62 RBI(Builder.getMF().getSubtarget().getRegBankInfo()), 63 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { 64 (void)this->KB; 65 } 66 67 const TargetLowering &CombinerHelper::getTargetLowering() const { 68 return *Builder.getMF().getSubtarget().getTargetLowering(); 69 } 70 71 const MachineFunction &CombinerHelper::getMachineFunction() const { 72 return Builder.getMF(); 73 } 74 75 const DataLayout &CombinerHelper::getDataLayout() const { 76 return getMachineFunction().getDataLayout(); 77 } 78 79 LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); } 80 81 /// \returns The little endian in-memory byte position of byte \p I in a 82 /// \p ByteWidth bytes wide type. 83 /// 84 /// E.g. Given a 4-byte type x, x[0] -> byte 0 85 static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { 86 assert(I < ByteWidth && "I must be in [0, ByteWidth)"); 87 return I; 88 } 89 90 /// Determines the LogBase2 value for a non-null input value using the 91 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). 92 static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { 93 auto &MRI = *MIB.getMRI(); 94 LLT Ty = MRI.getType(V); 95 auto Ctlz = MIB.buildCTLZ(Ty, V); 96 auto Base = MIB.buildConstant(Ty, Ty.getScalarSizeInBits() - 1); 97 return MIB.buildSub(Ty, Base, Ctlz).getReg(0); 98 } 99 100 /// \returns The big endian in-memory byte position of byte \p I in a 101 /// \p ByteWidth bytes wide type. 102 /// 103 /// E.g. Given a 4-byte type x, x[0] -> byte 3 104 static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { 105 assert(I < ByteWidth && "I must be in [0, ByteWidth)"); 106 return ByteWidth - I - 1; 107 } 108 109 /// Given a map from byte offsets in memory to indices in a load/store, 110 /// determine if that map corresponds to a little or big endian byte pattern. 111 /// 112 /// \param MemOffset2Idx maps memory offsets to address offsets. 113 /// \param LowestIdx is the lowest index in \p MemOffset2Idx. 114 /// 115 /// \returns true if the map corresponds to a big endian byte pattern, false if 116 /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. 117 /// 118 /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns 119 /// are as follows: 120 /// 121 /// AddrOffset Little endian Big endian 122 /// 0 0 3 123 /// 1 1 2 124 /// 2 2 1 125 /// 3 3 0 126 static std::optional<bool> 127 isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, 128 int64_t LowestIdx) { 129 // Need at least two byte positions to decide on endianness. 130 unsigned Width = MemOffset2Idx.size(); 131 if (Width < 2) 132 return std::nullopt; 133 bool BigEndian = true, LittleEndian = true; 134 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { 135 auto MemOffsetAndIdx = MemOffset2Idx.find(MemOffset); 136 if (MemOffsetAndIdx == MemOffset2Idx.end()) 137 return std::nullopt; 138 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; 139 assert(Idx >= 0 && "Expected non-negative byte offset?"); 140 LittleEndian &= Idx == littleEndianByteAt(Width, MemOffset); 141 BigEndian &= Idx == bigEndianByteAt(Width, MemOffset); 142 if (!BigEndian && !LittleEndian) 143 return std::nullopt; 144 } 145 146 assert((BigEndian != LittleEndian) && 147 "Pattern cannot be both big and little endian!"); 148 return BigEndian; 149 } 150 151 bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } 152 153 bool CombinerHelper::isLegal(const LegalityQuery &Query) const { 154 assert(LI && "Must have LegalizerInfo to query isLegal!"); 155 return LI->getAction(Query).Action == LegalizeActions::Legal; 156 } 157 158 bool CombinerHelper::isLegalOrBeforeLegalizer( 159 const LegalityQuery &Query) const { 160 return isPreLegalize() || isLegal(Query); 161 } 162 163 bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { 164 if (!Ty.isVector()) 165 return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}}); 166 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. 167 if (isPreLegalize()) 168 return true; 169 LLT EltTy = Ty.getElementType(); 170 return isLegal({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && 171 isLegal({TargetOpcode::G_CONSTANT, {EltTy}}); 172 } 173 174 void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, 175 Register ToReg) const { 176 Observer.changingAllUsesOfReg(MRI, FromReg); 177 178 if (MRI.constrainRegAttrs(ToReg, FromReg)) 179 MRI.replaceRegWith(FromReg, ToReg); 180 else 181 Builder.buildCopy(FromReg, ToReg); 182 183 Observer.finishedChangingAllUsesOfReg(); 184 } 185 186 void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, 187 MachineOperand &FromRegOp, 188 Register ToReg) const { 189 assert(FromRegOp.getParent() && "Expected an operand in an MI"); 190 Observer.changingInstr(*FromRegOp.getParent()); 191 192 FromRegOp.setReg(ToReg); 193 194 Observer.changedInstr(*FromRegOp.getParent()); 195 } 196 197 void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, 198 unsigned ToOpcode) const { 199 Observer.changingInstr(FromMI); 200 201 FromMI.setDesc(Builder.getTII().get(ToOpcode)); 202 203 Observer.changedInstr(FromMI); 204 } 205 206 const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { 207 return RBI->getRegBank(Reg, MRI, *TRI); 208 } 209 210 void CombinerHelper::setRegBank(Register Reg, 211 const RegisterBank *RegBank) const { 212 if (RegBank) 213 MRI.setRegBank(Reg, *RegBank); 214 } 215 216 bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const { 217 if (matchCombineCopy(MI)) { 218 applyCombineCopy(MI); 219 return true; 220 } 221 return false; 222 } 223 bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const { 224 if (MI.getOpcode() != TargetOpcode::COPY) 225 return false; 226 Register DstReg = MI.getOperand(0).getReg(); 227 Register SrcReg = MI.getOperand(1).getReg(); 228 return canReplaceReg(DstReg, SrcReg, MRI); 229 } 230 void CombinerHelper::applyCombineCopy(MachineInstr &MI) const { 231 Register DstReg = MI.getOperand(0).getReg(); 232 Register SrcReg = MI.getOperand(1).getReg(); 233 replaceRegWith(MRI, DstReg, SrcReg); 234 MI.eraseFromParent(); 235 } 236 237 bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand( 238 MachineInstr &MI, BuildFnTy &MatchInfo) const { 239 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating. 240 Register DstOp = MI.getOperand(0).getReg(); 241 Register OrigOp = MI.getOperand(1).getReg(); 242 243 if (!MRI.hasOneNonDBGUse(OrigOp)) 244 return false; 245 246 MachineInstr *OrigDef = MRI.getUniqueVRegDef(OrigOp); 247 // Even if only a single operand of the PHI is not guaranteed non-poison, 248 // moving freeze() backwards across a PHI can cause optimization issues for 249 // other users of that operand. 250 // 251 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to 252 // the source register is unprofitable because it makes the freeze() more 253 // strict than is necessary (it would affect the whole register instead of 254 // just the subreg being frozen). 255 if (OrigDef->isPHI() || isa<GUnmerge>(OrigDef)) 256 return false; 257 258 if (canCreateUndefOrPoison(OrigOp, MRI, 259 /*ConsiderFlagsAndMetadata=*/false)) 260 return false; 261 262 std::optional<MachineOperand> MaybePoisonOperand; 263 for (MachineOperand &Operand : OrigDef->uses()) { 264 if (!Operand.isReg()) 265 return false; 266 267 if (isGuaranteedNotToBeUndefOrPoison(Operand.getReg(), MRI)) 268 continue; 269 270 if (!MaybePoisonOperand) 271 MaybePoisonOperand = Operand; 272 else { 273 // We have more than one maybe-poison operand. Moving the freeze is 274 // unsafe. 275 return false; 276 } 277 } 278 279 // Eliminate freeze if all operands are guaranteed non-poison. 280 if (!MaybePoisonOperand) { 281 MatchInfo = [=](MachineIRBuilder &B) { 282 Observer.changingInstr(*OrigDef); 283 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags(); 284 Observer.changedInstr(*OrigDef); 285 B.buildCopy(DstOp, OrigOp); 286 }; 287 return true; 288 } 289 290 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg(); 291 LLT MaybePoisonOperandRegTy = MRI.getType(MaybePoisonOperandReg); 292 293 MatchInfo = [=](MachineIRBuilder &B) mutable { 294 Observer.changingInstr(*OrigDef); 295 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags(); 296 Observer.changedInstr(*OrigDef); 297 B.setInsertPt(*OrigDef->getParent(), OrigDef->getIterator()); 298 auto Freeze = B.buildFreeze(MaybePoisonOperandRegTy, MaybePoisonOperandReg); 299 replaceRegOpWith( 300 MRI, *OrigDef->findRegisterUseOperand(MaybePoisonOperandReg, TRI), 301 Freeze.getReg(0)); 302 replaceRegWith(MRI, DstOp, OrigOp); 303 }; 304 return true; 305 } 306 307 bool CombinerHelper::matchCombineConcatVectors( 308 MachineInstr &MI, SmallVector<Register> &Ops) const { 309 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && 310 "Invalid instruction"); 311 bool IsUndef = true; 312 MachineInstr *Undef = nullptr; 313 314 // Walk over all the operands of concat vectors and check if they are 315 // build_vector themselves or undef. 316 // Then collect their operands in Ops. 317 for (const MachineOperand &MO : MI.uses()) { 318 Register Reg = MO.getReg(); 319 MachineInstr *Def = MRI.getVRegDef(Reg); 320 assert(Def && "Operand not defined"); 321 if (!MRI.hasOneNonDBGUse(Reg)) 322 return false; 323 switch (Def->getOpcode()) { 324 case TargetOpcode::G_BUILD_VECTOR: 325 IsUndef = false; 326 // Remember the operands of the build_vector to fold 327 // them into the yet-to-build flattened concat vectors. 328 for (const MachineOperand &BuildVecMO : Def->uses()) 329 Ops.push_back(BuildVecMO.getReg()); 330 break; 331 case TargetOpcode::G_IMPLICIT_DEF: { 332 LLT OpType = MRI.getType(Reg); 333 // Keep one undef value for all the undef operands. 334 if (!Undef) { 335 Builder.setInsertPt(*MI.getParent(), MI); 336 Undef = Builder.buildUndef(OpType.getScalarType()); 337 } 338 assert(MRI.getType(Undef->getOperand(0).getReg()) == 339 OpType.getScalarType() && 340 "All undefs should have the same type"); 341 // Break the undef vector in as many scalar elements as needed 342 // for the flattening. 343 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); 344 EltIdx != EltEnd; ++EltIdx) 345 Ops.push_back(Undef->getOperand(0).getReg()); 346 break; 347 } 348 default: 349 return false; 350 } 351 } 352 353 // Check if the combine is illegal 354 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 355 if (!isLegalOrBeforeLegalizer( 356 {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Ops[0])}})) { 357 return false; 358 } 359 360 if (IsUndef) 361 Ops.clear(); 362 363 return true; 364 } 365 void CombinerHelper::applyCombineConcatVectors( 366 MachineInstr &MI, SmallVector<Register> &Ops) const { 367 // We determined that the concat_vectors can be flatten. 368 // Generate the flattened build_vector. 369 Register DstReg = MI.getOperand(0).getReg(); 370 Builder.setInsertPt(*MI.getParent(), MI); 371 Register NewDstReg = MRI.cloneVirtualRegister(DstReg); 372 373 // Note: IsUndef is sort of redundant. We could have determine it by 374 // checking that at all Ops are undef. Alternatively, we could have 375 // generate a build_vector of undefs and rely on another combine to 376 // clean that up. For now, given we already gather this information 377 // in matchCombineConcatVectors, just save compile time and issue the 378 // right thing. 379 if (Ops.empty()) 380 Builder.buildUndef(NewDstReg); 381 else 382 Builder.buildBuildVector(NewDstReg, Ops); 383 replaceRegWith(MRI, DstReg, NewDstReg); 384 MI.eraseFromParent(); 385 } 386 387 bool CombinerHelper::matchCombineShuffleConcat( 388 MachineInstr &MI, SmallVector<Register> &Ops) const { 389 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 390 auto ConcatMI1 = 391 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(1).getReg())); 392 auto ConcatMI2 = 393 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(2).getReg())); 394 if (!ConcatMI1 || !ConcatMI2) 395 return false; 396 397 // Check that the sources of the Concat instructions have the same type 398 if (MRI.getType(ConcatMI1->getSourceReg(0)) != 399 MRI.getType(ConcatMI2->getSourceReg(0))) 400 return false; 401 402 LLT ConcatSrcTy = MRI.getType(ConcatMI1->getReg(1)); 403 LLT ShuffleSrcTy1 = MRI.getType(MI.getOperand(1).getReg()); 404 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements(); 405 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) { 406 // Check if the index takes a whole source register from G_CONCAT_VECTORS 407 // Assumes that all Sources of G_CONCAT_VECTORS are the same type 408 if (Mask[i] == -1) { 409 for (unsigned j = 1; j < ConcatSrcNumElt; j++) { 410 if (i + j >= Mask.size()) 411 return false; 412 if (Mask[i + j] != -1) 413 return false; 414 } 415 if (!isLegalOrBeforeLegalizer( 416 {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}})) 417 return false; 418 Ops.push_back(0); 419 } else if (Mask[i] % ConcatSrcNumElt == 0) { 420 for (unsigned j = 1; j < ConcatSrcNumElt; j++) { 421 if (i + j >= Mask.size()) 422 return false; 423 if (Mask[i + j] != Mask[i] + static_cast<int>(j)) 424 return false; 425 } 426 // Retrieve the source register from its respective G_CONCAT_VECTORS 427 // instruction 428 if (Mask[i] < ShuffleSrcTy1.getNumElements()) { 429 Ops.push_back(ConcatMI1->getSourceReg(Mask[i] / ConcatSrcNumElt)); 430 } else { 431 Ops.push_back(ConcatMI2->getSourceReg(Mask[i] / ConcatSrcNumElt - 432 ConcatMI1->getNumSources())); 433 } 434 } else { 435 return false; 436 } 437 } 438 439 if (!isLegalOrBeforeLegalizer( 440 {TargetOpcode::G_CONCAT_VECTORS, 441 {MRI.getType(MI.getOperand(0).getReg()), ConcatSrcTy}})) 442 return false; 443 444 return !Ops.empty(); 445 } 446 447 void CombinerHelper::applyCombineShuffleConcat( 448 MachineInstr &MI, SmallVector<Register> &Ops) const { 449 LLT SrcTy; 450 for (Register &Reg : Ops) { 451 if (Reg != 0) 452 SrcTy = MRI.getType(Reg); 453 } 454 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine"); 455 456 Register UndefReg = 0; 457 458 for (Register &Reg : Ops) { 459 if (Reg == 0) { 460 if (UndefReg == 0) 461 UndefReg = Builder.buildUndef(SrcTy).getReg(0); 462 Reg = UndefReg; 463 } 464 } 465 466 if (Ops.size() > 1) 467 Builder.buildConcatVectors(MI.getOperand(0).getReg(), Ops); 468 else 469 Builder.buildCopy(MI.getOperand(0).getReg(), Ops[0]); 470 MI.eraseFromParent(); 471 } 472 473 bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const { 474 SmallVector<Register, 4> Ops; 475 if (matchCombineShuffleVector(MI, Ops)) { 476 applyCombineShuffleVector(MI, Ops); 477 return true; 478 } 479 return false; 480 } 481 482 bool CombinerHelper::matchCombineShuffleVector( 483 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const { 484 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && 485 "Invalid instruction kind"); 486 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 487 Register Src1 = MI.getOperand(1).getReg(); 488 LLT SrcType = MRI.getType(Src1); 489 // As bizarre as it may look, shuffle vector can actually produce 490 // scalar! This is because at the IR level a <1 x ty> shuffle 491 // vector is perfectly valid. 492 unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; 493 unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; 494 495 // If the resulting vector is smaller than the size of the source 496 // vectors being concatenated, we won't be able to replace the 497 // shuffle vector into a concat_vectors. 498 // 499 // Note: We may still be able to produce a concat_vectors fed by 500 // extract_vector_elt and so on. It is less clear that would 501 // be better though, so don't bother for now. 502 // 503 // If the destination is a scalar, the size of the sources doesn't 504 // matter. we will lower the shuffle to a plain copy. This will 505 // work only if the source and destination have the same size. But 506 // that's covered by the next condition. 507 // 508 // TODO: If the size between the source and destination don't match 509 // we could still emit an extract vector element in that case. 510 if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) 511 return false; 512 513 // Check that the shuffle mask can be broken evenly between the 514 // different sources. 515 if (DstNumElts % SrcNumElts != 0) 516 return false; 517 518 // Mask length is a multiple of the source vector length. 519 // Check if the shuffle is some kind of concatenation of the input 520 // vectors. 521 unsigned NumConcat = DstNumElts / SrcNumElts; 522 SmallVector<int, 8> ConcatSrcs(NumConcat, -1); 523 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 524 for (unsigned i = 0; i != DstNumElts; ++i) { 525 int Idx = Mask[i]; 526 // Undef value. 527 if (Idx < 0) 528 continue; 529 // Ensure the indices in each SrcType sized piece are sequential and that 530 // the same source is used for the whole piece. 531 if ((Idx % SrcNumElts != (i % SrcNumElts)) || 532 (ConcatSrcs[i / SrcNumElts] >= 0 && 533 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) 534 return false; 535 // Remember which source this index came from. 536 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; 537 } 538 539 // The shuffle is concatenating multiple vectors together. 540 // Collect the different operands for that. 541 Register UndefReg; 542 Register Src2 = MI.getOperand(2).getReg(); 543 for (auto Src : ConcatSrcs) { 544 if (Src < 0) { 545 if (!UndefReg) { 546 Builder.setInsertPt(*MI.getParent(), MI); 547 UndefReg = Builder.buildUndef(SrcType).getReg(0); 548 } 549 Ops.push_back(UndefReg); 550 } else if (Src == 0) 551 Ops.push_back(Src1); 552 else 553 Ops.push_back(Src2); 554 } 555 return true; 556 } 557 558 void CombinerHelper::applyCombineShuffleVector( 559 MachineInstr &MI, const ArrayRef<Register> Ops) const { 560 Register DstReg = MI.getOperand(0).getReg(); 561 Builder.setInsertPt(*MI.getParent(), MI); 562 Register NewDstReg = MRI.cloneVirtualRegister(DstReg); 563 564 if (Ops.size() == 1) 565 Builder.buildCopy(NewDstReg, Ops[0]); 566 else 567 Builder.buildMergeLikeInstr(NewDstReg, Ops); 568 569 replaceRegWith(MRI, DstReg, NewDstReg); 570 MI.eraseFromParent(); 571 } 572 573 bool CombinerHelper::matchShuffleToExtract(MachineInstr &MI) const { 574 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && 575 "Invalid instruction kind"); 576 577 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 578 return Mask.size() == 1; 579 } 580 581 void CombinerHelper::applyShuffleToExtract(MachineInstr &MI) const { 582 Register DstReg = MI.getOperand(0).getReg(); 583 Builder.setInsertPt(*MI.getParent(), MI); 584 585 int I = MI.getOperand(3).getShuffleMask()[0]; 586 Register Src1 = MI.getOperand(1).getReg(); 587 LLT Src1Ty = MRI.getType(Src1); 588 int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; 589 Register SrcReg; 590 if (I >= Src1NumElts) { 591 SrcReg = MI.getOperand(2).getReg(); 592 I -= Src1NumElts; 593 } else if (I >= 0) 594 SrcReg = Src1; 595 596 if (I < 0) 597 Builder.buildUndef(DstReg); 598 else if (!MRI.getType(SrcReg).isVector()) 599 Builder.buildCopy(DstReg, SrcReg); 600 else 601 Builder.buildExtractVectorElementConstant(DstReg, SrcReg, I); 602 603 MI.eraseFromParent(); 604 } 605 606 namespace { 607 608 /// Select a preference between two uses. CurrentUse is the current preference 609 /// while *ForCandidate is attributes of the candidate under consideration. 610 PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, 611 PreferredTuple &CurrentUse, 612 const LLT TyForCandidate, 613 unsigned OpcodeForCandidate, 614 MachineInstr *MIForCandidate) { 615 if (!CurrentUse.Ty.isValid()) { 616 if (CurrentUse.ExtendOpcode == OpcodeForCandidate || 617 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) 618 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 619 return CurrentUse; 620 } 621 622 // We permit the extend to hoist through basic blocks but this is only 623 // sensible if the target has extending loads. If you end up lowering back 624 // into a load and extend during the legalizer then the end result is 625 // hoisting the extend up to the load. 626 627 // Prefer defined extensions to undefined extensions as these are more 628 // likely to reduce the number of instructions. 629 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && 630 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) 631 return CurrentUse; 632 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && 633 OpcodeForCandidate != TargetOpcode::G_ANYEXT) 634 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 635 636 // Prefer sign extensions to zero extensions as sign-extensions tend to be 637 // more expensive. Don't do this if the load is already a zero-extend load 638 // though, otherwise we'll rewrite a zero-extend load into a sign-extend 639 // later. 640 if (!isa<GZExtLoad>(LoadMI) && CurrentUse.Ty == TyForCandidate) { 641 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && 642 OpcodeForCandidate == TargetOpcode::G_ZEXT) 643 return CurrentUse; 644 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && 645 OpcodeForCandidate == TargetOpcode::G_SEXT) 646 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 647 } 648 649 // This is potentially target specific. We've chosen the largest type 650 // because G_TRUNC is usually free. One potential catch with this is that 651 // some targets have a reduced number of larger registers than smaller 652 // registers and this choice potentially increases the live-range for the 653 // larger value. 654 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { 655 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 656 } 657 return CurrentUse; 658 } 659 660 /// Find a suitable place to insert some instructions and insert them. This 661 /// function accounts for special cases like inserting before a PHI node. 662 /// The current strategy for inserting before PHI's is to duplicate the 663 /// instructions for each predecessor. However, while that's ok for G_TRUNC 664 /// on most targets since it generally requires no code, other targets/cases may 665 /// want to try harder to find a dominating block. 666 static void InsertInsnsWithoutSideEffectsBeforeUse( 667 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, 668 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, 669 MachineOperand &UseMO)> 670 Inserter) { 671 MachineInstr &UseMI = *UseMO.getParent(); 672 673 MachineBasicBlock *InsertBB = UseMI.getParent(); 674 675 // If the use is a PHI then we want the predecessor block instead. 676 if (UseMI.isPHI()) { 677 MachineOperand *PredBB = std::next(&UseMO); 678 InsertBB = PredBB->getMBB(); 679 } 680 681 // If the block is the same block as the def then we want to insert just after 682 // the def instead of at the start of the block. 683 if (InsertBB == DefMI.getParent()) { 684 MachineBasicBlock::iterator InsertPt = &DefMI; 685 Inserter(InsertBB, std::next(InsertPt), UseMO); 686 return; 687 } 688 689 // Otherwise we want the start of the BB 690 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); 691 } 692 } // end anonymous namespace 693 694 bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const { 695 PreferredTuple Preferred; 696 if (matchCombineExtendingLoads(MI, Preferred)) { 697 applyCombineExtendingLoads(MI, Preferred); 698 return true; 699 } 700 return false; 701 } 702 703 static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { 704 unsigned CandidateLoadOpc; 705 switch (ExtOpc) { 706 case TargetOpcode::G_ANYEXT: 707 CandidateLoadOpc = TargetOpcode::G_LOAD; 708 break; 709 case TargetOpcode::G_SEXT: 710 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; 711 break; 712 case TargetOpcode::G_ZEXT: 713 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; 714 break; 715 default: 716 llvm_unreachable("Unexpected extend opc"); 717 } 718 return CandidateLoadOpc; 719 } 720 721 bool CombinerHelper::matchCombineExtendingLoads( 722 MachineInstr &MI, PreferredTuple &Preferred) const { 723 // We match the loads and follow the uses to the extend instead of matching 724 // the extends and following the def to the load. This is because the load 725 // must remain in the same position for correctness (unless we also add code 726 // to find a safe place to sink it) whereas the extend is freely movable. 727 // It also prevents us from duplicating the load for the volatile case or just 728 // for performance. 729 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(&MI); 730 if (!LoadMI) 731 return false; 732 733 Register LoadReg = LoadMI->getDstReg(); 734 735 LLT LoadValueTy = MRI.getType(LoadReg); 736 if (!LoadValueTy.isScalar()) 737 return false; 738 739 // Most architectures are going to legalize <s8 loads into at least a 1 byte 740 // load, and the MMOs can only describe memory accesses in multiples of bytes. 741 // If we try to perform extload combining on those, we can end up with 742 // %a(s8) = extload %ptr (load 1 byte from %ptr) 743 // ... which is an illegal extload instruction. 744 if (LoadValueTy.getSizeInBits() < 8) 745 return false; 746 747 // For non power-of-2 types, they will very likely be legalized into multiple 748 // loads. Don't bother trying to match them into extending loads. 749 if (!llvm::has_single_bit<uint32_t>(LoadValueTy.getSizeInBits())) 750 return false; 751 752 // Find the preferred type aside from the any-extends (unless it's the only 753 // one) and non-extending ops. We'll emit an extending load to that type and 754 // and emit a variant of (extend (trunc X)) for the others according to the 755 // relative type sizes. At the same time, pick an extend to use based on the 756 // extend involved in the chosen type. 757 unsigned PreferredOpcode = 758 isa<GLoad>(&MI) 759 ? TargetOpcode::G_ANYEXT 760 : isa<GSExtLoad>(&MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; 761 Preferred = {LLT(), PreferredOpcode, nullptr}; 762 for (auto &UseMI : MRI.use_nodbg_instructions(LoadReg)) { 763 if (UseMI.getOpcode() == TargetOpcode::G_SEXT || 764 UseMI.getOpcode() == TargetOpcode::G_ZEXT || 765 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { 766 const auto &MMO = LoadMI->getMMO(); 767 // Don't do anything for atomics. 768 if (MMO.isAtomic()) 769 continue; 770 // Check for legality. 771 if (!isPreLegalize()) { 772 LegalityQuery::MemDesc MMDesc(MMO); 773 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(UseMI.getOpcode()); 774 LLT UseTy = MRI.getType(UseMI.getOperand(0).getReg()); 775 LLT SrcTy = MRI.getType(LoadMI->getPointerReg()); 776 if (LI->getAction({CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) 777 .Action != LegalizeActions::Legal) 778 continue; 779 } 780 Preferred = ChoosePreferredUse(MI, Preferred, 781 MRI.getType(UseMI.getOperand(0).getReg()), 782 UseMI.getOpcode(), &UseMI); 783 } 784 } 785 786 // There were no extends 787 if (!Preferred.MI) 788 return false; 789 // It should be impossible to chose an extend without selecting a different 790 // type since by definition the result of an extend is larger. 791 assert(Preferred.Ty != LoadValueTy && "Extending to same type?"); 792 793 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); 794 return true; 795 } 796 797 void CombinerHelper::applyCombineExtendingLoads( 798 MachineInstr &MI, PreferredTuple &Preferred) const { 799 // Rewrite the load to the chosen extending load. 800 Register ChosenDstReg = Preferred.MI->getOperand(0).getReg(); 801 802 // Inserter to insert a truncate back to the original type at a given point 803 // with some basic CSE to limit truncate duplication to one per BB. 804 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; 805 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, 806 MachineBasicBlock::iterator InsertBefore, 807 MachineOperand &UseMO) { 808 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(InsertIntoBB); 809 if (PreviouslyEmitted) { 810 Observer.changingInstr(*UseMO.getParent()); 811 UseMO.setReg(PreviouslyEmitted->getOperand(0).getReg()); 812 Observer.changedInstr(*UseMO.getParent()); 813 return; 814 } 815 816 Builder.setInsertPt(*InsertIntoBB, InsertBefore); 817 Register NewDstReg = MRI.cloneVirtualRegister(MI.getOperand(0).getReg()); 818 MachineInstr *NewMI = Builder.buildTrunc(NewDstReg, ChosenDstReg); 819 EmittedInsns[InsertIntoBB] = NewMI; 820 replaceRegOpWith(MRI, UseMO, NewDstReg); 821 }; 822 823 Observer.changingInstr(MI); 824 unsigned LoadOpc = getExtLoadOpcForExtend(Preferred.ExtendOpcode); 825 MI.setDesc(Builder.getTII().get(LoadOpc)); 826 827 // Rewrite all the uses to fix up the types. 828 auto &LoadValue = MI.getOperand(0); 829 SmallVector<MachineOperand *, 4> Uses; 830 for (auto &UseMO : MRI.use_operands(LoadValue.getReg())) 831 Uses.push_back(&UseMO); 832 833 for (auto *UseMO : Uses) { 834 MachineInstr *UseMI = UseMO->getParent(); 835 836 // If the extend is compatible with the preferred extend then we should fix 837 // up the type and extend so that it uses the preferred use. 838 if (UseMI->getOpcode() == Preferred.ExtendOpcode || 839 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { 840 Register UseDstReg = UseMI->getOperand(0).getReg(); 841 MachineOperand &UseSrcMO = UseMI->getOperand(1); 842 const LLT UseDstTy = MRI.getType(UseDstReg); 843 if (UseDstReg != ChosenDstReg) { 844 if (Preferred.Ty == UseDstTy) { 845 // If the use has the same type as the preferred use, then merge 846 // the vregs and erase the extend. For example: 847 // %1:_(s8) = G_LOAD ... 848 // %2:_(s32) = G_SEXT %1(s8) 849 // %3:_(s32) = G_ANYEXT %1(s8) 850 // ... = ... %3(s32) 851 // rewrites to: 852 // %2:_(s32) = G_SEXTLOAD ... 853 // ... = ... %2(s32) 854 replaceRegWith(MRI, UseDstReg, ChosenDstReg); 855 Observer.erasingInstr(*UseMO->getParent()); 856 UseMO->getParent()->eraseFromParent(); 857 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { 858 // If the preferred size is smaller, then keep the extend but extend 859 // from the result of the extending load. For example: 860 // %1:_(s8) = G_LOAD ... 861 // %2:_(s32) = G_SEXT %1(s8) 862 // %3:_(s64) = G_ANYEXT %1(s8) 863 // ... = ... %3(s64) 864 /// rewrites to: 865 // %2:_(s32) = G_SEXTLOAD ... 866 // %3:_(s64) = G_ANYEXT %2:_(s32) 867 // ... = ... %3(s64) 868 replaceRegOpWith(MRI, UseSrcMO, ChosenDstReg); 869 } else { 870 // If the preferred size is large, then insert a truncate. For 871 // example: 872 // %1:_(s8) = G_LOAD ... 873 // %2:_(s64) = G_SEXT %1(s8) 874 // %3:_(s32) = G_ZEXT %1(s8) 875 // ... = ... %3(s32) 876 /// rewrites to: 877 // %2:_(s64) = G_SEXTLOAD ... 878 // %4:_(s8) = G_TRUNC %2:_(s32) 879 // %3:_(s64) = G_ZEXT %2:_(s8) 880 // ... = ... %3(s64) 881 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, 882 InsertTruncAt); 883 } 884 continue; 885 } 886 // The use is (one of) the uses of the preferred use we chose earlier. 887 // We're going to update the load to def this value later so just erase 888 // the old extend. 889 Observer.erasingInstr(*UseMO->getParent()); 890 UseMO->getParent()->eraseFromParent(); 891 continue; 892 } 893 894 // The use isn't an extend. Truncate back to the type we originally loaded. 895 // This is free on many targets. 896 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, InsertTruncAt); 897 } 898 899 MI.getOperand(0).setReg(ChosenDstReg); 900 Observer.changedInstr(MI); 901 } 902 903 bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, 904 BuildFnTy &MatchInfo) const { 905 assert(MI.getOpcode() == TargetOpcode::G_AND); 906 907 // If we have the following code: 908 // %mask = G_CONSTANT 255 909 // %ld = G_LOAD %ptr, (load s16) 910 // %and = G_AND %ld, %mask 911 // 912 // Try to fold it into 913 // %ld = G_ZEXTLOAD %ptr, (load s8) 914 915 Register Dst = MI.getOperand(0).getReg(); 916 if (MRI.getType(Dst).isVector()) 917 return false; 918 919 auto MaybeMask = 920 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 921 if (!MaybeMask) 922 return false; 923 924 APInt MaskVal = MaybeMask->Value; 925 926 if (!MaskVal.isMask()) 927 return false; 928 929 Register SrcReg = MI.getOperand(1).getReg(); 930 // Don't use getOpcodeDef() here since intermediate instructions may have 931 // multiple users. 932 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg)); 933 if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg())) 934 return false; 935 936 Register LoadReg = LoadMI->getDstReg(); 937 LLT RegTy = MRI.getType(LoadReg); 938 Register PtrReg = LoadMI->getPointerReg(); 939 unsigned RegSize = RegTy.getSizeInBits(); 940 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits(); 941 unsigned MaskSizeBits = MaskVal.countr_one(); 942 943 // The mask may not be larger than the in-memory type, as it might cover sign 944 // extended bits 945 if (MaskSizeBits > LoadSizeBits.getValue()) 946 return false; 947 948 // If the mask covers the whole destination register, there's nothing to 949 // extend 950 if (MaskSizeBits >= RegSize) 951 return false; 952 953 // Most targets cannot deal with loads of size < 8 and need to re-legalize to 954 // at least byte loads. Avoid creating such loads here 955 if (MaskSizeBits < 8 || !isPowerOf2_32(MaskSizeBits)) 956 return false; 957 958 const MachineMemOperand &MMO = LoadMI->getMMO(); 959 LegalityQuery::MemDesc MemDesc(MMO); 960 961 // Don't modify the memory access size if this is atomic/volatile, but we can 962 // still adjust the opcode to indicate the high bit behavior. 963 if (LoadMI->isSimple()) 964 MemDesc.MemoryTy = LLT::scalar(MaskSizeBits); 965 else if (LoadSizeBits.getValue() > MaskSizeBits || 966 LoadSizeBits.getValue() == RegSize) 967 return false; 968 969 // TODO: Could check if it's legal with the reduced or original memory size. 970 if (!isLegalOrBeforeLegalizer( 971 {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}})) 972 return false; 973 974 MatchInfo = [=](MachineIRBuilder &B) { 975 B.setInstrAndDebugLoc(*LoadMI); 976 auto &MF = B.getMF(); 977 auto PtrInfo = MMO.getPointerInfo(); 978 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy); 979 B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO); 980 LoadMI->eraseFromParent(); 981 }; 982 return true; 983 } 984 985 bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, 986 const MachineInstr &UseMI) const { 987 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && 988 "shouldn't consider debug uses"); 989 assert(DefMI.getParent() == UseMI.getParent()); 990 if (&DefMI == &UseMI) 991 return true; 992 const MachineBasicBlock &MBB = *DefMI.getParent(); 993 auto DefOrUse = find_if(MBB, [&DefMI, &UseMI](const MachineInstr &MI) { 994 return &MI == &DefMI || &MI == &UseMI; 995 }); 996 if (DefOrUse == MBB.end()) 997 llvm_unreachable("Block must contain both DefMI and UseMI!"); 998 return &*DefOrUse == &DefMI; 999 } 1000 1001 bool CombinerHelper::dominates(const MachineInstr &DefMI, 1002 const MachineInstr &UseMI) const { 1003 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && 1004 "shouldn't consider debug uses"); 1005 if (MDT) 1006 return MDT->dominates(&DefMI, &UseMI); 1007 else if (DefMI.getParent() != UseMI.getParent()) 1008 return false; 1009 1010 return isPredecessor(DefMI, UseMI); 1011 } 1012 1013 bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const { 1014 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1015 Register SrcReg = MI.getOperand(1).getReg(); 1016 Register LoadUser = SrcReg; 1017 1018 if (MRI.getType(SrcReg).isVector()) 1019 return false; 1020 1021 Register TruncSrc; 1022 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) 1023 LoadUser = TruncSrc; 1024 1025 uint64_t SizeInBits = MI.getOperand(2).getImm(); 1026 // If the source is a G_SEXTLOAD from the same bit width, then we don't 1027 // need any extend at all, just a truncate. 1028 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(LoadUser, MRI)) { 1029 // If truncating more than the original extended value, abort. 1030 auto LoadSizeBits = LoadMI->getMemSizeInBits(); 1031 if (TruncSrc && 1032 MRI.getType(TruncSrc).getSizeInBits() < LoadSizeBits.getValue()) 1033 return false; 1034 if (LoadSizeBits == SizeInBits) 1035 return true; 1036 } 1037 return false; 1038 } 1039 1040 void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const { 1041 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1042 Builder.buildCopy(MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); 1043 MI.eraseFromParent(); 1044 } 1045 1046 bool CombinerHelper::matchSextInRegOfLoad( 1047 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { 1048 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1049 1050 Register DstReg = MI.getOperand(0).getReg(); 1051 LLT RegTy = MRI.getType(DstReg); 1052 1053 // Only supports scalars for now. 1054 if (RegTy.isVector()) 1055 return false; 1056 1057 Register SrcReg = MI.getOperand(1).getReg(); 1058 auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI); 1059 if (!LoadDef || !MRI.hasOneNonDBGUse(SrcReg)) 1060 return false; 1061 1062 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue(); 1063 1064 // If the sign extend extends from a narrower width than the load's width, 1065 // then we can narrow the load width when we combine to a G_SEXTLOAD. 1066 // Avoid widening the load at all. 1067 unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits); 1068 1069 // Don't generate G_SEXTLOADs with a < 1 byte width. 1070 if (NewSizeBits < 8) 1071 return false; 1072 // Don't bother creating a non-power-2 sextload, it will likely be broken up 1073 // anyway for most targets. 1074 if (!isPowerOf2_32(NewSizeBits)) 1075 return false; 1076 1077 const MachineMemOperand &MMO = LoadDef->getMMO(); 1078 LegalityQuery::MemDesc MMDesc(MMO); 1079 1080 // Don't modify the memory access size if this is atomic/volatile, but we can 1081 // still adjust the opcode to indicate the high bit behavior. 1082 if (LoadDef->isSimple()) 1083 MMDesc.MemoryTy = LLT::scalar(NewSizeBits); 1084 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) 1085 return false; 1086 1087 // TODO: Could check if it's legal with the reduced or original memory size. 1088 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD, 1089 {MRI.getType(LoadDef->getDstReg()), 1090 MRI.getType(LoadDef->getPointerReg())}, 1091 {MMDesc}})) 1092 return false; 1093 1094 MatchInfo = std::make_tuple(LoadDef->getDstReg(), NewSizeBits); 1095 return true; 1096 } 1097 1098 void CombinerHelper::applySextInRegOfLoad( 1099 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { 1100 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1101 Register LoadReg; 1102 unsigned ScalarSizeBits; 1103 std::tie(LoadReg, ScalarSizeBits) = MatchInfo; 1104 GLoad *LoadDef = cast<GLoad>(MRI.getVRegDef(LoadReg)); 1105 1106 // If we have the following: 1107 // %ld = G_LOAD %ptr, (load 2) 1108 // %ext = G_SEXT_INREG %ld, 8 1109 // ==> 1110 // %ld = G_SEXTLOAD %ptr (load 1) 1111 1112 auto &MMO = LoadDef->getMMO(); 1113 Builder.setInstrAndDebugLoc(*LoadDef); 1114 auto &MF = Builder.getMF(); 1115 auto PtrInfo = MMO.getPointerInfo(); 1116 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, ScalarSizeBits / 8); 1117 Builder.buildLoadInstr(TargetOpcode::G_SEXTLOAD, MI.getOperand(0).getReg(), 1118 LoadDef->getPointerReg(), *NewMMO); 1119 MI.eraseFromParent(); 1120 1121 // Not all loads can be deleted, so make sure the old one is removed. 1122 LoadDef->eraseFromParent(); 1123 } 1124 1125 /// Return true if 'MI' is a load or a store that may be fold it's address 1126 /// operand into the load / store addressing mode. 1127 static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, 1128 MachineRegisterInfo &MRI) { 1129 TargetLowering::AddrMode AM; 1130 auto *MF = MI->getMF(); 1131 auto *Addr = getOpcodeDef<GPtrAdd>(MI->getPointerReg(), MRI); 1132 if (!Addr) 1133 return false; 1134 1135 AM.HasBaseReg = true; 1136 if (auto CstOff = getIConstantVRegVal(Addr->getOffsetReg(), MRI)) 1137 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] 1138 else 1139 AM.Scale = 1; // [reg +/- reg] 1140 1141 return TLI.isLegalAddressingMode( 1142 MF->getDataLayout(), AM, 1143 getTypeForLLT(MI->getMMO().getMemoryType(), 1144 MF->getFunction().getContext()), 1145 MI->getMMO().getAddrSpace()); 1146 } 1147 1148 static unsigned getIndexedOpc(unsigned LdStOpc) { 1149 switch (LdStOpc) { 1150 case TargetOpcode::G_LOAD: 1151 return TargetOpcode::G_INDEXED_LOAD; 1152 case TargetOpcode::G_STORE: 1153 return TargetOpcode::G_INDEXED_STORE; 1154 case TargetOpcode::G_ZEXTLOAD: 1155 return TargetOpcode::G_INDEXED_ZEXTLOAD; 1156 case TargetOpcode::G_SEXTLOAD: 1157 return TargetOpcode::G_INDEXED_SEXTLOAD; 1158 default: 1159 llvm_unreachable("Unexpected opcode"); 1160 } 1161 } 1162 1163 bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { 1164 // Check for legality. 1165 LLT PtrTy = MRI.getType(LdSt.getPointerReg()); 1166 LLT Ty = MRI.getType(LdSt.getReg(0)); 1167 LLT MemTy = LdSt.getMMO().getMemoryType(); 1168 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( 1169 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), 1170 AtomicOrdering::NotAtomic}}); 1171 unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode()); 1172 SmallVector<LLT> OpTys; 1173 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) 1174 OpTys = {PtrTy, Ty, Ty}; 1175 else 1176 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD 1177 1178 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); 1179 return isLegal(Q); 1180 } 1181 1182 static cl::opt<unsigned> PostIndexUseThreshold( 1183 "post-index-use-threshold", cl::Hidden, cl::init(32), 1184 cl::desc("Number of uses of a base pointer to check before it is no longer " 1185 "considered for post-indexing.")); 1186 1187 bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, 1188 Register &Base, Register &Offset, 1189 bool &RematOffset) const { 1190 // We're looking for the following pattern, for either load or store: 1191 // %baseptr:_(p0) = ... 1192 // G_STORE %val(s64), %baseptr(p0) 1193 // %offset:_(s64) = G_CONSTANT i64 -256 1194 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) 1195 const auto &TLI = getTargetLowering(); 1196 1197 Register Ptr = LdSt.getPointerReg(); 1198 // If the store is the only use, don't bother. 1199 if (MRI.hasOneNonDBGUse(Ptr)) 1200 return false; 1201 1202 if (!isIndexedLoadStoreLegal(LdSt)) 1203 return false; 1204 1205 if (getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Ptr, MRI)) 1206 return false; 1207 1208 MachineInstr *StoredValDef = getDefIgnoringCopies(LdSt.getReg(0), MRI); 1209 auto *PtrDef = MRI.getVRegDef(Ptr); 1210 1211 unsigned NumUsesChecked = 0; 1212 for (auto &Use : MRI.use_nodbg_instructions(Ptr)) { 1213 if (++NumUsesChecked > PostIndexUseThreshold) 1214 return false; // Try to avoid exploding compile time. 1215 1216 auto *PtrAdd = dyn_cast<GPtrAdd>(&Use); 1217 // The use itself might be dead. This can happen during combines if DCE 1218 // hasn't had a chance to run yet. Don't allow it to form an indexed op. 1219 if (!PtrAdd || MRI.use_nodbg_empty(PtrAdd->getReg(0))) 1220 continue; 1221 1222 // Check the user of this isn't the store, otherwise we'd be generate a 1223 // indexed store defining its own use. 1224 if (StoredValDef == &Use) 1225 continue; 1226 1227 Offset = PtrAdd->getOffsetReg(); 1228 if (!ForceLegalIndexing && 1229 !TLI.isIndexingLegal(LdSt, PtrAdd->getBaseReg(), Offset, 1230 /*IsPre*/ false, MRI)) 1231 continue; 1232 1233 // Make sure the offset calculation is before the potentially indexed op. 1234 MachineInstr *OffsetDef = MRI.getVRegDef(Offset); 1235 RematOffset = false; 1236 if (!dominates(*OffsetDef, LdSt)) { 1237 // If the offset however is just a G_CONSTANT, we can always just 1238 // rematerialize it where we need it. 1239 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) 1240 continue; 1241 RematOffset = true; 1242 } 1243 1244 for (auto &BasePtrUse : MRI.use_nodbg_instructions(PtrAdd->getBaseReg())) { 1245 if (&BasePtrUse == PtrDef) 1246 continue; 1247 1248 // If the user is a later load/store that can be post-indexed, then don't 1249 // combine this one. 1250 auto *BasePtrLdSt = dyn_cast<GLoadStore>(&BasePtrUse); 1251 if (BasePtrLdSt && BasePtrLdSt != &LdSt && 1252 dominates(LdSt, *BasePtrLdSt) && 1253 isIndexedLoadStoreLegal(*BasePtrLdSt)) 1254 return false; 1255 1256 // Now we're looking for the key G_PTR_ADD instruction, which contains 1257 // the offset add that we want to fold. 1258 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(&BasePtrUse)) { 1259 Register PtrAddDefReg = BasePtrUseDef->getReg(0); 1260 for (auto &BaseUseUse : MRI.use_nodbg_instructions(PtrAddDefReg)) { 1261 // If the use is in a different block, then we may produce worse code 1262 // due to the extra register pressure. 1263 if (BaseUseUse.getParent() != LdSt.getParent()) 1264 return false; 1265 1266 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(&BaseUseUse)) 1267 if (canFoldInAddressingMode(UseUseLdSt, TLI, MRI)) 1268 return false; 1269 } 1270 if (!dominates(LdSt, BasePtrUse)) 1271 return false; // All use must be dominated by the load/store. 1272 } 1273 } 1274 1275 Addr = PtrAdd->getReg(0); 1276 Base = PtrAdd->getBaseReg(); 1277 return true; 1278 } 1279 1280 return false; 1281 } 1282 1283 bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, 1284 Register &Base, 1285 Register &Offset) const { 1286 auto &MF = *LdSt.getParent()->getParent(); 1287 const auto &TLI = *MF.getSubtarget().getTargetLowering(); 1288 1289 Addr = LdSt.getPointerReg(); 1290 if (!mi_match(Addr, MRI, m_GPtrAdd(m_Reg(Base), m_Reg(Offset))) || 1291 MRI.hasOneNonDBGUse(Addr)) 1292 return false; 1293 1294 if (!ForceLegalIndexing && 1295 !TLI.isIndexingLegal(LdSt, Base, Offset, /*IsPre*/ true, MRI)) 1296 return false; 1297 1298 if (!isIndexedLoadStoreLegal(LdSt)) 1299 return false; 1300 1301 MachineInstr *BaseDef = getDefIgnoringCopies(Base, MRI); 1302 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) 1303 return false; 1304 1305 if (auto *St = dyn_cast<GStore>(&LdSt)) { 1306 // Would require a copy. 1307 if (Base == St->getValueReg()) 1308 return false; 1309 1310 // We're expecting one use of Addr in MI, but it could also be the 1311 // value stored, which isn't actually dominated by the instruction. 1312 if (St->getValueReg() == Addr) 1313 return false; 1314 } 1315 1316 // Avoid increasing cross-block register pressure. 1317 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr)) 1318 if (AddrUse.getParent() != LdSt.getParent()) 1319 return false; 1320 1321 // FIXME: check whether all uses of the base pointer are constant PtrAdds. 1322 // That might allow us to end base's liveness here by adjusting the constant. 1323 bool RealUse = false; 1324 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr)) { 1325 if (!dominates(LdSt, AddrUse)) 1326 return false; // All use must be dominated by the load/store. 1327 1328 // If Ptr may be folded in addressing mode of other use, then it's 1329 // not profitable to do this transformation. 1330 if (auto *UseLdSt = dyn_cast<GLoadStore>(&AddrUse)) { 1331 if (!canFoldInAddressingMode(UseLdSt, TLI, MRI)) 1332 RealUse = true; 1333 } else { 1334 RealUse = true; 1335 } 1336 } 1337 return RealUse; 1338 } 1339 1340 bool CombinerHelper::matchCombineExtractedVectorLoad( 1341 MachineInstr &MI, BuildFnTy &MatchInfo) const { 1342 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); 1343 1344 // Check if there is a load that defines the vector being extracted from. 1345 auto *LoadMI = getOpcodeDef<GLoad>(MI.getOperand(1).getReg(), MRI); 1346 if (!LoadMI) 1347 return false; 1348 1349 Register Vector = MI.getOperand(1).getReg(); 1350 LLT VecEltTy = MRI.getType(Vector).getElementType(); 1351 1352 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); 1353 1354 // Checking whether we should reduce the load width. 1355 if (!MRI.hasOneNonDBGUse(Vector)) 1356 return false; 1357 1358 // Check if the defining load is simple. 1359 if (!LoadMI->isSimple()) 1360 return false; 1361 1362 // If the vector element type is not a multiple of a byte then we are unable 1363 // to correctly compute an address to load only the extracted element as a 1364 // scalar. 1365 if (!VecEltTy.isByteSized()) 1366 return false; 1367 1368 // Check for load fold barriers between the extraction and the load. 1369 if (MI.getParent() != LoadMI->getParent()) 1370 return false; 1371 const unsigned MaxIter = 20; 1372 unsigned Iter = 0; 1373 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) { 1374 if (II->isLoadFoldBarrier()) 1375 return false; 1376 if (Iter++ == MaxIter) 1377 return false; 1378 } 1379 1380 // Check if the new load that we are going to create is legal 1381 // if we are in the post-legalization phase. 1382 MachineMemOperand MMO = LoadMI->getMMO(); 1383 Align Alignment = MMO.getAlign(); 1384 MachinePointerInfo PtrInfo; 1385 uint64_t Offset; 1386 1387 // Finding the appropriate PtrInfo if offset is a known constant. 1388 // This is required to create the memory operand for the narrowed load. 1389 // This machine memory operand object helps us infer about legality 1390 // before we proceed to combine the instruction. 1391 if (auto CVal = getIConstantVRegVal(Vector, MRI)) { 1392 int Elt = CVal->getZExtValue(); 1393 // FIXME: should be (ABI size)*Elt. 1394 Offset = VecEltTy.getSizeInBits() * Elt / 8; 1395 PtrInfo = MMO.getPointerInfo().getWithOffset(Offset); 1396 } else { 1397 // Discard the pointer info except the address space because the memory 1398 // operand can't represent this new access since the offset is variable. 1399 Offset = VecEltTy.getSizeInBits() / 8; 1400 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); 1401 } 1402 1403 Alignment = commonAlignment(Alignment, Offset); 1404 1405 Register VecPtr = LoadMI->getPointerReg(); 1406 LLT PtrTy = MRI.getType(VecPtr); 1407 1408 MachineFunction &MF = *MI.getMF(); 1409 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, VecEltTy); 1410 1411 LegalityQuery::MemDesc MMDesc(*NewMMO); 1412 1413 LegalityQuery Q = {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}}; 1414 1415 if (!isLegalOrBeforeLegalizer(Q)) 1416 return false; 1417 1418 // Load must be allowed and fast on the target. 1419 LLVMContext &C = MF.getFunction().getContext(); 1420 auto &DL = MF.getDataLayout(); 1421 unsigned Fast = 0; 1422 if (!getTargetLowering().allowsMemoryAccess(C, DL, VecEltTy, *NewMMO, 1423 &Fast) || 1424 !Fast) 1425 return false; 1426 1427 Register Result = MI.getOperand(0).getReg(); 1428 Register Index = MI.getOperand(2).getReg(); 1429 1430 MatchInfo = [=](MachineIRBuilder &B) { 1431 GISelObserverWrapper DummyObserver; 1432 LegalizerHelper Helper(B.getMF(), DummyObserver, B); 1433 //// Get pointer to the vector element. 1434 Register finalPtr = Helper.getVectorElementPointer( 1435 LoadMI->getPointerReg(), MRI.getType(LoadMI->getOperand(0).getReg()), 1436 Index); 1437 // New G_LOAD instruction. 1438 B.buildLoad(Result, finalPtr, PtrInfo, Alignment); 1439 // Remove original GLOAD instruction. 1440 LoadMI->eraseFromParent(); 1441 }; 1442 1443 return true; 1444 } 1445 1446 bool CombinerHelper::matchCombineIndexedLoadStore( 1447 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { 1448 auto &LdSt = cast<GLoadStore>(MI); 1449 1450 if (LdSt.isAtomic()) 1451 return false; 1452 1453 MatchInfo.IsPre = findPreIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base, 1454 MatchInfo.Offset); 1455 if (!MatchInfo.IsPre && 1456 !findPostIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base, 1457 MatchInfo.Offset, MatchInfo.RematOffset)) 1458 return false; 1459 1460 return true; 1461 } 1462 1463 void CombinerHelper::applyCombineIndexedLoadStore( 1464 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { 1465 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(MatchInfo.Addr); 1466 unsigned Opcode = MI.getOpcode(); 1467 bool IsStore = Opcode == TargetOpcode::G_STORE; 1468 unsigned NewOpcode = getIndexedOpc(Opcode); 1469 1470 // If the offset constant didn't happen to dominate the load/store, we can 1471 // just clone it as needed. 1472 if (MatchInfo.RematOffset) { 1473 auto *OldCst = MRI.getVRegDef(MatchInfo.Offset); 1474 auto NewCst = Builder.buildConstant(MRI.getType(MatchInfo.Offset), 1475 *OldCst->getOperand(1).getCImm()); 1476 MatchInfo.Offset = NewCst.getReg(0); 1477 } 1478 1479 auto MIB = Builder.buildInstr(NewOpcode); 1480 if (IsStore) { 1481 MIB.addDef(MatchInfo.Addr); 1482 MIB.addUse(MI.getOperand(0).getReg()); 1483 } else { 1484 MIB.addDef(MI.getOperand(0).getReg()); 1485 MIB.addDef(MatchInfo.Addr); 1486 } 1487 1488 MIB.addUse(MatchInfo.Base); 1489 MIB.addUse(MatchInfo.Offset); 1490 MIB.addImm(MatchInfo.IsPre); 1491 MIB->cloneMemRefs(*MI.getMF(), MI); 1492 MI.eraseFromParent(); 1493 AddrDef.eraseFromParent(); 1494 1495 LLVM_DEBUG(dbgs() << " Combinined to indexed operation"); 1496 } 1497 1498 bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, 1499 MachineInstr *&OtherMI) const { 1500 unsigned Opcode = MI.getOpcode(); 1501 bool IsDiv, IsSigned; 1502 1503 switch (Opcode) { 1504 default: 1505 llvm_unreachable("Unexpected opcode!"); 1506 case TargetOpcode::G_SDIV: 1507 case TargetOpcode::G_UDIV: { 1508 IsDiv = true; 1509 IsSigned = Opcode == TargetOpcode::G_SDIV; 1510 break; 1511 } 1512 case TargetOpcode::G_SREM: 1513 case TargetOpcode::G_UREM: { 1514 IsDiv = false; 1515 IsSigned = Opcode == TargetOpcode::G_SREM; 1516 break; 1517 } 1518 } 1519 1520 Register Src1 = MI.getOperand(1).getReg(); 1521 unsigned DivOpcode, RemOpcode, DivremOpcode; 1522 if (IsSigned) { 1523 DivOpcode = TargetOpcode::G_SDIV; 1524 RemOpcode = TargetOpcode::G_SREM; 1525 DivremOpcode = TargetOpcode::G_SDIVREM; 1526 } else { 1527 DivOpcode = TargetOpcode::G_UDIV; 1528 RemOpcode = TargetOpcode::G_UREM; 1529 DivremOpcode = TargetOpcode::G_UDIVREM; 1530 } 1531 1532 if (!isLegalOrBeforeLegalizer({DivremOpcode, {MRI.getType(Src1)}})) 1533 return false; 1534 1535 // Combine: 1536 // %div:_ = G_[SU]DIV %src1:_, %src2:_ 1537 // %rem:_ = G_[SU]REM %src1:_, %src2:_ 1538 // into: 1539 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ 1540 1541 // Combine: 1542 // %rem:_ = G_[SU]REM %src1:_, %src2:_ 1543 // %div:_ = G_[SU]DIV %src1:_, %src2:_ 1544 // into: 1545 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ 1546 1547 for (auto &UseMI : MRI.use_nodbg_instructions(Src1)) { 1548 if (MI.getParent() == UseMI.getParent() && 1549 ((IsDiv && UseMI.getOpcode() == RemOpcode) || 1550 (!IsDiv && UseMI.getOpcode() == DivOpcode)) && 1551 matchEqualDefs(MI.getOperand(2), UseMI.getOperand(2)) && 1552 matchEqualDefs(MI.getOperand(1), UseMI.getOperand(1))) { 1553 OtherMI = &UseMI; 1554 return true; 1555 } 1556 } 1557 1558 return false; 1559 } 1560 1561 void CombinerHelper::applyCombineDivRem(MachineInstr &MI, 1562 MachineInstr *&OtherMI) const { 1563 unsigned Opcode = MI.getOpcode(); 1564 assert(OtherMI && "OtherMI shouldn't be empty."); 1565 1566 Register DestDivReg, DestRemReg; 1567 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { 1568 DestDivReg = MI.getOperand(0).getReg(); 1569 DestRemReg = OtherMI->getOperand(0).getReg(); 1570 } else { 1571 DestDivReg = OtherMI->getOperand(0).getReg(); 1572 DestRemReg = MI.getOperand(0).getReg(); 1573 } 1574 1575 bool IsSigned = 1576 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; 1577 1578 // Check which instruction is first in the block so we don't break def-use 1579 // deps by "moving" the instruction incorrectly. Also keep track of which 1580 // instruction is first so we pick it's operands, avoiding use-before-def 1581 // bugs. 1582 MachineInstr *FirstInst = dominates(MI, *OtherMI) ? &MI : OtherMI; 1583 Builder.setInstrAndDebugLoc(*FirstInst); 1584 1585 Builder.buildInstr(IsSigned ? TargetOpcode::G_SDIVREM 1586 : TargetOpcode::G_UDIVREM, 1587 {DestDivReg, DestRemReg}, 1588 { FirstInst->getOperand(1), FirstInst->getOperand(2) }); 1589 MI.eraseFromParent(); 1590 OtherMI->eraseFromParent(); 1591 } 1592 1593 bool CombinerHelper::matchOptBrCondByInvertingCond( 1594 MachineInstr &MI, MachineInstr *&BrCond) const { 1595 assert(MI.getOpcode() == TargetOpcode::G_BR); 1596 1597 // Try to match the following: 1598 // bb1: 1599 // G_BRCOND %c1, %bb2 1600 // G_BR %bb3 1601 // bb2: 1602 // ... 1603 // bb3: 1604 1605 // The above pattern does not have a fall through to the successor bb2, always 1606 // resulting in a branch no matter which path is taken. Here we try to find 1607 // and replace that pattern with conditional branch to bb3 and otherwise 1608 // fallthrough to bb2. This is generally better for branch predictors. 1609 1610 MachineBasicBlock *MBB = MI.getParent(); 1611 MachineBasicBlock::iterator BrIt(MI); 1612 if (BrIt == MBB->begin()) 1613 return false; 1614 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator"); 1615 1616 BrCond = &*std::prev(BrIt); 1617 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) 1618 return false; 1619 1620 // Check that the next block is the conditional branch target. Also make sure 1621 // that it isn't the same as the G_BR's target (otherwise, this will loop.) 1622 MachineBasicBlock *BrCondTarget = BrCond->getOperand(1).getMBB(); 1623 return BrCondTarget != MI.getOperand(0).getMBB() && 1624 MBB->isLayoutSuccessor(BrCondTarget); 1625 } 1626 1627 void CombinerHelper::applyOptBrCondByInvertingCond( 1628 MachineInstr &MI, MachineInstr *&BrCond) const { 1629 MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB(); 1630 Builder.setInstrAndDebugLoc(*BrCond); 1631 LLT Ty = MRI.getType(BrCond->getOperand(0).getReg()); 1632 // FIXME: Does int/fp matter for this? If so, we might need to restrict 1633 // this to i1 only since we might not know for sure what kind of 1634 // compare generated the condition value. 1635 auto True = Builder.buildConstant( 1636 Ty, getICmpTrueVal(getTargetLowering(), false, false)); 1637 auto Xor = Builder.buildXor(Ty, BrCond->getOperand(0), True); 1638 1639 auto *FallthroughBB = BrCond->getOperand(1).getMBB(); 1640 Observer.changingInstr(MI); 1641 MI.getOperand(0).setMBB(FallthroughBB); 1642 Observer.changedInstr(MI); 1643 1644 // Change the conditional branch to use the inverted condition and 1645 // new target block. 1646 Observer.changingInstr(*BrCond); 1647 BrCond->getOperand(0).setReg(Xor.getReg(0)); 1648 BrCond->getOperand(1).setMBB(BrTarget); 1649 Observer.changedInstr(*BrCond); 1650 } 1651 1652 bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const { 1653 MachineIRBuilder HelperBuilder(MI); 1654 GISelObserverWrapper DummyObserver; 1655 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); 1656 return Helper.lowerMemcpyInline(MI) == 1657 LegalizerHelper::LegalizeResult::Legalized; 1658 } 1659 1660 bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, 1661 unsigned MaxLen) const { 1662 MachineIRBuilder HelperBuilder(MI); 1663 GISelObserverWrapper DummyObserver; 1664 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); 1665 return Helper.lowerMemCpyFamily(MI, MaxLen) == 1666 LegalizerHelper::LegalizeResult::Legalized; 1667 } 1668 1669 static APFloat constantFoldFpUnary(const MachineInstr &MI, 1670 const MachineRegisterInfo &MRI, 1671 const APFloat &Val) { 1672 APFloat Result(Val); 1673 switch (MI.getOpcode()) { 1674 default: 1675 llvm_unreachable("Unexpected opcode!"); 1676 case TargetOpcode::G_FNEG: { 1677 Result.changeSign(); 1678 return Result; 1679 } 1680 case TargetOpcode::G_FABS: { 1681 Result.clearSign(); 1682 return Result; 1683 } 1684 case TargetOpcode::G_FPTRUNC: { 1685 bool Unused; 1686 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 1687 Result.convert(getFltSemanticForLLT(DstTy), APFloat::rmNearestTiesToEven, 1688 &Unused); 1689 return Result; 1690 } 1691 case TargetOpcode::G_FSQRT: { 1692 bool Unused; 1693 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 1694 &Unused); 1695 Result = APFloat(sqrt(Result.convertToDouble())); 1696 break; 1697 } 1698 case TargetOpcode::G_FLOG2: { 1699 bool Unused; 1700 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 1701 &Unused); 1702 Result = APFloat(log2(Result.convertToDouble())); 1703 break; 1704 } 1705 } 1706 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, 1707 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and 1708 // `G_FLOG2` reach here. 1709 bool Unused; 1710 Result.convert(Val.getSemantics(), APFloat::rmNearestTiesToEven, &Unused); 1711 return Result; 1712 } 1713 1714 void CombinerHelper::applyCombineConstantFoldFpUnary( 1715 MachineInstr &MI, const ConstantFP *Cst) const { 1716 APFloat Folded = constantFoldFpUnary(MI, MRI, Cst->getValue()); 1717 const ConstantFP *NewCst = ConstantFP::get(Builder.getContext(), Folded); 1718 Builder.buildFConstant(MI.getOperand(0), *NewCst); 1719 MI.eraseFromParent(); 1720 } 1721 1722 bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, 1723 PtrAddChain &MatchInfo) const { 1724 // We're trying to match the following pattern: 1725 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 1726 // %root = G_PTR_ADD %t1, G_CONSTANT imm2 1727 // --> 1728 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) 1729 1730 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) 1731 return false; 1732 1733 Register Add2 = MI.getOperand(1).getReg(); 1734 Register Imm1 = MI.getOperand(2).getReg(); 1735 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI); 1736 if (!MaybeImmVal) 1737 return false; 1738 1739 MachineInstr *Add2Def = MRI.getVRegDef(Add2); 1740 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) 1741 return false; 1742 1743 Register Base = Add2Def->getOperand(1).getReg(); 1744 Register Imm2 = Add2Def->getOperand(2).getReg(); 1745 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI); 1746 if (!MaybeImm2Val) 1747 return false; 1748 1749 // Check if the new combined immediate forms an illegal addressing mode. 1750 // Do not combine if it was legal before but would get illegal. 1751 // To do so, we need to find a load/store user of the pointer to get 1752 // the access type. 1753 Type *AccessTy = nullptr; 1754 auto &MF = *MI.getMF(); 1755 for (auto &UseMI : MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) { 1756 if (auto *LdSt = dyn_cast<GLoadStore>(&UseMI)) { 1757 AccessTy = getTypeForLLT(MRI.getType(LdSt->getReg(0)), 1758 MF.getFunction().getContext()); 1759 break; 1760 } 1761 } 1762 TargetLoweringBase::AddrMode AMNew; 1763 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; 1764 AMNew.BaseOffs = CombinedImm.getSExtValue(); 1765 if (AccessTy) { 1766 AMNew.HasBaseReg = true; 1767 TargetLoweringBase::AddrMode AMOld; 1768 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); 1769 AMOld.HasBaseReg = true; 1770 unsigned AS = MRI.getType(Add2).getAddressSpace(); 1771 const auto &TLI = *MF.getSubtarget().getTargetLowering(); 1772 if (TLI.isLegalAddressingMode(MF.getDataLayout(), AMOld, AccessTy, AS) && 1773 !TLI.isLegalAddressingMode(MF.getDataLayout(), AMNew, AccessTy, AS)) 1774 return false; 1775 } 1776 1777 // Pass the combined immediate to the apply function. 1778 MatchInfo.Imm = AMNew.BaseOffs; 1779 MatchInfo.Base = Base; 1780 MatchInfo.Bank = getRegBank(Imm2); 1781 return true; 1782 } 1783 1784 void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, 1785 PtrAddChain &MatchInfo) const { 1786 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD"); 1787 MachineIRBuilder MIB(MI); 1788 LLT OffsetTy = MRI.getType(MI.getOperand(2).getReg()); 1789 auto NewOffset = MIB.buildConstant(OffsetTy, MatchInfo.Imm); 1790 setRegBank(NewOffset.getReg(0), MatchInfo.Bank); 1791 Observer.changingInstr(MI); 1792 MI.getOperand(1).setReg(MatchInfo.Base); 1793 MI.getOperand(2).setReg(NewOffset.getReg(0)); 1794 Observer.changedInstr(MI); 1795 } 1796 1797 bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, 1798 RegisterImmPair &MatchInfo) const { 1799 // We're trying to match the following pattern with any of 1800 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: 1801 // %t1 = SHIFT %base, G_CONSTANT imm1 1802 // %root = SHIFT %t1, G_CONSTANT imm2 1803 // --> 1804 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) 1805 1806 unsigned Opcode = MI.getOpcode(); 1807 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 1808 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || 1809 Opcode == TargetOpcode::G_USHLSAT) && 1810 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT"); 1811 1812 Register Shl2 = MI.getOperand(1).getReg(); 1813 Register Imm1 = MI.getOperand(2).getReg(); 1814 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI); 1815 if (!MaybeImmVal) 1816 return false; 1817 1818 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Shl2); 1819 if (Shl2Def->getOpcode() != Opcode) 1820 return false; 1821 1822 Register Base = Shl2Def->getOperand(1).getReg(); 1823 Register Imm2 = Shl2Def->getOperand(2).getReg(); 1824 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI); 1825 if (!MaybeImm2Val) 1826 return false; 1827 1828 // Pass the combined immediate to the apply function. 1829 MatchInfo.Imm = 1830 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); 1831 MatchInfo.Reg = Base; 1832 1833 // There is no simple replacement for a saturating unsigned left shift that 1834 // exceeds the scalar size. 1835 if (Opcode == TargetOpcode::G_USHLSAT && 1836 MatchInfo.Imm >= MRI.getType(Shl2).getScalarSizeInBits()) 1837 return false; 1838 1839 return true; 1840 } 1841 1842 void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, 1843 RegisterImmPair &MatchInfo) const { 1844 unsigned Opcode = MI.getOpcode(); 1845 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 1846 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || 1847 Opcode == TargetOpcode::G_USHLSAT) && 1848 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT"); 1849 1850 LLT Ty = MRI.getType(MI.getOperand(1).getReg()); 1851 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); 1852 auto Imm = MatchInfo.Imm; 1853 1854 if (Imm >= ScalarSizeInBits) { 1855 // Any logical shift that exceeds scalar size will produce zero. 1856 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { 1857 Builder.buildConstant(MI.getOperand(0), 0); 1858 MI.eraseFromParent(); 1859 return; 1860 } 1861 // Arithmetic shift and saturating signed left shift have no effect beyond 1862 // scalar size. 1863 Imm = ScalarSizeInBits - 1; 1864 } 1865 1866 LLT ImmTy = MRI.getType(MI.getOperand(2).getReg()); 1867 Register NewImm = Builder.buildConstant(ImmTy, Imm).getReg(0); 1868 Observer.changingInstr(MI); 1869 MI.getOperand(1).setReg(MatchInfo.Reg); 1870 MI.getOperand(2).setReg(NewImm); 1871 Observer.changedInstr(MI); 1872 } 1873 1874 bool CombinerHelper::matchShiftOfShiftedLogic( 1875 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { 1876 // We're trying to match the following pattern with any of 1877 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination 1878 // with any of G_AND/G_OR/G_XOR logic instructions. 1879 // %t1 = SHIFT %X, G_CONSTANT C0 1880 // %t2 = LOGIC %t1, %Y 1881 // %root = SHIFT %t2, G_CONSTANT C1 1882 // --> 1883 // %t3 = SHIFT %X, G_CONSTANT (C0+C1) 1884 // %t4 = SHIFT %Y, G_CONSTANT C1 1885 // %root = LOGIC %t3, %t4 1886 unsigned ShiftOpcode = MI.getOpcode(); 1887 assert((ShiftOpcode == TargetOpcode::G_SHL || 1888 ShiftOpcode == TargetOpcode::G_ASHR || 1889 ShiftOpcode == TargetOpcode::G_LSHR || 1890 ShiftOpcode == TargetOpcode::G_USHLSAT || 1891 ShiftOpcode == TargetOpcode::G_SSHLSAT) && 1892 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT"); 1893 1894 // Match a one-use bitwise logic op. 1895 Register LogicDest = MI.getOperand(1).getReg(); 1896 if (!MRI.hasOneNonDBGUse(LogicDest)) 1897 return false; 1898 1899 MachineInstr *LogicMI = MRI.getUniqueVRegDef(LogicDest); 1900 unsigned LogicOpcode = LogicMI->getOpcode(); 1901 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && 1902 LogicOpcode != TargetOpcode::G_XOR) 1903 return false; 1904 1905 // Find a matching one-use shift by constant. 1906 const Register C1 = MI.getOperand(2).getReg(); 1907 auto MaybeImmVal = getIConstantVRegValWithLookThrough(C1, MRI); 1908 if (!MaybeImmVal || MaybeImmVal->Value == 0) 1909 return false; 1910 1911 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); 1912 1913 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { 1914 // Shift should match previous one and should be a one-use. 1915 if (MI->getOpcode() != ShiftOpcode || 1916 !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) 1917 return false; 1918 1919 // Must be a constant. 1920 auto MaybeImmVal = 1921 getIConstantVRegValWithLookThrough(MI->getOperand(2).getReg(), MRI); 1922 if (!MaybeImmVal) 1923 return false; 1924 1925 ShiftVal = MaybeImmVal->Value.getSExtValue(); 1926 return true; 1927 }; 1928 1929 // Logic ops are commutative, so check each operand for a match. 1930 Register LogicMIReg1 = LogicMI->getOperand(1).getReg(); 1931 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(LogicMIReg1); 1932 Register LogicMIReg2 = LogicMI->getOperand(2).getReg(); 1933 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(LogicMIReg2); 1934 uint64_t C0Val; 1935 1936 if (matchFirstShift(LogicMIOp1, C0Val)) { 1937 MatchInfo.LogicNonShiftReg = LogicMIReg2; 1938 MatchInfo.Shift2 = LogicMIOp1; 1939 } else if (matchFirstShift(LogicMIOp2, C0Val)) { 1940 MatchInfo.LogicNonShiftReg = LogicMIReg1; 1941 MatchInfo.Shift2 = LogicMIOp2; 1942 } else 1943 return false; 1944 1945 MatchInfo.ValSum = C0Val + C1Val; 1946 1947 // The fold is not valid if the sum of the shift values exceeds bitwidth. 1948 if (MatchInfo.ValSum >= MRI.getType(LogicDest).getScalarSizeInBits()) 1949 return false; 1950 1951 MatchInfo.Logic = LogicMI; 1952 return true; 1953 } 1954 1955 void CombinerHelper::applyShiftOfShiftedLogic( 1956 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { 1957 unsigned Opcode = MI.getOpcode(); 1958 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 1959 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || 1960 Opcode == TargetOpcode::G_SSHLSAT) && 1961 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT"); 1962 1963 LLT ShlType = MRI.getType(MI.getOperand(2).getReg()); 1964 LLT DestType = MRI.getType(MI.getOperand(0).getReg()); 1965 1966 Register Const = Builder.buildConstant(ShlType, MatchInfo.ValSum).getReg(0); 1967 1968 Register Shift1Base = MatchInfo.Shift2->getOperand(1).getReg(); 1969 Register Shift1 = 1970 Builder.buildInstr(Opcode, {DestType}, {Shift1Base, Const}).getReg(0); 1971 1972 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same 1973 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when 1974 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we 1975 // remove old shift1. And it will cause crash later. So erase it earlier to 1976 // avoid the crash. 1977 MatchInfo.Shift2->eraseFromParent(); 1978 1979 Register Shift2Const = MI.getOperand(2).getReg(); 1980 Register Shift2 = Builder 1981 .buildInstr(Opcode, {DestType}, 1982 {MatchInfo.LogicNonShiftReg, Shift2Const}) 1983 .getReg(0); 1984 1985 Register Dest = MI.getOperand(0).getReg(); 1986 Builder.buildInstr(MatchInfo.Logic->getOpcode(), {Dest}, {Shift1, Shift2}); 1987 1988 // This was one use so it's safe to remove it. 1989 MatchInfo.Logic->eraseFromParent(); 1990 1991 MI.eraseFromParent(); 1992 } 1993 1994 bool CombinerHelper::matchCommuteShift(MachineInstr &MI, 1995 BuildFnTy &MatchInfo) const { 1996 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL"); 1997 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) 1998 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) 1999 auto &Shl = cast<GenericMachineInstr>(MI); 2000 Register DstReg = Shl.getReg(0); 2001 Register SrcReg = Shl.getReg(1); 2002 Register ShiftReg = Shl.getReg(2); 2003 Register X, C1; 2004 2005 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, !isPreLegalize())) 2006 return false; 2007 2008 if (!mi_match(SrcReg, MRI, 2009 m_OneNonDBGUse(m_any_of(m_GAdd(m_Reg(X), m_Reg(C1)), 2010 m_GOr(m_Reg(X), m_Reg(C1)))))) 2011 return false; 2012 2013 APInt C1Val, C2Val; 2014 if (!mi_match(C1, MRI, m_ICstOrSplat(C1Val)) || 2015 !mi_match(ShiftReg, MRI, m_ICstOrSplat(C2Val))) 2016 return false; 2017 2018 auto *SrcDef = MRI.getVRegDef(SrcReg); 2019 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || 2020 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op"); 2021 LLT SrcTy = MRI.getType(SrcReg); 2022 MatchInfo = [=](MachineIRBuilder &B) { 2023 auto S1 = B.buildShl(SrcTy, X, ShiftReg); 2024 auto S2 = B.buildShl(SrcTy, C1, ShiftReg); 2025 B.buildInstr(SrcDef->getOpcode(), {DstReg}, {S1, S2}); 2026 }; 2027 return true; 2028 } 2029 2030 bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, 2031 unsigned &ShiftVal) const { 2032 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL"); 2033 auto MaybeImmVal = 2034 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 2035 if (!MaybeImmVal) 2036 return false; 2037 2038 ShiftVal = MaybeImmVal->Value.exactLogBase2(); 2039 return (static_cast<int32_t>(ShiftVal) != -1); 2040 } 2041 2042 void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, 2043 unsigned &ShiftVal) const { 2044 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL"); 2045 MachineIRBuilder MIB(MI); 2046 LLT ShiftTy = MRI.getType(MI.getOperand(0).getReg()); 2047 auto ShiftCst = MIB.buildConstant(ShiftTy, ShiftVal); 2048 Observer.changingInstr(MI); 2049 MI.setDesc(MIB.getTII().get(TargetOpcode::G_SHL)); 2050 MI.getOperand(2).setReg(ShiftCst.getReg(0)); 2051 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1) 2052 MI.clearFlag(MachineInstr::MIFlag::NoSWrap); 2053 Observer.changedInstr(MI); 2054 } 2055 2056 bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI, 2057 BuildFnTy &MatchInfo) const { 2058 GSub &Sub = cast<GSub>(MI); 2059 2060 LLT Ty = MRI.getType(Sub.getReg(0)); 2061 2062 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {Ty}})) 2063 return false; 2064 2065 if (!isConstantLegalOrBeforeLegalizer(Ty)) 2066 return false; 2067 2068 APInt Imm = getIConstantFromReg(Sub.getRHSReg(), MRI); 2069 2070 MatchInfo = [=, &MI](MachineIRBuilder &B) { 2071 auto NegCst = B.buildConstant(Ty, -Imm); 2072 Observer.changingInstr(MI); 2073 MI.setDesc(B.getTII().get(TargetOpcode::G_ADD)); 2074 MI.getOperand(2).setReg(NegCst.getReg(0)); 2075 MI.clearFlag(MachineInstr::MIFlag::NoUWrap); 2076 Observer.changedInstr(MI); 2077 }; 2078 return true; 2079 } 2080 2081 // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source 2082 bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, 2083 RegisterImmPair &MatchData) const { 2084 assert(MI.getOpcode() == TargetOpcode::G_SHL && KB); 2085 if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) 2086 return false; 2087 2088 Register LHS = MI.getOperand(1).getReg(); 2089 2090 Register ExtSrc; 2091 if (!mi_match(LHS, MRI, m_GAnyExt(m_Reg(ExtSrc))) && 2092 !mi_match(LHS, MRI, m_GZExt(m_Reg(ExtSrc))) && 2093 !mi_match(LHS, MRI, m_GSExt(m_Reg(ExtSrc)))) 2094 return false; 2095 2096 Register RHS = MI.getOperand(2).getReg(); 2097 MachineInstr *MIShiftAmt = MRI.getVRegDef(RHS); 2098 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(*MIShiftAmt, MRI); 2099 if (!MaybeShiftAmtVal) 2100 return false; 2101 2102 if (LI) { 2103 LLT SrcTy = MRI.getType(ExtSrc); 2104 2105 // We only really care about the legality with the shifted value. We can 2106 // pick any type the constant shift amount, so ask the target what to 2107 // use. Otherwise we would have to guess and hope it is reported as legal. 2108 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(SrcTy); 2109 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) 2110 return false; 2111 } 2112 2113 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); 2114 MatchData.Reg = ExtSrc; 2115 MatchData.Imm = ShiftAmt; 2116 2117 unsigned MinLeadingZeros = KB->getKnownZeroes(ExtSrc).countl_one(); 2118 unsigned SrcTySize = MRI.getType(ExtSrc).getScalarSizeInBits(); 2119 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; 2120 } 2121 2122 void CombinerHelper::applyCombineShlOfExtend( 2123 MachineInstr &MI, const RegisterImmPair &MatchData) const { 2124 Register ExtSrcReg = MatchData.Reg; 2125 int64_t ShiftAmtVal = MatchData.Imm; 2126 2127 LLT ExtSrcTy = MRI.getType(ExtSrcReg); 2128 auto ShiftAmt = Builder.buildConstant(ExtSrcTy, ShiftAmtVal); 2129 auto NarrowShift = 2130 Builder.buildShl(ExtSrcTy, ExtSrcReg, ShiftAmt, MI.getFlags()); 2131 Builder.buildZExt(MI.getOperand(0), NarrowShift); 2132 MI.eraseFromParent(); 2133 } 2134 2135 bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, 2136 Register &MatchInfo) const { 2137 GMerge &Merge = cast<GMerge>(MI); 2138 SmallVector<Register, 16> MergedValues; 2139 for (unsigned I = 0; I < Merge.getNumSources(); ++I) 2140 MergedValues.emplace_back(Merge.getSourceReg(I)); 2141 2142 auto *Unmerge = getOpcodeDef<GUnmerge>(MergedValues[0], MRI); 2143 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) 2144 return false; 2145 2146 for (unsigned I = 0; I < MergedValues.size(); ++I) 2147 if (MergedValues[I] != Unmerge->getReg(I)) 2148 return false; 2149 2150 MatchInfo = Unmerge->getSourceReg(); 2151 return true; 2152 } 2153 2154 static Register peekThroughBitcast(Register Reg, 2155 const MachineRegisterInfo &MRI) { 2156 while (mi_match(Reg, MRI, m_GBitcast(m_Reg(Reg)))) 2157 ; 2158 2159 return Reg; 2160 } 2161 2162 bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( 2163 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { 2164 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2165 "Expected an unmerge"); 2166 auto &Unmerge = cast<GUnmerge>(MI); 2167 Register SrcReg = peekThroughBitcast(Unmerge.getSourceReg(), MRI); 2168 2169 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(SrcReg, MRI); 2170 if (!SrcInstr) 2171 return false; 2172 2173 // Check the source type of the merge. 2174 LLT SrcMergeTy = MRI.getType(SrcInstr->getSourceReg(0)); 2175 LLT Dst0Ty = MRI.getType(Unmerge.getReg(0)); 2176 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); 2177 if (SrcMergeTy != Dst0Ty && !SameSize) 2178 return false; 2179 // They are the same now (modulo a bitcast). 2180 // We can collect all the src registers. 2181 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) 2182 Operands.push_back(SrcInstr->getSourceReg(Idx)); 2183 return true; 2184 } 2185 2186 void CombinerHelper::applyCombineUnmergeMergeToPlainValues( 2187 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { 2188 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2189 "Expected an unmerge"); 2190 assert((MI.getNumOperands() - 1 == Operands.size()) && 2191 "Not enough operands to replace all defs"); 2192 unsigned NumElems = MI.getNumOperands() - 1; 2193 2194 LLT SrcTy = MRI.getType(Operands[0]); 2195 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2196 bool CanReuseInputDirectly = DstTy == SrcTy; 2197 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2198 Register DstReg = MI.getOperand(Idx).getReg(); 2199 Register SrcReg = Operands[Idx]; 2200 2201 // This combine may run after RegBankSelect, so we need to be aware of 2202 // register banks. 2203 const auto &DstCB = MRI.getRegClassOrRegBank(DstReg); 2204 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(SrcReg)) { 2205 SrcReg = Builder.buildCopy(MRI.getType(SrcReg), SrcReg).getReg(0); 2206 MRI.setRegClassOrRegBank(SrcReg, DstCB); 2207 } 2208 2209 if (CanReuseInputDirectly) 2210 replaceRegWith(MRI, DstReg, SrcReg); 2211 else 2212 Builder.buildCast(DstReg, SrcReg); 2213 } 2214 MI.eraseFromParent(); 2215 } 2216 2217 bool CombinerHelper::matchCombineUnmergeConstant( 2218 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { 2219 unsigned SrcIdx = MI.getNumOperands() - 1; 2220 Register SrcReg = MI.getOperand(SrcIdx).getReg(); 2221 MachineInstr *SrcInstr = MRI.getVRegDef(SrcReg); 2222 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && 2223 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) 2224 return false; 2225 // Break down the big constant in smaller ones. 2226 const MachineOperand &CstVal = SrcInstr->getOperand(1); 2227 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT 2228 ? CstVal.getCImm()->getValue() 2229 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); 2230 2231 LLT Dst0Ty = MRI.getType(MI.getOperand(0).getReg()); 2232 unsigned ShiftAmt = Dst0Ty.getSizeInBits(); 2233 // Unmerge a constant. 2234 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { 2235 Csts.emplace_back(Val.trunc(ShiftAmt)); 2236 Val = Val.lshr(ShiftAmt); 2237 } 2238 2239 return true; 2240 } 2241 2242 void CombinerHelper::applyCombineUnmergeConstant( 2243 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { 2244 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2245 "Expected an unmerge"); 2246 assert((MI.getNumOperands() - 1 == Csts.size()) && 2247 "Not enough operands to replace all defs"); 2248 unsigned NumElems = MI.getNumOperands() - 1; 2249 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2250 Register DstReg = MI.getOperand(Idx).getReg(); 2251 Builder.buildConstant(DstReg, Csts[Idx]); 2252 } 2253 2254 MI.eraseFromParent(); 2255 } 2256 2257 bool CombinerHelper::matchCombineUnmergeUndef( 2258 MachineInstr &MI, 2259 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 2260 unsigned SrcIdx = MI.getNumOperands() - 1; 2261 Register SrcReg = MI.getOperand(SrcIdx).getReg(); 2262 MatchInfo = [&MI](MachineIRBuilder &B) { 2263 unsigned NumElems = MI.getNumOperands() - 1; 2264 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2265 Register DstReg = MI.getOperand(Idx).getReg(); 2266 B.buildUndef(DstReg); 2267 } 2268 }; 2269 return isa<GImplicitDef>(MRI.getVRegDef(SrcReg)); 2270 } 2271 2272 bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc( 2273 MachineInstr &MI) const { 2274 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2275 "Expected an unmerge"); 2276 if (MRI.getType(MI.getOperand(0).getReg()).isVector() || 2277 MRI.getType(MI.getOperand(MI.getNumDefs()).getReg()).isVector()) 2278 return false; 2279 // Check that all the lanes are dead except the first one. 2280 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { 2281 if (!MRI.use_nodbg_empty(MI.getOperand(Idx).getReg())) 2282 return false; 2283 } 2284 return true; 2285 } 2286 2287 void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc( 2288 MachineInstr &MI) const { 2289 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg(); 2290 Register Dst0Reg = MI.getOperand(0).getReg(); 2291 Builder.buildTrunc(Dst0Reg, SrcReg); 2292 MI.eraseFromParent(); 2293 } 2294 2295 bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const { 2296 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2297 "Expected an unmerge"); 2298 Register Dst0Reg = MI.getOperand(0).getReg(); 2299 LLT Dst0Ty = MRI.getType(Dst0Reg); 2300 // G_ZEXT on vector applies to each lane, so it will 2301 // affect all destinations. Therefore we won't be able 2302 // to simplify the unmerge to just the first definition. 2303 if (Dst0Ty.isVector()) 2304 return false; 2305 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg(); 2306 LLT SrcTy = MRI.getType(SrcReg); 2307 if (SrcTy.isVector()) 2308 return false; 2309 2310 Register ZExtSrcReg; 2311 if (!mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZExtSrcReg)))) 2312 return false; 2313 2314 // Finally we can replace the first definition with 2315 // a zext of the source if the definition is big enough to hold 2316 // all of ZExtSrc bits. 2317 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg); 2318 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); 2319 } 2320 2321 void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const { 2322 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2323 "Expected an unmerge"); 2324 2325 Register Dst0Reg = MI.getOperand(0).getReg(); 2326 2327 MachineInstr *ZExtInstr = 2328 MRI.getVRegDef(MI.getOperand(MI.getNumDefs()).getReg()); 2329 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && 2330 "Expecting a G_ZEXT"); 2331 2332 Register ZExtSrcReg = ZExtInstr->getOperand(1).getReg(); 2333 LLT Dst0Ty = MRI.getType(Dst0Reg); 2334 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg); 2335 2336 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { 2337 Builder.buildZExt(Dst0Reg, ZExtSrcReg); 2338 } else { 2339 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && 2340 "ZExt src doesn't fit in destination"); 2341 replaceRegWith(MRI, Dst0Reg, ZExtSrcReg); 2342 } 2343 2344 Register ZeroReg; 2345 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { 2346 if (!ZeroReg) 2347 ZeroReg = Builder.buildConstant(Dst0Ty, 0).getReg(0); 2348 replaceRegWith(MRI, MI.getOperand(Idx).getReg(), ZeroReg); 2349 } 2350 MI.eraseFromParent(); 2351 } 2352 2353 bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, 2354 unsigned TargetShiftSize, 2355 unsigned &ShiftVal) const { 2356 assert((MI.getOpcode() == TargetOpcode::G_SHL || 2357 MI.getOpcode() == TargetOpcode::G_LSHR || 2358 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift"); 2359 2360 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 2361 if (Ty.isVector()) // TODO: 2362 return false; 2363 2364 // Don't narrow further than the requested size. 2365 unsigned Size = Ty.getSizeInBits(); 2366 if (Size <= TargetShiftSize) 2367 return false; 2368 2369 auto MaybeImmVal = 2370 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 2371 if (!MaybeImmVal) 2372 return false; 2373 2374 ShiftVal = MaybeImmVal->Value.getSExtValue(); 2375 return ShiftVal >= Size / 2 && ShiftVal < Size; 2376 } 2377 2378 void CombinerHelper::applyCombineShiftToUnmerge( 2379 MachineInstr &MI, const unsigned &ShiftVal) const { 2380 Register DstReg = MI.getOperand(0).getReg(); 2381 Register SrcReg = MI.getOperand(1).getReg(); 2382 LLT Ty = MRI.getType(SrcReg); 2383 unsigned Size = Ty.getSizeInBits(); 2384 unsigned HalfSize = Size / 2; 2385 assert(ShiftVal >= HalfSize); 2386 2387 LLT HalfTy = LLT::scalar(HalfSize); 2388 2389 auto Unmerge = Builder.buildUnmerge(HalfTy, SrcReg); 2390 unsigned NarrowShiftAmt = ShiftVal - HalfSize; 2391 2392 if (MI.getOpcode() == TargetOpcode::G_LSHR) { 2393 Register Narrowed = Unmerge.getReg(1); 2394 2395 // dst = G_LSHR s64:x, C for C >= 32 2396 // => 2397 // lo, hi = G_UNMERGE_VALUES x 2398 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 2399 2400 if (NarrowShiftAmt != 0) { 2401 Narrowed = Builder.buildLShr(HalfTy, Narrowed, 2402 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0); 2403 } 2404 2405 auto Zero = Builder.buildConstant(HalfTy, 0); 2406 Builder.buildMergeLikeInstr(DstReg, {Narrowed, Zero}); 2407 } else if (MI.getOpcode() == TargetOpcode::G_SHL) { 2408 Register Narrowed = Unmerge.getReg(0); 2409 // dst = G_SHL s64:x, C for C >= 32 2410 // => 2411 // lo, hi = G_UNMERGE_VALUES x 2412 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) 2413 if (NarrowShiftAmt != 0) { 2414 Narrowed = Builder.buildShl(HalfTy, Narrowed, 2415 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0); 2416 } 2417 2418 auto Zero = Builder.buildConstant(HalfTy, 0); 2419 Builder.buildMergeLikeInstr(DstReg, {Zero, Narrowed}); 2420 } else { 2421 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 2422 auto Hi = Builder.buildAShr( 2423 HalfTy, Unmerge.getReg(1), 2424 Builder.buildConstant(HalfTy, HalfSize - 1)); 2425 2426 if (ShiftVal == HalfSize) { 2427 // (G_ASHR i64:x, 32) -> 2428 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) 2429 Builder.buildMergeLikeInstr(DstReg, {Unmerge.getReg(1), Hi}); 2430 } else if (ShiftVal == Size - 1) { 2431 // Don't need a second shift. 2432 // (G_ASHR i64:x, 63) -> 2433 // %narrowed = (G_ASHR hi_32(x), 31) 2434 // G_MERGE_VALUES %narrowed, %narrowed 2435 Builder.buildMergeLikeInstr(DstReg, {Hi, Hi}); 2436 } else { 2437 auto Lo = Builder.buildAShr( 2438 HalfTy, Unmerge.getReg(1), 2439 Builder.buildConstant(HalfTy, ShiftVal - HalfSize)); 2440 2441 // (G_ASHR i64:x, C) ->, for C >= 32 2442 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) 2443 Builder.buildMergeLikeInstr(DstReg, {Lo, Hi}); 2444 } 2445 } 2446 2447 MI.eraseFromParent(); 2448 } 2449 2450 bool CombinerHelper::tryCombineShiftToUnmerge( 2451 MachineInstr &MI, unsigned TargetShiftAmount) const { 2452 unsigned ShiftAmt; 2453 if (matchCombineShiftToUnmerge(MI, TargetShiftAmount, ShiftAmt)) { 2454 applyCombineShiftToUnmerge(MI, ShiftAmt); 2455 return true; 2456 } 2457 2458 return false; 2459 } 2460 2461 bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, 2462 Register &Reg) const { 2463 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR"); 2464 Register DstReg = MI.getOperand(0).getReg(); 2465 LLT DstTy = MRI.getType(DstReg); 2466 Register SrcReg = MI.getOperand(1).getReg(); 2467 return mi_match(SrcReg, MRI, 2468 m_GPtrToInt(m_all_of(m_SpecificType(DstTy), m_Reg(Reg)))); 2469 } 2470 2471 void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, 2472 Register &Reg) const { 2473 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR"); 2474 Register DstReg = MI.getOperand(0).getReg(); 2475 Builder.buildCopy(DstReg, Reg); 2476 MI.eraseFromParent(); 2477 } 2478 2479 void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, 2480 Register &Reg) const { 2481 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT"); 2482 Register DstReg = MI.getOperand(0).getReg(); 2483 Builder.buildZExtOrTrunc(DstReg, Reg); 2484 MI.eraseFromParent(); 2485 } 2486 2487 bool CombinerHelper::matchCombineAddP2IToPtrAdd( 2488 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { 2489 assert(MI.getOpcode() == TargetOpcode::G_ADD); 2490 Register LHS = MI.getOperand(1).getReg(); 2491 Register RHS = MI.getOperand(2).getReg(); 2492 LLT IntTy = MRI.getType(LHS); 2493 2494 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the 2495 // instruction. 2496 PtrReg.second = false; 2497 for (Register SrcReg : {LHS, RHS}) { 2498 if (mi_match(SrcReg, MRI, m_GPtrToInt(m_Reg(PtrReg.first)))) { 2499 // Don't handle cases where the integer is implicitly converted to the 2500 // pointer width. 2501 LLT PtrTy = MRI.getType(PtrReg.first); 2502 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) 2503 return true; 2504 } 2505 2506 PtrReg.second = true; 2507 } 2508 2509 return false; 2510 } 2511 2512 void CombinerHelper::applyCombineAddP2IToPtrAdd( 2513 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { 2514 Register Dst = MI.getOperand(0).getReg(); 2515 Register LHS = MI.getOperand(1).getReg(); 2516 Register RHS = MI.getOperand(2).getReg(); 2517 2518 const bool DoCommute = PtrReg.second; 2519 if (DoCommute) 2520 std::swap(LHS, RHS); 2521 LHS = PtrReg.first; 2522 2523 LLT PtrTy = MRI.getType(LHS); 2524 2525 auto PtrAdd = Builder.buildPtrAdd(PtrTy, LHS, RHS); 2526 Builder.buildPtrToInt(Dst, PtrAdd); 2527 MI.eraseFromParent(); 2528 } 2529 2530 bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, 2531 APInt &NewCst) const { 2532 auto &PtrAdd = cast<GPtrAdd>(MI); 2533 Register LHS = PtrAdd.getBaseReg(); 2534 Register RHS = PtrAdd.getOffsetReg(); 2535 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); 2536 2537 if (auto RHSCst = getIConstantVRegVal(RHS, MRI)) { 2538 APInt Cst; 2539 if (mi_match(LHS, MRI, m_GIntToPtr(m_ICst(Cst)))) { 2540 auto DstTy = MRI.getType(PtrAdd.getReg(0)); 2541 // G_INTTOPTR uses zero-extension 2542 NewCst = Cst.zextOrTrunc(DstTy.getSizeInBits()); 2543 NewCst += RHSCst->sextOrTrunc(DstTy.getSizeInBits()); 2544 return true; 2545 } 2546 } 2547 2548 return false; 2549 } 2550 2551 void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, 2552 APInt &NewCst) const { 2553 auto &PtrAdd = cast<GPtrAdd>(MI); 2554 Register Dst = PtrAdd.getReg(0); 2555 2556 Builder.buildConstant(Dst, NewCst); 2557 PtrAdd.eraseFromParent(); 2558 } 2559 2560 bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, 2561 Register &Reg) const { 2562 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT"); 2563 Register DstReg = MI.getOperand(0).getReg(); 2564 Register SrcReg = MI.getOperand(1).getReg(); 2565 Register OriginalSrcReg = getSrcRegIgnoringCopies(SrcReg, MRI); 2566 if (OriginalSrcReg.isValid()) 2567 SrcReg = OriginalSrcReg; 2568 LLT DstTy = MRI.getType(DstReg); 2569 return mi_match(SrcReg, MRI, 2570 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy)))); 2571 } 2572 2573 bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, 2574 Register &Reg) const { 2575 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT"); 2576 Register DstReg = MI.getOperand(0).getReg(); 2577 Register SrcReg = MI.getOperand(1).getReg(); 2578 LLT DstTy = MRI.getType(DstReg); 2579 if (mi_match(SrcReg, MRI, 2580 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))))) { 2581 unsigned DstSize = DstTy.getScalarSizeInBits(); 2582 unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits(); 2583 return KB->getKnownBits(Reg).countMinLeadingZeros() >= DstSize - SrcSize; 2584 } 2585 return false; 2586 } 2587 2588 static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { 2589 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); 2590 const unsigned TruncSize = TruncTy.getScalarSizeInBits(); 2591 2592 // ShiftTy > 32 > TruncTy -> 32 2593 if (ShiftSize > 32 && TruncSize < 32) 2594 return ShiftTy.changeElementSize(32); 2595 2596 // TODO: We could also reduce to 16 bits, but that's more target-dependent. 2597 // Some targets like it, some don't, some only like it under certain 2598 // conditions/processor versions, etc. 2599 // A TL hook might be needed for this. 2600 2601 // Don't combine 2602 return ShiftTy; 2603 } 2604 2605 bool CombinerHelper::matchCombineTruncOfShift( 2606 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { 2607 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC"); 2608 Register DstReg = MI.getOperand(0).getReg(); 2609 Register SrcReg = MI.getOperand(1).getReg(); 2610 2611 if (!MRI.hasOneNonDBGUse(SrcReg)) 2612 return false; 2613 2614 LLT SrcTy = MRI.getType(SrcReg); 2615 LLT DstTy = MRI.getType(DstReg); 2616 2617 MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI); 2618 const auto &TL = getTargetLowering(); 2619 2620 LLT NewShiftTy; 2621 switch (SrcMI->getOpcode()) { 2622 default: 2623 return false; 2624 case TargetOpcode::G_SHL: { 2625 NewShiftTy = DstTy; 2626 2627 // Make sure new shift amount is legal. 2628 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg()); 2629 if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits())) 2630 return false; 2631 break; 2632 } 2633 case TargetOpcode::G_LSHR: 2634 case TargetOpcode::G_ASHR: { 2635 // For right shifts, we conservatively do not do the transform if the TRUNC 2636 // has any STORE users. The reason is that if we change the type of the 2637 // shift, we may break the truncstore combine. 2638 // 2639 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). 2640 for (auto &User : MRI.use_instructions(DstReg)) 2641 if (User.getOpcode() == TargetOpcode::G_STORE) 2642 return false; 2643 2644 NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy); 2645 if (NewShiftTy == SrcTy) 2646 return false; 2647 2648 // Make sure we won't lose information by truncating the high bits. 2649 KnownBits Known = KB->getKnownBits(SrcMI->getOperand(2).getReg()); 2650 if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() - 2651 DstTy.getScalarSizeInBits())) 2652 return false; 2653 break; 2654 } 2655 } 2656 2657 if (!isLegalOrBeforeLegalizer( 2658 {SrcMI->getOpcode(), 2659 {NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}})) 2660 return false; 2661 2662 MatchInfo = std::make_pair(SrcMI, NewShiftTy); 2663 return true; 2664 } 2665 2666 void CombinerHelper::applyCombineTruncOfShift( 2667 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { 2668 MachineInstr *ShiftMI = MatchInfo.first; 2669 LLT NewShiftTy = MatchInfo.second; 2670 2671 Register Dst = MI.getOperand(0).getReg(); 2672 LLT DstTy = MRI.getType(Dst); 2673 2674 Register ShiftAmt = ShiftMI->getOperand(2).getReg(); 2675 Register ShiftSrc = ShiftMI->getOperand(1).getReg(); 2676 ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0); 2677 2678 Register NewShift = 2679 Builder 2680 .buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt}) 2681 .getReg(0); 2682 2683 if (NewShiftTy == DstTy) 2684 replaceRegWith(MRI, Dst, NewShift); 2685 else 2686 Builder.buildTrunc(Dst, NewShift); 2687 2688 eraseInst(MI); 2689 } 2690 2691 bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const { 2692 return any_of(MI.explicit_uses(), [this](const MachineOperand &MO) { 2693 return MO.isReg() && 2694 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 2695 }); 2696 } 2697 2698 bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const { 2699 return all_of(MI.explicit_uses(), [this](const MachineOperand &MO) { 2700 return !MO.isReg() || 2701 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 2702 }); 2703 } 2704 2705 bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const { 2706 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); 2707 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 2708 return all_of(Mask, [](int Elt) { return Elt < 0; }); 2709 } 2710 2711 bool CombinerHelper::matchUndefStore(MachineInstr &MI) const { 2712 assert(MI.getOpcode() == TargetOpcode::G_STORE); 2713 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(0).getReg(), 2714 MRI); 2715 } 2716 2717 bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const { 2718 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 2719 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(), 2720 MRI); 2721 } 2722 2723 bool CombinerHelper::matchInsertExtractVecEltOutOfBounds( 2724 MachineInstr &MI) const { 2725 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || 2726 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && 2727 "Expected an insert/extract element op"); 2728 LLT VecTy = MRI.getType(MI.getOperand(1).getReg()); 2729 if (VecTy.isScalableVector()) 2730 return false; 2731 2732 unsigned IdxIdx = 2733 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; 2734 auto Idx = getIConstantVRegVal(MI.getOperand(IdxIdx).getReg(), MRI); 2735 if (!Idx) 2736 return false; 2737 return Idx->getZExtValue() >= VecTy.getNumElements(); 2738 } 2739 2740 bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, 2741 unsigned &OpIdx) const { 2742 GSelect &SelMI = cast<GSelect>(MI); 2743 auto Cst = 2744 isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI); 2745 if (!Cst) 2746 return false; 2747 OpIdx = Cst->isZero() ? 3 : 2; 2748 return true; 2749 } 2750 2751 void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); } 2752 2753 bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, 2754 const MachineOperand &MOP2) const { 2755 if (!MOP1.isReg() || !MOP2.isReg()) 2756 return false; 2757 auto InstAndDef1 = getDefSrcRegIgnoringCopies(MOP1.getReg(), MRI); 2758 if (!InstAndDef1) 2759 return false; 2760 auto InstAndDef2 = getDefSrcRegIgnoringCopies(MOP2.getReg(), MRI); 2761 if (!InstAndDef2) 2762 return false; 2763 MachineInstr *I1 = InstAndDef1->MI; 2764 MachineInstr *I2 = InstAndDef2->MI; 2765 2766 // Handle a case like this: 2767 // 2768 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) 2769 // 2770 // Even though %0 and %1 are produced by the same instruction they are not 2771 // the same values. 2772 if (I1 == I2) 2773 return MOP1.getReg() == MOP2.getReg(); 2774 2775 // If we have an instruction which loads or stores, we can't guarantee that 2776 // it is identical. 2777 // 2778 // For example, we may have 2779 // 2780 // %x1 = G_LOAD %addr (load N from @somewhere) 2781 // ... 2782 // call @foo 2783 // ... 2784 // %x2 = G_LOAD %addr (load N from @somewhere) 2785 // ... 2786 // %or = G_OR %x1, %x2 2787 // 2788 // It's possible that @foo will modify whatever lives at the address we're 2789 // loading from. To be safe, let's just assume that all loads and stores 2790 // are different (unless we have something which is guaranteed to not 2791 // change.) 2792 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) 2793 return false; 2794 2795 // If both instructions are loads or stores, they are equal only if both 2796 // are dereferenceable invariant loads with the same number of bits. 2797 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { 2798 GLoadStore *LS1 = dyn_cast<GLoadStore>(I1); 2799 GLoadStore *LS2 = dyn_cast<GLoadStore>(I2); 2800 if (!LS1 || !LS2) 2801 return false; 2802 2803 if (!I2->isDereferenceableInvariantLoad() || 2804 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) 2805 return false; 2806 } 2807 2808 // Check for physical registers on the instructions first to avoid cases 2809 // like this: 2810 // 2811 // %a = COPY $physreg 2812 // ... 2813 // SOMETHING implicit-def $physreg 2814 // ... 2815 // %b = COPY $physreg 2816 // 2817 // These copies are not equivalent. 2818 if (any_of(I1->uses(), [](const MachineOperand &MO) { 2819 return MO.isReg() && MO.getReg().isPhysical(); 2820 })) { 2821 // Check if we have a case like this: 2822 // 2823 // %a = COPY $physreg 2824 // %b = COPY %a 2825 // 2826 // In this case, I1 and I2 will both be equal to %a = COPY $physreg. 2827 // From that, we know that they must have the same value, since they must 2828 // have come from the same COPY. 2829 return I1->isIdenticalTo(*I2); 2830 } 2831 2832 // We don't have any physical registers, so we don't necessarily need the 2833 // same vreg defs. 2834 // 2835 // On the off-chance that there's some target instruction feeding into the 2836 // instruction, let's use produceSameValue instead of isIdenticalTo. 2837 if (Builder.getTII().produceSameValue(*I1, *I2, &MRI)) { 2838 // Handle instructions with multiple defs that produce same values. Values 2839 // are same for operands with same index. 2840 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) 2841 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) 2842 // I1 and I2 are different instructions but produce same values, 2843 // %1 and %6 are same, %1 and %7 are not the same value. 2844 return I1->findRegisterDefOperandIdx(InstAndDef1->Reg, /*TRI=*/nullptr) == 2845 I2->findRegisterDefOperandIdx(InstAndDef2->Reg, /*TRI=*/nullptr); 2846 } 2847 return false; 2848 } 2849 2850 bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, 2851 int64_t C) const { 2852 if (!MOP.isReg()) 2853 return false; 2854 auto *MI = MRI.getVRegDef(MOP.getReg()); 2855 auto MaybeCst = isConstantOrConstantSplatVector(*MI, MRI); 2856 return MaybeCst && MaybeCst->getBitWidth() <= 64 && 2857 MaybeCst->getSExtValue() == C; 2858 } 2859 2860 bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, 2861 double C) const { 2862 if (!MOP.isReg()) 2863 return false; 2864 std::optional<FPValueAndVReg> MaybeCst; 2865 if (!mi_match(MOP.getReg(), MRI, m_GFCstOrSplat(MaybeCst))) 2866 return false; 2867 2868 return MaybeCst->Value.isExactlyValue(C); 2869 } 2870 2871 void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, 2872 unsigned OpIdx) const { 2873 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?"); 2874 Register OldReg = MI.getOperand(0).getReg(); 2875 Register Replacement = MI.getOperand(OpIdx).getReg(); 2876 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?"); 2877 replaceRegWith(MRI, OldReg, Replacement); 2878 MI.eraseFromParent(); 2879 } 2880 2881 void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, 2882 Register Replacement) const { 2883 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?"); 2884 Register OldReg = MI.getOperand(0).getReg(); 2885 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?"); 2886 replaceRegWith(MRI, OldReg, Replacement); 2887 MI.eraseFromParent(); 2888 } 2889 2890 bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, 2891 unsigned ConstIdx) const { 2892 Register ConstReg = MI.getOperand(ConstIdx).getReg(); 2893 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2894 2895 // Get the shift amount 2896 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI); 2897 if (!VRegAndVal) 2898 return false; 2899 2900 // Return true of shift amount >= Bitwidth 2901 return (VRegAndVal->Value.uge(DstTy.getSizeInBits())); 2902 } 2903 2904 void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const { 2905 assert((MI.getOpcode() == TargetOpcode::G_FSHL || 2906 MI.getOpcode() == TargetOpcode::G_FSHR) && 2907 "This is not a funnel shift operation"); 2908 2909 Register ConstReg = MI.getOperand(3).getReg(); 2910 LLT ConstTy = MRI.getType(ConstReg); 2911 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2912 2913 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI); 2914 assert((VRegAndVal) && "Value is not a constant"); 2915 2916 // Calculate the new Shift Amount = Old Shift Amount % BitWidth 2917 APInt NewConst = VRegAndVal->Value.urem( 2918 APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); 2919 2920 auto NewConstInstr = Builder.buildConstant(ConstTy, NewConst.getZExtValue()); 2921 Builder.buildInstr( 2922 MI.getOpcode(), {MI.getOperand(0)}, 2923 {MI.getOperand(1), MI.getOperand(2), NewConstInstr.getReg(0)}); 2924 2925 MI.eraseFromParent(); 2926 } 2927 2928 bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const { 2929 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 2930 // Match (cond ? x : x) 2931 return matchEqualDefs(MI.getOperand(2), MI.getOperand(3)) && 2932 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(2).getReg(), 2933 MRI); 2934 } 2935 2936 bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const { 2937 return matchEqualDefs(MI.getOperand(1), MI.getOperand(2)) && 2938 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(), 2939 MRI); 2940 } 2941 2942 bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, 2943 unsigned OpIdx) const { 2944 return matchConstantOp(MI.getOperand(OpIdx), 0) && 2945 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(OpIdx).getReg(), 2946 MRI); 2947 } 2948 2949 bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, 2950 unsigned OpIdx) const { 2951 MachineOperand &MO = MI.getOperand(OpIdx); 2952 return MO.isReg() && 2953 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 2954 } 2955 2956 bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, 2957 unsigned OpIdx) const { 2958 MachineOperand &MO = MI.getOperand(OpIdx); 2959 return isKnownToBeAPowerOfTwo(MO.getReg(), MRI, KB); 2960 } 2961 2962 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, 2963 double C) const { 2964 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 2965 Builder.buildFConstant(MI.getOperand(0), C); 2966 MI.eraseFromParent(); 2967 } 2968 2969 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, 2970 int64_t C) const { 2971 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 2972 Builder.buildConstant(MI.getOperand(0), C); 2973 MI.eraseFromParent(); 2974 } 2975 2976 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const { 2977 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 2978 Builder.buildConstant(MI.getOperand(0), C); 2979 MI.eraseFromParent(); 2980 } 2981 2982 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, 2983 ConstantFP *CFP) const { 2984 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 2985 Builder.buildFConstant(MI.getOperand(0), CFP->getValueAPF()); 2986 MI.eraseFromParent(); 2987 } 2988 2989 void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const { 2990 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 2991 Builder.buildUndef(MI.getOperand(0)); 2992 MI.eraseFromParent(); 2993 } 2994 2995 bool CombinerHelper::matchSimplifyAddToSub( 2996 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { 2997 Register LHS = MI.getOperand(1).getReg(); 2998 Register RHS = MI.getOperand(2).getReg(); 2999 Register &NewLHS = std::get<0>(MatchInfo); 3000 Register &NewRHS = std::get<1>(MatchInfo); 3001 3002 // Helper lambda to check for opportunities for 3003 // ((0-A) + B) -> B - A 3004 // (A + (0-B)) -> A - B 3005 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { 3006 if (!mi_match(MaybeSub, MRI, m_Neg(m_Reg(NewRHS)))) 3007 return false; 3008 NewLHS = MaybeNewLHS; 3009 return true; 3010 }; 3011 3012 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); 3013 } 3014 3015 bool CombinerHelper::matchCombineInsertVecElts( 3016 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { 3017 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && 3018 "Invalid opcode"); 3019 Register DstReg = MI.getOperand(0).getReg(); 3020 LLT DstTy = MRI.getType(DstReg); 3021 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?"); 3022 3023 if (DstTy.isScalableVector()) 3024 return false; 3025 3026 unsigned NumElts = DstTy.getNumElements(); 3027 // If this MI is part of a sequence of insert_vec_elts, then 3028 // don't do the combine in the middle of the sequence. 3029 if (MRI.hasOneUse(DstReg) && MRI.use_instr_begin(DstReg)->getOpcode() == 3030 TargetOpcode::G_INSERT_VECTOR_ELT) 3031 return false; 3032 MachineInstr *CurrInst = &MI; 3033 MachineInstr *TmpInst; 3034 int64_t IntImm; 3035 Register TmpReg; 3036 MatchInfo.resize(NumElts); 3037 while (mi_match( 3038 CurrInst->getOperand(0).getReg(), MRI, 3039 m_GInsertVecElt(m_MInstr(TmpInst), m_Reg(TmpReg), m_ICst(IntImm)))) { 3040 if (IntImm >= NumElts || IntImm < 0) 3041 return false; 3042 if (!MatchInfo[IntImm]) 3043 MatchInfo[IntImm] = TmpReg; 3044 CurrInst = TmpInst; 3045 } 3046 // Variable index. 3047 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) 3048 return false; 3049 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { 3050 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { 3051 if (!MatchInfo[I - 1].isValid()) 3052 MatchInfo[I - 1] = TmpInst->getOperand(I).getReg(); 3053 } 3054 return true; 3055 } 3056 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully 3057 // overwritten, bail out. 3058 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF || 3059 all_of(MatchInfo, [](Register Reg) { return !!Reg; }); 3060 } 3061 3062 void CombinerHelper::applyCombineInsertVecElts( 3063 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { 3064 Register UndefReg; 3065 auto GetUndef = [&]() { 3066 if (UndefReg) 3067 return UndefReg; 3068 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 3069 UndefReg = Builder.buildUndef(DstTy.getScalarType()).getReg(0); 3070 return UndefReg; 3071 }; 3072 for (Register &Reg : MatchInfo) { 3073 if (!Reg) 3074 Reg = GetUndef(); 3075 } 3076 Builder.buildBuildVector(MI.getOperand(0).getReg(), MatchInfo); 3077 MI.eraseFromParent(); 3078 } 3079 3080 void CombinerHelper::applySimplifyAddToSub( 3081 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { 3082 Register SubLHS, SubRHS; 3083 std::tie(SubLHS, SubRHS) = MatchInfo; 3084 Builder.buildSub(MI.getOperand(0).getReg(), SubLHS, SubRHS); 3085 MI.eraseFromParent(); 3086 } 3087 3088 bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( 3089 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { 3090 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... 3091 // 3092 // Creates the new hand + logic instruction (but does not insert them.) 3093 // 3094 // On success, MatchInfo is populated with the new instructions. These are 3095 // inserted in applyHoistLogicOpWithSameOpcodeHands. 3096 unsigned LogicOpcode = MI.getOpcode(); 3097 assert(LogicOpcode == TargetOpcode::G_AND || 3098 LogicOpcode == TargetOpcode::G_OR || 3099 LogicOpcode == TargetOpcode::G_XOR); 3100 MachineIRBuilder MIB(MI); 3101 Register Dst = MI.getOperand(0).getReg(); 3102 Register LHSReg = MI.getOperand(1).getReg(); 3103 Register RHSReg = MI.getOperand(2).getReg(); 3104 3105 // Don't recompute anything. 3106 if (!MRI.hasOneNonDBGUse(LHSReg) || !MRI.hasOneNonDBGUse(RHSReg)) 3107 return false; 3108 3109 // Make sure we have (hand x, ...), (hand y, ...) 3110 MachineInstr *LeftHandInst = getDefIgnoringCopies(LHSReg, MRI); 3111 MachineInstr *RightHandInst = getDefIgnoringCopies(RHSReg, MRI); 3112 if (!LeftHandInst || !RightHandInst) 3113 return false; 3114 unsigned HandOpcode = LeftHandInst->getOpcode(); 3115 if (HandOpcode != RightHandInst->getOpcode()) 3116 return false; 3117 if (LeftHandInst->getNumOperands() < 2 || 3118 !LeftHandInst->getOperand(1).isReg() || 3119 RightHandInst->getNumOperands() < 2 || 3120 !RightHandInst->getOperand(1).isReg()) 3121 return false; 3122 3123 // Make sure the types match up, and if we're doing this post-legalization, 3124 // we end up with legal types. 3125 Register X = LeftHandInst->getOperand(1).getReg(); 3126 Register Y = RightHandInst->getOperand(1).getReg(); 3127 LLT XTy = MRI.getType(X); 3128 LLT YTy = MRI.getType(Y); 3129 if (!XTy.isValid() || XTy != YTy) 3130 return false; 3131 3132 // Optional extra source register. 3133 Register ExtraHandOpSrcReg; 3134 switch (HandOpcode) { 3135 default: 3136 return false; 3137 case TargetOpcode::G_ANYEXT: 3138 case TargetOpcode::G_SEXT: 3139 case TargetOpcode::G_ZEXT: { 3140 // Match: logic (ext X), (ext Y) --> ext (logic X, Y) 3141 break; 3142 } 3143 case TargetOpcode::G_TRUNC: { 3144 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y) 3145 const MachineFunction *MF = MI.getMF(); 3146 LLVMContext &Ctx = MF->getFunction().getContext(); 3147 3148 LLT DstTy = MRI.getType(Dst); 3149 const TargetLowering &TLI = getTargetLowering(); 3150 3151 // Be extra careful sinking truncate. If it's free, there's no benefit in 3152 // widening a binop. 3153 if (TLI.isZExtFree(DstTy, XTy, Ctx) && TLI.isTruncateFree(XTy, DstTy, Ctx)) 3154 return false; 3155 break; 3156 } 3157 case TargetOpcode::G_AND: 3158 case TargetOpcode::G_ASHR: 3159 case TargetOpcode::G_LSHR: 3160 case TargetOpcode::G_SHL: { 3161 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z 3162 MachineOperand &ZOp = LeftHandInst->getOperand(2); 3163 if (!matchEqualDefs(ZOp, RightHandInst->getOperand(2))) 3164 return false; 3165 ExtraHandOpSrcReg = ZOp.getReg(); 3166 break; 3167 } 3168 } 3169 3170 if (!isLegalOrBeforeLegalizer({LogicOpcode, {XTy, YTy}})) 3171 return false; 3172 3173 // Record the steps to build the new instructions. 3174 // 3175 // Steps to build (logic x, y) 3176 auto NewLogicDst = MRI.createGenericVirtualRegister(XTy); 3177 OperandBuildSteps LogicBuildSteps = { 3178 [=](MachineInstrBuilder &MIB) { MIB.addDef(NewLogicDst); }, 3179 [=](MachineInstrBuilder &MIB) { MIB.addReg(X); }, 3180 [=](MachineInstrBuilder &MIB) { MIB.addReg(Y); }}; 3181 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); 3182 3183 // Steps to build hand (logic x, y), ...z 3184 OperandBuildSteps HandBuildSteps = { 3185 [=](MachineInstrBuilder &MIB) { MIB.addDef(Dst); }, 3186 [=](MachineInstrBuilder &MIB) { MIB.addReg(NewLogicDst); }}; 3187 if (ExtraHandOpSrcReg.isValid()) 3188 HandBuildSteps.push_back( 3189 [=](MachineInstrBuilder &MIB) { MIB.addReg(ExtraHandOpSrcReg); }); 3190 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); 3191 3192 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); 3193 return true; 3194 } 3195 3196 void CombinerHelper::applyBuildInstructionSteps( 3197 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { 3198 assert(MatchInfo.InstrsToBuild.size() && 3199 "Expected at least one instr to build?"); 3200 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { 3201 assert(InstrToBuild.Opcode && "Expected a valid opcode?"); 3202 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?"); 3203 MachineInstrBuilder Instr = Builder.buildInstr(InstrToBuild.Opcode); 3204 for (auto &OperandFn : InstrToBuild.OperandFns) 3205 OperandFn(Instr); 3206 } 3207 MI.eraseFromParent(); 3208 } 3209 3210 bool CombinerHelper::matchAshrShlToSextInreg( 3211 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { 3212 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 3213 int64_t ShlCst, AshrCst; 3214 Register Src; 3215 if (!mi_match(MI.getOperand(0).getReg(), MRI, 3216 m_GAShr(m_GShl(m_Reg(Src), m_ICstOrSplat(ShlCst)), 3217 m_ICstOrSplat(AshrCst)))) 3218 return false; 3219 if (ShlCst != AshrCst) 3220 return false; 3221 if (!isLegalOrBeforeLegalizer( 3222 {TargetOpcode::G_SEXT_INREG, {MRI.getType(Src)}})) 3223 return false; 3224 MatchInfo = std::make_tuple(Src, ShlCst); 3225 return true; 3226 } 3227 3228 void CombinerHelper::applyAshShlToSextInreg( 3229 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { 3230 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 3231 Register Src; 3232 int64_t ShiftAmt; 3233 std::tie(Src, ShiftAmt) = MatchInfo; 3234 unsigned Size = MRI.getType(Src).getScalarSizeInBits(); 3235 Builder.buildSExtInReg(MI.getOperand(0).getReg(), Src, Size - ShiftAmt); 3236 MI.eraseFromParent(); 3237 } 3238 3239 /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 3240 bool CombinerHelper::matchOverlappingAnd( 3241 MachineInstr &MI, 3242 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 3243 assert(MI.getOpcode() == TargetOpcode::G_AND); 3244 3245 Register Dst = MI.getOperand(0).getReg(); 3246 LLT Ty = MRI.getType(Dst); 3247 3248 Register R; 3249 int64_t C1; 3250 int64_t C2; 3251 if (!mi_match( 3252 Dst, MRI, 3253 m_GAnd(m_GAnd(m_Reg(R), m_ICst(C1)), m_ICst(C2)))) 3254 return false; 3255 3256 MatchInfo = [=](MachineIRBuilder &B) { 3257 if (C1 & C2) { 3258 B.buildAnd(Dst, R, B.buildConstant(Ty, C1 & C2)); 3259 return; 3260 } 3261 auto Zero = B.buildConstant(Ty, 0); 3262 replaceRegWith(MRI, Dst, Zero->getOperand(0).getReg()); 3263 }; 3264 return true; 3265 } 3266 3267 bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, 3268 Register &Replacement) const { 3269 // Given 3270 // 3271 // %y:_(sN) = G_SOMETHING 3272 // %x:_(sN) = G_SOMETHING 3273 // %res:_(sN) = G_AND %x, %y 3274 // 3275 // Eliminate the G_AND when it is known that x & y == x or x & y == y. 3276 // 3277 // Patterns like this can appear as a result of legalization. E.g. 3278 // 3279 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y 3280 // %one:_(s32) = G_CONSTANT i32 1 3281 // %and:_(s32) = G_AND %cmp, %one 3282 // 3283 // In this case, G_ICMP only produces a single bit, so x & 1 == x. 3284 assert(MI.getOpcode() == TargetOpcode::G_AND); 3285 if (!KB) 3286 return false; 3287 3288 Register AndDst = MI.getOperand(0).getReg(); 3289 Register LHS = MI.getOperand(1).getReg(); 3290 Register RHS = MI.getOperand(2).getReg(); 3291 3292 // Check the RHS (maybe a constant) first, and if we have no KnownBits there, 3293 // we can't do anything. If we do, then it depends on whether we have 3294 // KnownBits on the LHS. 3295 KnownBits RHSBits = KB->getKnownBits(RHS); 3296 if (RHSBits.isUnknown()) 3297 return false; 3298 3299 KnownBits LHSBits = KB->getKnownBits(LHS); 3300 3301 // Check that x & Mask == x. 3302 // x & 1 == x, always 3303 // x & 0 == x, only if x is also 0 3304 // Meaning Mask has no effect if every bit is either one in Mask or zero in x. 3305 // 3306 // Check if we can replace AndDst with the LHS of the G_AND 3307 if (canReplaceReg(AndDst, LHS, MRI) && 3308 (LHSBits.Zero | RHSBits.One).isAllOnes()) { 3309 Replacement = LHS; 3310 return true; 3311 } 3312 3313 // Check if we can replace AndDst with the RHS of the G_AND 3314 if (canReplaceReg(AndDst, RHS, MRI) && 3315 (LHSBits.One | RHSBits.Zero).isAllOnes()) { 3316 Replacement = RHS; 3317 return true; 3318 } 3319 3320 return false; 3321 } 3322 3323 bool CombinerHelper::matchRedundantOr(MachineInstr &MI, 3324 Register &Replacement) const { 3325 // Given 3326 // 3327 // %y:_(sN) = G_SOMETHING 3328 // %x:_(sN) = G_SOMETHING 3329 // %res:_(sN) = G_OR %x, %y 3330 // 3331 // Eliminate the G_OR when it is known that x | y == x or x | y == y. 3332 assert(MI.getOpcode() == TargetOpcode::G_OR); 3333 if (!KB) 3334 return false; 3335 3336 Register OrDst = MI.getOperand(0).getReg(); 3337 Register LHS = MI.getOperand(1).getReg(); 3338 Register RHS = MI.getOperand(2).getReg(); 3339 3340 KnownBits LHSBits = KB->getKnownBits(LHS); 3341 KnownBits RHSBits = KB->getKnownBits(RHS); 3342 3343 // Check that x | Mask == x. 3344 // x | 0 == x, always 3345 // x | 1 == x, only if x is also 1 3346 // Meaning Mask has no effect if every bit is either zero in Mask or one in x. 3347 // 3348 // Check if we can replace OrDst with the LHS of the G_OR 3349 if (canReplaceReg(OrDst, LHS, MRI) && 3350 (LHSBits.One | RHSBits.Zero).isAllOnes()) { 3351 Replacement = LHS; 3352 return true; 3353 } 3354 3355 // Check if we can replace OrDst with the RHS of the G_OR 3356 if (canReplaceReg(OrDst, RHS, MRI) && 3357 (LHSBits.Zero | RHSBits.One).isAllOnes()) { 3358 Replacement = RHS; 3359 return true; 3360 } 3361 3362 return false; 3363 } 3364 3365 bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const { 3366 // If the input is already sign extended, just drop the extension. 3367 Register Src = MI.getOperand(1).getReg(); 3368 unsigned ExtBits = MI.getOperand(2).getImm(); 3369 unsigned TypeSize = MRI.getType(Src).getScalarSizeInBits(); 3370 return KB->computeNumSignBits(Src) >= (TypeSize - ExtBits + 1); 3371 } 3372 3373 static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, 3374 int64_t Cst, bool IsVector, bool IsFP) { 3375 // For i1, Cst will always be -1 regardless of boolean contents. 3376 return (ScalarSizeBits == 1 && Cst == -1) || 3377 isConstTrueVal(TLI, Cst, IsVector, IsFP); 3378 } 3379 3380 // This combine tries to reduce the number of scalarised G_TRUNC instructions by 3381 // using vector truncates instead 3382 // 3383 // EXAMPLE: 3384 // %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>) 3385 // %T_a(i16) = G_TRUNC %a(i32) 3386 // %T_b(i16) = G_TRUNC %b(i32) 3387 // %Undef(i16) = G_IMPLICIT_DEF(i16) 3388 // %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16) 3389 // 3390 // ===> 3391 // %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>) 3392 // %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>) 3393 // %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>) 3394 // 3395 // Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs 3396 bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI, 3397 Register &MatchInfo) const { 3398 auto BuildMI = cast<GBuildVector>(&MI); 3399 unsigned NumOperands = BuildMI->getNumSources(); 3400 LLT DstTy = MRI.getType(BuildMI->getReg(0)); 3401 3402 // Check the G_BUILD_VECTOR sources 3403 unsigned I; 3404 MachineInstr *UnmergeMI = nullptr; 3405 3406 // Check all source TRUNCs come from the same UNMERGE instruction 3407 for (I = 0; I < NumOperands; ++I) { 3408 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I)); 3409 auto SrcMIOpc = SrcMI->getOpcode(); 3410 3411 // Check if the G_TRUNC instructions all come from the same MI 3412 if (SrcMIOpc == TargetOpcode::G_TRUNC) { 3413 if (!UnmergeMI) { 3414 UnmergeMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg()); 3415 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES) 3416 return false; 3417 } else { 3418 auto UnmergeSrcMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg()); 3419 if (UnmergeMI != UnmergeSrcMI) 3420 return false; 3421 } 3422 } else { 3423 break; 3424 } 3425 } 3426 if (I < 2) 3427 return false; 3428 3429 // Check the remaining source elements are only G_IMPLICIT_DEF 3430 for (; I < NumOperands; ++I) { 3431 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I)); 3432 auto SrcMIOpc = SrcMI->getOpcode(); 3433 3434 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF) 3435 return false; 3436 } 3437 3438 // Check the size of unmerge source 3439 MatchInfo = cast<GUnmerge>(UnmergeMI)->getSourceReg(); 3440 LLT UnmergeSrcTy = MRI.getType(MatchInfo); 3441 if (!DstTy.getElementCount().isKnownMultipleOf(UnmergeSrcTy.getNumElements())) 3442 return false; 3443 3444 // Only generate legal instructions post-legalizer 3445 if (!IsPreLegalize) { 3446 LLT MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType()); 3447 3448 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() && 3449 !isLegal({TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}})) 3450 return false; 3451 3452 if (!isLegal({TargetOpcode::G_TRUNC, {DstTy, MidTy}})) 3453 return false; 3454 } 3455 3456 return true; 3457 } 3458 3459 void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI, 3460 Register &MatchInfo) const { 3461 Register MidReg; 3462 auto BuildMI = cast<GBuildVector>(&MI); 3463 Register DstReg = BuildMI->getReg(0); 3464 LLT DstTy = MRI.getType(DstReg); 3465 LLT UnmergeSrcTy = MRI.getType(MatchInfo); 3466 unsigned DstTyNumElt = DstTy.getNumElements(); 3467 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements(); 3468 3469 // No need to pad vector if only G_TRUNC is needed 3470 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) { 3471 MidReg = MatchInfo; 3472 } else { 3473 Register UndefReg = Builder.buildUndef(UnmergeSrcTy).getReg(0); 3474 SmallVector<Register> ConcatRegs = {MatchInfo}; 3475 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I) 3476 ConcatRegs.push_back(UndefReg); 3477 3478 auto MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType()); 3479 MidReg = Builder.buildConcatVectors(MidTy, ConcatRegs).getReg(0); 3480 } 3481 3482 Builder.buildTrunc(DstReg, MidReg); 3483 MI.eraseFromParent(); 3484 } 3485 3486 bool CombinerHelper::matchNotCmp( 3487 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { 3488 assert(MI.getOpcode() == TargetOpcode::G_XOR); 3489 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 3490 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); 3491 Register XorSrc; 3492 Register CstReg; 3493 // We match xor(src, true) here. 3494 if (!mi_match(MI.getOperand(0).getReg(), MRI, 3495 m_GXor(m_Reg(XorSrc), m_Reg(CstReg)))) 3496 return false; 3497 3498 if (!MRI.hasOneNonDBGUse(XorSrc)) 3499 return false; 3500 3501 // Check that XorSrc is the root of a tree of comparisons combined with ANDs 3502 // and ORs. The suffix of RegsToNegate starting from index I is used a work 3503 // list of tree nodes to visit. 3504 RegsToNegate.push_back(XorSrc); 3505 // Remember whether the comparisons are all integer or all floating point. 3506 bool IsInt = false; 3507 bool IsFP = false; 3508 for (unsigned I = 0; I < RegsToNegate.size(); ++I) { 3509 Register Reg = RegsToNegate[I]; 3510 if (!MRI.hasOneNonDBGUse(Reg)) 3511 return false; 3512 MachineInstr *Def = MRI.getVRegDef(Reg); 3513 switch (Def->getOpcode()) { 3514 default: 3515 // Don't match if the tree contains anything other than ANDs, ORs and 3516 // comparisons. 3517 return false; 3518 case TargetOpcode::G_ICMP: 3519 if (IsFP) 3520 return false; 3521 IsInt = true; 3522 // When we apply the combine we will invert the predicate. 3523 break; 3524 case TargetOpcode::G_FCMP: 3525 if (IsInt) 3526 return false; 3527 IsFP = true; 3528 // When we apply the combine we will invert the predicate. 3529 break; 3530 case TargetOpcode::G_AND: 3531 case TargetOpcode::G_OR: 3532 // Implement De Morgan's laws: 3533 // ~(x & y) -> ~x | ~y 3534 // ~(x | y) -> ~x & ~y 3535 // When we apply the combine we will change the opcode and recursively 3536 // negate the operands. 3537 RegsToNegate.push_back(Def->getOperand(1).getReg()); 3538 RegsToNegate.push_back(Def->getOperand(2).getReg()); 3539 break; 3540 } 3541 } 3542 3543 // Now we know whether the comparisons are integer or floating point, check 3544 // the constant in the xor. 3545 int64_t Cst; 3546 if (Ty.isVector()) { 3547 MachineInstr *CstDef = MRI.getVRegDef(CstReg); 3548 auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI); 3549 if (!MaybeCst) 3550 return false; 3551 if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP)) 3552 return false; 3553 } else { 3554 if (!mi_match(CstReg, MRI, m_ICst(Cst))) 3555 return false; 3556 if (!isConstValidTrue(TLI, Ty.getSizeInBits(), Cst, false, IsFP)) 3557 return false; 3558 } 3559 3560 return true; 3561 } 3562 3563 void CombinerHelper::applyNotCmp( 3564 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { 3565 for (Register Reg : RegsToNegate) { 3566 MachineInstr *Def = MRI.getVRegDef(Reg); 3567 Observer.changingInstr(*Def); 3568 // For each comparison, invert the opcode. For each AND and OR, change the 3569 // opcode. 3570 switch (Def->getOpcode()) { 3571 default: 3572 llvm_unreachable("Unexpected opcode"); 3573 case TargetOpcode::G_ICMP: 3574 case TargetOpcode::G_FCMP: { 3575 MachineOperand &PredOp = Def->getOperand(1); 3576 CmpInst::Predicate NewP = CmpInst::getInversePredicate( 3577 (CmpInst::Predicate)PredOp.getPredicate()); 3578 PredOp.setPredicate(NewP); 3579 break; 3580 } 3581 case TargetOpcode::G_AND: 3582 Def->setDesc(Builder.getTII().get(TargetOpcode::G_OR)); 3583 break; 3584 case TargetOpcode::G_OR: 3585 Def->setDesc(Builder.getTII().get(TargetOpcode::G_AND)); 3586 break; 3587 } 3588 Observer.changedInstr(*Def); 3589 } 3590 3591 replaceRegWith(MRI, MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); 3592 MI.eraseFromParent(); 3593 } 3594 3595 bool CombinerHelper::matchXorOfAndWithSameReg( 3596 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { 3597 // Match (xor (and x, y), y) (or any of its commuted cases) 3598 assert(MI.getOpcode() == TargetOpcode::G_XOR); 3599 Register &X = MatchInfo.first; 3600 Register &Y = MatchInfo.second; 3601 Register AndReg = MI.getOperand(1).getReg(); 3602 Register SharedReg = MI.getOperand(2).getReg(); 3603 3604 // Find a G_AND on either side of the G_XOR. 3605 // Look for one of 3606 // 3607 // (xor (and x, y), SharedReg) 3608 // (xor SharedReg, (and x, y)) 3609 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) { 3610 std::swap(AndReg, SharedReg); 3611 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) 3612 return false; 3613 } 3614 3615 // Only do this if we'll eliminate the G_AND. 3616 if (!MRI.hasOneNonDBGUse(AndReg)) 3617 return false; 3618 3619 // We can combine if SharedReg is the same as either the LHS or RHS of the 3620 // G_AND. 3621 if (Y != SharedReg) 3622 std::swap(X, Y); 3623 return Y == SharedReg; 3624 } 3625 3626 void CombinerHelper::applyXorOfAndWithSameReg( 3627 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { 3628 // Fold (xor (and x, y), y) -> (and (not x), y) 3629 Register X, Y; 3630 std::tie(X, Y) = MatchInfo; 3631 auto Not = Builder.buildNot(MRI.getType(X), X); 3632 Observer.changingInstr(MI); 3633 MI.setDesc(Builder.getTII().get(TargetOpcode::G_AND)); 3634 MI.getOperand(1).setReg(Not->getOperand(0).getReg()); 3635 MI.getOperand(2).setReg(Y); 3636 Observer.changedInstr(MI); 3637 } 3638 3639 bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const { 3640 auto &PtrAdd = cast<GPtrAdd>(MI); 3641 Register DstReg = PtrAdd.getReg(0); 3642 LLT Ty = MRI.getType(DstReg); 3643 const DataLayout &DL = Builder.getMF().getDataLayout(); 3644 3645 if (DL.isNonIntegralAddressSpace(Ty.getScalarType().getAddressSpace())) 3646 return false; 3647 3648 if (Ty.isPointer()) { 3649 auto ConstVal = getIConstantVRegVal(PtrAdd.getBaseReg(), MRI); 3650 return ConstVal && *ConstVal == 0; 3651 } 3652 3653 assert(Ty.isVector() && "Expecting a vector type"); 3654 const MachineInstr *VecMI = MRI.getVRegDef(PtrAdd.getBaseReg()); 3655 return isBuildVectorAllZeros(*VecMI, MRI); 3656 } 3657 3658 void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const { 3659 auto &PtrAdd = cast<GPtrAdd>(MI); 3660 Builder.buildIntToPtr(PtrAdd.getReg(0), PtrAdd.getOffsetReg()); 3661 PtrAdd.eraseFromParent(); 3662 } 3663 3664 /// The second source operand is known to be a power of 2. 3665 void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const { 3666 Register DstReg = MI.getOperand(0).getReg(); 3667 Register Src0 = MI.getOperand(1).getReg(); 3668 Register Pow2Src1 = MI.getOperand(2).getReg(); 3669 LLT Ty = MRI.getType(DstReg); 3670 3671 // Fold (urem x, pow2) -> (and x, pow2-1) 3672 auto NegOne = Builder.buildConstant(Ty, -1); 3673 auto Add = Builder.buildAdd(Ty, Pow2Src1, NegOne); 3674 Builder.buildAnd(DstReg, Src0, Add); 3675 MI.eraseFromParent(); 3676 } 3677 3678 bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, 3679 unsigned &SelectOpNo) const { 3680 Register LHS = MI.getOperand(1).getReg(); 3681 Register RHS = MI.getOperand(2).getReg(); 3682 3683 Register OtherOperandReg = RHS; 3684 SelectOpNo = 1; 3685 MachineInstr *Select = MRI.getVRegDef(LHS); 3686 3687 // Don't do this unless the old select is going away. We want to eliminate the 3688 // binary operator, not replace a binop with a select. 3689 if (Select->getOpcode() != TargetOpcode::G_SELECT || 3690 !MRI.hasOneNonDBGUse(LHS)) { 3691 OtherOperandReg = LHS; 3692 SelectOpNo = 2; 3693 Select = MRI.getVRegDef(RHS); 3694 if (Select->getOpcode() != TargetOpcode::G_SELECT || 3695 !MRI.hasOneNonDBGUse(RHS)) 3696 return false; 3697 } 3698 3699 MachineInstr *SelectLHS = MRI.getVRegDef(Select->getOperand(2).getReg()); 3700 MachineInstr *SelectRHS = MRI.getVRegDef(Select->getOperand(3).getReg()); 3701 3702 if (!isConstantOrConstantVector(*SelectLHS, MRI, 3703 /*AllowFP*/ true, 3704 /*AllowOpaqueConstants*/ false)) 3705 return false; 3706 if (!isConstantOrConstantVector(*SelectRHS, MRI, 3707 /*AllowFP*/ true, 3708 /*AllowOpaqueConstants*/ false)) 3709 return false; 3710 3711 unsigned BinOpcode = MI.getOpcode(); 3712 3713 // We know that one of the operands is a select of constants. Now verify that 3714 // the other binary operator operand is either a constant, or we can handle a 3715 // variable. 3716 bool CanFoldNonConst = 3717 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && 3718 (isNullOrNullSplat(*SelectLHS, MRI) || 3719 isAllOnesOrAllOnesSplat(*SelectLHS, MRI)) && 3720 (isNullOrNullSplat(*SelectRHS, MRI) || 3721 isAllOnesOrAllOnesSplat(*SelectRHS, MRI)); 3722 if (CanFoldNonConst) 3723 return true; 3724 3725 return isConstantOrConstantVector(*MRI.getVRegDef(OtherOperandReg), MRI, 3726 /*AllowFP*/ true, 3727 /*AllowOpaqueConstants*/ false); 3728 } 3729 3730 /// \p SelectOperand is the operand in binary operator \p MI that is the select 3731 /// to fold. 3732 void CombinerHelper::applyFoldBinOpIntoSelect( 3733 MachineInstr &MI, const unsigned &SelectOperand) const { 3734 Register Dst = MI.getOperand(0).getReg(); 3735 Register LHS = MI.getOperand(1).getReg(); 3736 Register RHS = MI.getOperand(2).getReg(); 3737 MachineInstr *Select = MRI.getVRegDef(MI.getOperand(SelectOperand).getReg()); 3738 3739 Register SelectCond = Select->getOperand(1).getReg(); 3740 Register SelectTrue = Select->getOperand(2).getReg(); 3741 Register SelectFalse = Select->getOperand(3).getReg(); 3742 3743 LLT Ty = MRI.getType(Dst); 3744 unsigned BinOpcode = MI.getOpcode(); 3745 3746 Register FoldTrue, FoldFalse; 3747 3748 // We have a select-of-constants followed by a binary operator with a 3749 // constant. Eliminate the binop by pulling the constant math into the select. 3750 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO 3751 if (SelectOperand == 1) { 3752 // TODO: SelectionDAG verifies this actually constant folds before 3753 // committing to the combine. 3754 3755 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {SelectTrue, RHS}).getReg(0); 3756 FoldFalse = 3757 Builder.buildInstr(BinOpcode, {Ty}, {SelectFalse, RHS}).getReg(0); 3758 } else { 3759 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectTrue}).getReg(0); 3760 FoldFalse = 3761 Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectFalse}).getReg(0); 3762 } 3763 3764 Builder.buildSelect(Dst, SelectCond, FoldTrue, FoldFalse, MI.getFlags()); 3765 MI.eraseFromParent(); 3766 } 3767 3768 std::optional<SmallVector<Register, 8>> 3769 CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { 3770 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!"); 3771 // We want to detect if Root is part of a tree which represents a bunch 3772 // of loads being merged into a larger load. We'll try to recognize patterns 3773 // like, for example: 3774 // 3775 // Reg Reg 3776 // \ / 3777 // OR_1 Reg 3778 // \ / 3779 // OR_2 3780 // \ Reg 3781 // .. / 3782 // Root 3783 // 3784 // Reg Reg Reg Reg 3785 // \ / \ / 3786 // OR_1 OR_2 3787 // \ / 3788 // \ / 3789 // ... 3790 // Root 3791 // 3792 // Each "Reg" may have been produced by a load + some arithmetic. This 3793 // function will save each of them. 3794 SmallVector<Register, 8> RegsToVisit; 3795 SmallVector<const MachineInstr *, 7> Ors = {Root}; 3796 3797 // In the "worst" case, we're dealing with a load for each byte. So, there 3798 // are at most #bytes - 1 ORs. 3799 const unsigned MaxIter = 3800 MRI.getType(Root->getOperand(0).getReg()).getSizeInBytes() - 1; 3801 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { 3802 if (Ors.empty()) 3803 break; 3804 const MachineInstr *Curr = Ors.pop_back_val(); 3805 Register OrLHS = Curr->getOperand(1).getReg(); 3806 Register OrRHS = Curr->getOperand(2).getReg(); 3807 3808 // In the combine, we want to elimate the entire tree. 3809 if (!MRI.hasOneNonDBGUse(OrLHS) || !MRI.hasOneNonDBGUse(OrRHS)) 3810 return std::nullopt; 3811 3812 // If it's a G_OR, save it and continue to walk. If it's not, then it's 3813 // something that may be a load + arithmetic. 3814 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrLHS, MRI)) 3815 Ors.push_back(Or); 3816 else 3817 RegsToVisit.push_back(OrLHS); 3818 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrRHS, MRI)) 3819 Ors.push_back(Or); 3820 else 3821 RegsToVisit.push_back(OrRHS); 3822 } 3823 3824 // We're going to try and merge each register into a wider power-of-2 type, 3825 // so we ought to have an even number of registers. 3826 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) 3827 return std::nullopt; 3828 return RegsToVisit; 3829 } 3830 3831 /// Helper function for findLoadOffsetsForLoadOrCombine. 3832 /// 3833 /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, 3834 /// and then moving that value into a specific byte offset. 3835 /// 3836 /// e.g. x[i] << 24 3837 /// 3838 /// \returns The load instruction and the byte offset it is moved into. 3839 static std::optional<std::pair<GZExtLoad *, int64_t>> 3840 matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, 3841 const MachineRegisterInfo &MRI) { 3842 assert(MRI.hasOneNonDBGUse(Reg) && 3843 "Expected Reg to only have one non-debug use?"); 3844 Register MaybeLoad; 3845 int64_t Shift; 3846 if (!mi_match(Reg, MRI, 3847 m_OneNonDBGUse(m_GShl(m_Reg(MaybeLoad), m_ICst(Shift))))) { 3848 Shift = 0; 3849 MaybeLoad = Reg; 3850 } 3851 3852 if (Shift % MemSizeInBits != 0) 3853 return std::nullopt; 3854 3855 // TODO: Handle other types of loads. 3856 auto *Load = getOpcodeDef<GZExtLoad>(MaybeLoad, MRI); 3857 if (!Load) 3858 return std::nullopt; 3859 3860 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) 3861 return std::nullopt; 3862 3863 return std::make_pair(Load, Shift / MemSizeInBits); 3864 } 3865 3866 std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> 3867 CombinerHelper::findLoadOffsetsForLoadOrCombine( 3868 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, 3869 const SmallVector<Register, 8> &RegsToVisit, 3870 const unsigned MemSizeInBits) const { 3871 3872 // Each load found for the pattern. There should be one for each RegsToVisit. 3873 SmallSetVector<const MachineInstr *, 8> Loads; 3874 3875 // The lowest index used in any load. (The lowest "i" for each x[i].) 3876 int64_t LowestIdx = INT64_MAX; 3877 3878 // The load which uses the lowest index. 3879 GZExtLoad *LowestIdxLoad = nullptr; 3880 3881 // Keeps track of the load indices we see. We shouldn't see any indices twice. 3882 SmallSet<int64_t, 8> SeenIdx; 3883 3884 // Ensure each load is in the same MBB. 3885 // TODO: Support multiple MachineBasicBlocks. 3886 MachineBasicBlock *MBB = nullptr; 3887 const MachineMemOperand *MMO = nullptr; 3888 3889 // Earliest instruction-order load in the pattern. 3890 GZExtLoad *EarliestLoad = nullptr; 3891 3892 // Latest instruction-order load in the pattern. 3893 GZExtLoad *LatestLoad = nullptr; 3894 3895 // Base pointer which every load should share. 3896 Register BasePtr; 3897 3898 // We want to find a load for each register. Each load should have some 3899 // appropriate bit twiddling arithmetic. During this loop, we will also keep 3900 // track of the load which uses the lowest index. Later, we will check if we 3901 // can use its pointer in the final, combined load. 3902 for (auto Reg : RegsToVisit) { 3903 // Find the load, and find the position that it will end up in (e.g. a 3904 // shifted) value. 3905 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); 3906 if (!LoadAndPos) 3907 return std::nullopt; 3908 GZExtLoad *Load; 3909 int64_t DstPos; 3910 std::tie(Load, DstPos) = *LoadAndPos; 3911 3912 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because 3913 // it is difficult to check for stores/calls/etc between loads. 3914 MachineBasicBlock *LoadMBB = Load->getParent(); 3915 if (!MBB) 3916 MBB = LoadMBB; 3917 if (LoadMBB != MBB) 3918 return std::nullopt; 3919 3920 // Make sure that the MachineMemOperands of every seen load are compatible. 3921 auto &LoadMMO = Load->getMMO(); 3922 if (!MMO) 3923 MMO = &LoadMMO; 3924 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) 3925 return std::nullopt; 3926 3927 // Find out what the base pointer and index for the load is. 3928 Register LoadPtr; 3929 int64_t Idx; 3930 if (!mi_match(Load->getOperand(1).getReg(), MRI, 3931 m_GPtrAdd(m_Reg(LoadPtr), m_ICst(Idx)))) { 3932 LoadPtr = Load->getOperand(1).getReg(); 3933 Idx = 0; 3934 } 3935 3936 // Don't combine things like a[i], a[i] -> a bigger load. 3937 if (!SeenIdx.insert(Idx).second) 3938 return std::nullopt; 3939 3940 // Every load must share the same base pointer; don't combine things like: 3941 // 3942 // a[i], b[i + 1] -> a bigger load. 3943 if (!BasePtr.isValid()) 3944 BasePtr = LoadPtr; 3945 if (BasePtr != LoadPtr) 3946 return std::nullopt; 3947 3948 if (Idx < LowestIdx) { 3949 LowestIdx = Idx; 3950 LowestIdxLoad = Load; 3951 } 3952 3953 // Keep track of the byte offset that this load ends up at. If we have seen 3954 // the byte offset, then stop here. We do not want to combine: 3955 // 3956 // a[i] << 16, a[i + k] << 16 -> a bigger load. 3957 if (!MemOffset2Idx.try_emplace(DstPos, Idx).second) 3958 return std::nullopt; 3959 Loads.insert(Load); 3960 3961 // Keep track of the position of the earliest/latest loads in the pattern. 3962 // We will check that there are no load fold barriers between them later 3963 // on. 3964 // 3965 // FIXME: Is there a better way to check for load fold barriers? 3966 if (!EarliestLoad || dominates(*Load, *EarliestLoad)) 3967 EarliestLoad = Load; 3968 if (!LatestLoad || dominates(*LatestLoad, *Load)) 3969 LatestLoad = Load; 3970 } 3971 3972 // We found a load for each register. Let's check if each load satisfies the 3973 // pattern. 3974 assert(Loads.size() == RegsToVisit.size() && 3975 "Expected to find a load for each register?"); 3976 assert(EarliestLoad != LatestLoad && EarliestLoad && 3977 LatestLoad && "Expected at least two loads?"); 3978 3979 // Check if there are any stores, calls, etc. between any of the loads. If 3980 // there are, then we can't safely perform the combine. 3981 // 3982 // MaxIter is chosen based off the (worst case) number of iterations it 3983 // typically takes to succeed in the LLVM test suite plus some padding. 3984 // 3985 // FIXME: Is there a better way to check for load fold barriers? 3986 const unsigned MaxIter = 20; 3987 unsigned Iter = 0; 3988 for (const auto &MI : instructionsWithoutDebug(EarliestLoad->getIterator(), 3989 LatestLoad->getIterator())) { 3990 if (Loads.count(&MI)) 3991 continue; 3992 if (MI.isLoadFoldBarrier()) 3993 return std::nullopt; 3994 if (Iter++ == MaxIter) 3995 return std::nullopt; 3996 } 3997 3998 return std::make_tuple(LowestIdxLoad, LowestIdx, LatestLoad); 3999 } 4000 4001 bool CombinerHelper::matchLoadOrCombine( 4002 MachineInstr &MI, 4003 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4004 assert(MI.getOpcode() == TargetOpcode::G_OR); 4005 MachineFunction &MF = *MI.getMF(); 4006 // Assuming a little-endian target, transform: 4007 // s8 *a = ... 4008 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) 4009 // => 4010 // s32 val = *((i32)a) 4011 // 4012 // s8 *a = ... 4013 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] 4014 // => 4015 // s32 val = BSWAP(*((s32)a)) 4016 Register Dst = MI.getOperand(0).getReg(); 4017 LLT Ty = MRI.getType(Dst); 4018 if (Ty.isVector()) 4019 return false; 4020 4021 // We need to combine at least two loads into this type. Since the smallest 4022 // possible load is into a byte, we need at least a 16-bit wide type. 4023 const unsigned WideMemSizeInBits = Ty.getSizeInBits(); 4024 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) 4025 return false; 4026 4027 // Match a collection of non-OR instructions in the pattern. 4028 auto RegsToVisit = findCandidatesForLoadOrCombine(&MI); 4029 if (!RegsToVisit) 4030 return false; 4031 4032 // We have a collection of non-OR instructions. Figure out how wide each of 4033 // the small loads should be based off of the number of potential loads we 4034 // found. 4035 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); 4036 if (NarrowMemSizeInBits % 8 != 0) 4037 return false; 4038 4039 // Check if each register feeding into each OR is a load from the same 4040 // base pointer + some arithmetic. 4041 // 4042 // e.g. a[0], a[1] << 8, a[2] << 16, etc. 4043 // 4044 // Also verify that each of these ends up putting a[i] into the same memory 4045 // offset as a load into a wide type would. 4046 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; 4047 GZExtLoad *LowestIdxLoad, *LatestLoad; 4048 int64_t LowestIdx; 4049 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( 4050 MemOffset2Idx, *RegsToVisit, NarrowMemSizeInBits); 4051 if (!MaybeLoadInfo) 4052 return false; 4053 std::tie(LowestIdxLoad, LowestIdx, LatestLoad) = *MaybeLoadInfo; 4054 4055 // We have a bunch of loads being OR'd together. Using the addresses + offsets 4056 // we found before, check if this corresponds to a big or little endian byte 4057 // pattern. If it does, then we can represent it using a load + possibly a 4058 // BSWAP. 4059 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); 4060 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); 4061 if (!IsBigEndian) 4062 return false; 4063 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; 4064 if (NeedsBSwap && !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {Ty}})) 4065 return false; 4066 4067 // Make sure that the load from the lowest index produces offset 0 in the 4068 // final value. 4069 // 4070 // This ensures that we won't combine something like this: 4071 // 4072 // load x[i] -> byte 2 4073 // load x[i+1] -> byte 0 ---> wide_load x[i] 4074 // load x[i+2] -> byte 1 4075 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; 4076 const unsigned ZeroByteOffset = 4077 *IsBigEndian 4078 ? bigEndianByteAt(NumLoadsInTy, 0) 4079 : littleEndianByteAt(NumLoadsInTy, 0); 4080 auto ZeroOffsetIdx = MemOffset2Idx.find(ZeroByteOffset); 4081 if (ZeroOffsetIdx == MemOffset2Idx.end() || 4082 ZeroOffsetIdx->second != LowestIdx) 4083 return false; 4084 4085 // We wil reuse the pointer from the load which ends up at byte offset 0. It 4086 // may not use index 0. 4087 Register Ptr = LowestIdxLoad->getPointerReg(); 4088 const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); 4089 LegalityQuery::MemDesc MMDesc(MMO); 4090 MMDesc.MemoryTy = Ty; 4091 if (!isLegalOrBeforeLegalizer( 4092 {TargetOpcode::G_LOAD, {Ty, MRI.getType(Ptr)}, {MMDesc}})) 4093 return false; 4094 auto PtrInfo = MMO.getPointerInfo(); 4095 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, WideMemSizeInBits / 8); 4096 4097 // Load must be allowed and fast on the target. 4098 LLVMContext &C = MF.getFunction().getContext(); 4099 auto &DL = MF.getDataLayout(); 4100 unsigned Fast = 0; 4101 if (!getTargetLowering().allowsMemoryAccess(C, DL, Ty, *NewMMO, &Fast) || 4102 !Fast) 4103 return false; 4104 4105 MatchInfo = [=](MachineIRBuilder &MIB) { 4106 MIB.setInstrAndDebugLoc(*LatestLoad); 4107 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(Dst) : Dst; 4108 MIB.buildLoad(LoadDst, Ptr, *NewMMO); 4109 if (NeedsBSwap) 4110 MIB.buildBSwap(Dst, LoadDst); 4111 }; 4112 return true; 4113 } 4114 4115 bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, 4116 MachineInstr *&ExtMI) const { 4117 auto &PHI = cast<GPhi>(MI); 4118 Register DstReg = PHI.getReg(0); 4119 4120 // TODO: Extending a vector may be expensive, don't do this until heuristics 4121 // are better. 4122 if (MRI.getType(DstReg).isVector()) 4123 return false; 4124 4125 // Try to match a phi, whose only use is an extend. 4126 if (!MRI.hasOneNonDBGUse(DstReg)) 4127 return false; 4128 ExtMI = &*MRI.use_instr_nodbg_begin(DstReg); 4129 switch (ExtMI->getOpcode()) { 4130 case TargetOpcode::G_ANYEXT: 4131 return true; // G_ANYEXT is usually free. 4132 case TargetOpcode::G_ZEXT: 4133 case TargetOpcode::G_SEXT: 4134 break; 4135 default: 4136 return false; 4137 } 4138 4139 // If the target is likely to fold this extend away, don't propagate. 4140 if (Builder.getTII().isExtendLikelyToBeFolded(*ExtMI, MRI)) 4141 return false; 4142 4143 // We don't want to propagate the extends unless there's a good chance that 4144 // they'll be optimized in some way. 4145 // Collect the unique incoming values. 4146 SmallPtrSet<MachineInstr *, 4> InSrcs; 4147 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { 4148 auto *DefMI = getDefIgnoringCopies(PHI.getIncomingValue(I), MRI); 4149 switch (DefMI->getOpcode()) { 4150 case TargetOpcode::G_LOAD: 4151 case TargetOpcode::G_TRUNC: 4152 case TargetOpcode::G_SEXT: 4153 case TargetOpcode::G_ZEXT: 4154 case TargetOpcode::G_ANYEXT: 4155 case TargetOpcode::G_CONSTANT: 4156 InSrcs.insert(DefMI); 4157 // Don't try to propagate if there are too many places to create new 4158 // extends, chances are it'll increase code size. 4159 if (InSrcs.size() > 2) 4160 return false; 4161 break; 4162 default: 4163 return false; 4164 } 4165 } 4166 return true; 4167 } 4168 4169 void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, 4170 MachineInstr *&ExtMI) const { 4171 auto &PHI = cast<GPhi>(MI); 4172 Register DstReg = ExtMI->getOperand(0).getReg(); 4173 LLT ExtTy = MRI.getType(DstReg); 4174 4175 // Propagate the extension into the block of each incoming reg's block. 4176 // Use a SetVector here because PHIs can have duplicate edges, and we want 4177 // deterministic iteration order. 4178 SmallSetVector<MachineInstr *, 8> SrcMIs; 4179 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; 4180 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { 4181 auto SrcReg = PHI.getIncomingValue(I); 4182 auto *SrcMI = MRI.getVRegDef(SrcReg); 4183 if (!SrcMIs.insert(SrcMI)) 4184 continue; 4185 4186 // Build an extend after each src inst. 4187 auto *MBB = SrcMI->getParent(); 4188 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); 4189 if (InsertPt != MBB->end() && InsertPt->isPHI()) 4190 InsertPt = MBB->getFirstNonPHI(); 4191 4192 Builder.setInsertPt(*SrcMI->getParent(), InsertPt); 4193 Builder.setDebugLoc(MI.getDebugLoc()); 4194 auto NewExt = Builder.buildExtOrTrunc(ExtMI->getOpcode(), ExtTy, SrcReg); 4195 OldToNewSrcMap[SrcMI] = NewExt; 4196 } 4197 4198 // Create a new phi with the extended inputs. 4199 Builder.setInstrAndDebugLoc(MI); 4200 auto NewPhi = Builder.buildInstrNoInsert(TargetOpcode::G_PHI); 4201 NewPhi.addDef(DstReg); 4202 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) { 4203 if (!MO.isReg()) { 4204 NewPhi.addMBB(MO.getMBB()); 4205 continue; 4206 } 4207 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(MO.getReg())]; 4208 NewPhi.addUse(NewSrc->getOperand(0).getReg()); 4209 } 4210 Builder.insertInstr(NewPhi); 4211 ExtMI->eraseFromParent(); 4212 } 4213 4214 bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI, 4215 Register &Reg) const { 4216 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); 4217 // If we have a constant index, look for a G_BUILD_VECTOR source 4218 // and find the source register that the index maps to. 4219 Register SrcVec = MI.getOperand(1).getReg(); 4220 LLT SrcTy = MRI.getType(SrcVec); 4221 if (SrcTy.isScalableVector()) 4222 return false; 4223 4224 auto Cst = getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 4225 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) 4226 return false; 4227 4228 unsigned VecIdx = Cst->Value.getZExtValue(); 4229 4230 // Check if we have a build_vector or build_vector_trunc with an optional 4231 // trunc in front. 4232 MachineInstr *SrcVecMI = MRI.getVRegDef(SrcVec); 4233 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { 4234 SrcVecMI = MRI.getVRegDef(SrcVecMI->getOperand(1).getReg()); 4235 } 4236 4237 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && 4238 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) 4239 return false; 4240 4241 EVT Ty(getMVTForLLT(SrcTy)); 4242 if (!MRI.hasOneNonDBGUse(SrcVec) && 4243 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 4244 return false; 4245 4246 Reg = SrcVecMI->getOperand(VecIdx + 1).getReg(); 4247 return true; 4248 } 4249 4250 void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI, 4251 Register &Reg) const { 4252 // Check the type of the register, since it may have come from a 4253 // G_BUILD_VECTOR_TRUNC. 4254 LLT ScalarTy = MRI.getType(Reg); 4255 Register DstReg = MI.getOperand(0).getReg(); 4256 LLT DstTy = MRI.getType(DstReg); 4257 4258 if (ScalarTy != DstTy) { 4259 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); 4260 Builder.buildTrunc(DstReg, Reg); 4261 MI.eraseFromParent(); 4262 return; 4263 } 4264 replaceSingleDefInstWithReg(MI, Reg); 4265 } 4266 4267 bool CombinerHelper::matchExtractAllEltsFromBuildVector( 4268 MachineInstr &MI, 4269 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { 4270 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); 4271 // This combine tries to find build_vector's which have every source element 4272 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like 4273 // the masked load scalarization is run late in the pipeline. There's already 4274 // a combine for a similar pattern starting from the extract, but that 4275 // doesn't attempt to do it if there are multiple uses of the build_vector, 4276 // which in this case is true. Starting the combine from the build_vector 4277 // feels more natural than trying to find sibling nodes of extracts. 4278 // E.g. 4279 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 4280 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 4281 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 4282 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 4283 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 4284 // ==> 4285 // replace ext{1,2,3,4} with %s{1,2,3,4} 4286 4287 Register DstReg = MI.getOperand(0).getReg(); 4288 LLT DstTy = MRI.getType(DstReg); 4289 unsigned NumElts = DstTy.getNumElements(); 4290 4291 SmallBitVector ExtractedElts(NumElts); 4292 for (MachineInstr &II : MRI.use_nodbg_instructions(DstReg)) { 4293 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) 4294 return false; 4295 auto Cst = getIConstantVRegVal(II.getOperand(2).getReg(), MRI); 4296 if (!Cst) 4297 return false; 4298 unsigned Idx = Cst->getZExtValue(); 4299 if (Idx >= NumElts) 4300 return false; // Out of range. 4301 ExtractedElts.set(Idx); 4302 SrcDstPairs.emplace_back( 4303 std::make_pair(MI.getOperand(Idx + 1).getReg(), &II)); 4304 } 4305 // Match if every element was extracted. 4306 return ExtractedElts.all(); 4307 } 4308 4309 void CombinerHelper::applyExtractAllEltsFromBuildVector( 4310 MachineInstr &MI, 4311 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { 4312 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); 4313 for (auto &Pair : SrcDstPairs) { 4314 auto *ExtMI = Pair.second; 4315 replaceRegWith(MRI, ExtMI->getOperand(0).getReg(), Pair.first); 4316 ExtMI->eraseFromParent(); 4317 } 4318 MI.eraseFromParent(); 4319 } 4320 4321 void CombinerHelper::applyBuildFn( 4322 MachineInstr &MI, 4323 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4324 applyBuildFnNoErase(MI, MatchInfo); 4325 MI.eraseFromParent(); 4326 } 4327 4328 void CombinerHelper::applyBuildFnNoErase( 4329 MachineInstr &MI, 4330 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4331 MatchInfo(Builder); 4332 } 4333 4334 bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, 4335 BuildFnTy &MatchInfo) const { 4336 assert(MI.getOpcode() == TargetOpcode::G_OR); 4337 4338 Register Dst = MI.getOperand(0).getReg(); 4339 LLT Ty = MRI.getType(Dst); 4340 unsigned BitWidth = Ty.getScalarSizeInBits(); 4341 4342 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; 4343 unsigned FshOpc = 0; 4344 4345 // Match (or (shl ...), (lshr ...)). 4346 if (!mi_match(Dst, MRI, 4347 // m_GOr() handles the commuted version as well. 4348 m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)), 4349 m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt))))) 4350 return false; 4351 4352 // Given constants C0 and C1 such that C0 + C1 is bit-width: 4353 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) 4354 int64_t CstShlAmt, CstLShrAmt; 4355 if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) && 4356 mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) && 4357 CstShlAmt + CstLShrAmt == BitWidth) { 4358 FshOpc = TargetOpcode::G_FSHR; 4359 Amt = LShrAmt; 4360 4361 } else if (mi_match(LShrAmt, MRI, 4362 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) && 4363 ShlAmt == Amt) { 4364 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) 4365 FshOpc = TargetOpcode::G_FSHL; 4366 4367 } else if (mi_match(ShlAmt, MRI, 4368 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) && 4369 LShrAmt == Amt) { 4370 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) 4371 FshOpc = TargetOpcode::G_FSHR; 4372 4373 } else { 4374 return false; 4375 } 4376 4377 LLT AmtTy = MRI.getType(Amt); 4378 if (!isLegalOrBeforeLegalizer({FshOpc, {Ty, AmtTy}})) 4379 return false; 4380 4381 MatchInfo = [=](MachineIRBuilder &B) { 4382 B.buildInstr(FshOpc, {Dst}, {ShlSrc, LShrSrc, Amt}); 4383 }; 4384 return true; 4385 } 4386 4387 /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. 4388 bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const { 4389 unsigned Opc = MI.getOpcode(); 4390 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); 4391 Register X = MI.getOperand(1).getReg(); 4392 Register Y = MI.getOperand(2).getReg(); 4393 if (X != Y) 4394 return false; 4395 unsigned RotateOpc = 4396 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; 4397 return isLegalOrBeforeLegalizer({RotateOpc, {MRI.getType(X), MRI.getType(Y)}}); 4398 } 4399 4400 void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const { 4401 unsigned Opc = MI.getOpcode(); 4402 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); 4403 bool IsFSHL = Opc == TargetOpcode::G_FSHL; 4404 Observer.changingInstr(MI); 4405 MI.setDesc(Builder.getTII().get(IsFSHL ? TargetOpcode::G_ROTL 4406 : TargetOpcode::G_ROTR)); 4407 MI.removeOperand(2); 4408 Observer.changedInstr(MI); 4409 } 4410 4411 // Fold (rot x, c) -> (rot x, c % BitSize) 4412 bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const { 4413 assert(MI.getOpcode() == TargetOpcode::G_ROTL || 4414 MI.getOpcode() == TargetOpcode::G_ROTR); 4415 unsigned Bitsize = 4416 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits(); 4417 Register AmtReg = MI.getOperand(2).getReg(); 4418 bool OutOfRange = false; 4419 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { 4420 if (auto *CI = dyn_cast<ConstantInt>(C)) 4421 OutOfRange |= CI->getValue().uge(Bitsize); 4422 return true; 4423 }; 4424 return matchUnaryPredicate(MRI, AmtReg, MatchOutOfRange) && OutOfRange; 4425 } 4426 4427 void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const { 4428 assert(MI.getOpcode() == TargetOpcode::G_ROTL || 4429 MI.getOpcode() == TargetOpcode::G_ROTR); 4430 unsigned Bitsize = 4431 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits(); 4432 Register Amt = MI.getOperand(2).getReg(); 4433 LLT AmtTy = MRI.getType(Amt); 4434 auto Bits = Builder.buildConstant(AmtTy, Bitsize); 4435 Amt = Builder.buildURem(AmtTy, MI.getOperand(2).getReg(), Bits).getReg(0); 4436 Observer.changingInstr(MI); 4437 MI.getOperand(2).setReg(Amt); 4438 Observer.changedInstr(MI); 4439 } 4440 4441 bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, 4442 int64_t &MatchInfo) const { 4443 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 4444 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 4445 4446 // We want to avoid calling KnownBits on the LHS if possible, as this combine 4447 // has no filter and runs on every G_ICMP instruction. We can avoid calling 4448 // KnownBits on the LHS in two cases: 4449 // 4450 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown 4451 // we cannot do any transforms so we can safely bail out early. 4452 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and 4453 // >=0. 4454 auto KnownRHS = KB->getKnownBits(MI.getOperand(3).getReg()); 4455 if (KnownRHS.isUnknown()) 4456 return false; 4457 4458 std::optional<bool> KnownVal; 4459 if (KnownRHS.isZero()) { 4460 // ? uge 0 -> always true 4461 // ? ult 0 -> always false 4462 if (Pred == CmpInst::ICMP_UGE) 4463 KnownVal = true; 4464 else if (Pred == CmpInst::ICMP_ULT) 4465 KnownVal = false; 4466 } 4467 4468 if (!KnownVal) { 4469 auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg()); 4470 KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred); 4471 } 4472 4473 if (!KnownVal) 4474 return false; 4475 MatchInfo = 4476 *KnownVal 4477 ? getICmpTrueVal(getTargetLowering(), 4478 /*IsVector = */ 4479 MRI.getType(MI.getOperand(0).getReg()).isVector(), 4480 /* IsFP = */ false) 4481 : 0; 4482 return true; 4483 } 4484 4485 bool CombinerHelper::matchICmpToLHSKnownBits( 4486 MachineInstr &MI, 4487 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4488 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 4489 // Given: 4490 // 4491 // %x = G_WHATEVER (... x is known to be 0 or 1 ...) 4492 // %cmp = G_ICMP ne %x, 0 4493 // 4494 // Or: 4495 // 4496 // %x = G_WHATEVER (... x is known to be 0 or 1 ...) 4497 // %cmp = G_ICMP eq %x, 1 4498 // 4499 // We can replace %cmp with %x assuming true is 1 on the target. 4500 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 4501 if (!CmpInst::isEquality(Pred)) 4502 return false; 4503 Register Dst = MI.getOperand(0).getReg(); 4504 LLT DstTy = MRI.getType(Dst); 4505 if (getICmpTrueVal(getTargetLowering(), DstTy.isVector(), 4506 /* IsFP = */ false) != 1) 4507 return false; 4508 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; 4509 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(OneOrZero))) 4510 return false; 4511 Register LHS = MI.getOperand(2).getReg(); 4512 auto KnownLHS = KB->getKnownBits(LHS); 4513 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) 4514 return false; 4515 // Make sure replacing Dst with the LHS is a legal operation. 4516 LLT LHSTy = MRI.getType(LHS); 4517 unsigned LHSSize = LHSTy.getSizeInBits(); 4518 unsigned DstSize = DstTy.getSizeInBits(); 4519 unsigned Op = TargetOpcode::COPY; 4520 if (DstSize != LHSSize) 4521 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; 4522 if (!isLegalOrBeforeLegalizer({Op, {DstTy, LHSTy}})) 4523 return false; 4524 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Op, {Dst}, {LHS}); }; 4525 return true; 4526 } 4527 4528 // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 4529 bool CombinerHelper::matchAndOrDisjointMask( 4530 MachineInstr &MI, 4531 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4532 assert(MI.getOpcode() == TargetOpcode::G_AND); 4533 4534 // Ignore vector types to simplify matching the two constants. 4535 // TODO: do this for vectors and scalars via a demanded bits analysis. 4536 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 4537 if (Ty.isVector()) 4538 return false; 4539 4540 Register Src; 4541 Register AndMaskReg; 4542 int64_t AndMaskBits; 4543 int64_t OrMaskBits; 4544 if (!mi_match(MI, MRI, 4545 m_GAnd(m_GOr(m_Reg(Src), m_ICst(OrMaskBits)), 4546 m_all_of(m_ICst(AndMaskBits), m_Reg(AndMaskReg))))) 4547 return false; 4548 4549 // Check if OrMask could turn on any bits in Src. 4550 if (AndMaskBits & OrMaskBits) 4551 return false; 4552 4553 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4554 Observer.changingInstr(MI); 4555 // Canonicalize the result to have the constant on the RHS. 4556 if (MI.getOperand(1).getReg() == AndMaskReg) 4557 MI.getOperand(2).setReg(AndMaskReg); 4558 MI.getOperand(1).setReg(Src); 4559 Observer.changedInstr(MI); 4560 }; 4561 return true; 4562 } 4563 4564 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. 4565 bool CombinerHelper::matchBitfieldExtractFromSExtInReg( 4566 MachineInstr &MI, 4567 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4568 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 4569 Register Dst = MI.getOperand(0).getReg(); 4570 Register Src = MI.getOperand(1).getReg(); 4571 LLT Ty = MRI.getType(Src); 4572 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4573 if (!LI || !LI->isLegalOrCustom({TargetOpcode::G_SBFX, {Ty, ExtractTy}})) 4574 return false; 4575 int64_t Width = MI.getOperand(2).getImm(); 4576 Register ShiftSrc; 4577 int64_t ShiftImm; 4578 if (!mi_match( 4579 Src, MRI, 4580 m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)), 4581 m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)))))) 4582 return false; 4583 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) 4584 return false; 4585 4586 MatchInfo = [=](MachineIRBuilder &B) { 4587 auto Cst1 = B.buildConstant(ExtractTy, ShiftImm); 4588 auto Cst2 = B.buildConstant(ExtractTy, Width); 4589 B.buildSbfx(Dst, ShiftSrc, Cst1, Cst2); 4590 }; 4591 return true; 4592 } 4593 4594 /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. 4595 bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI, 4596 BuildFnTy &MatchInfo) const { 4597 GAnd *And = cast<GAnd>(&MI); 4598 Register Dst = And->getReg(0); 4599 LLT Ty = MRI.getType(Dst); 4600 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4601 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom 4602 // into account. 4603 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}})) 4604 return false; 4605 4606 int64_t AndImm, LSBImm; 4607 Register ShiftSrc; 4608 const unsigned Size = Ty.getScalarSizeInBits(); 4609 if (!mi_match(And->getReg(0), MRI, 4610 m_GAnd(m_OneNonDBGUse(m_GLShr(m_Reg(ShiftSrc), m_ICst(LSBImm))), 4611 m_ICst(AndImm)))) 4612 return false; 4613 4614 // The mask is a mask of the low bits iff imm & (imm+1) == 0. 4615 auto MaybeMask = static_cast<uint64_t>(AndImm); 4616 if (MaybeMask & (MaybeMask + 1)) 4617 return false; 4618 4619 // LSB must fit within the register. 4620 if (static_cast<uint64_t>(LSBImm) >= Size) 4621 return false; 4622 4623 uint64_t Width = APInt(Size, AndImm).countr_one(); 4624 MatchInfo = [=](MachineIRBuilder &B) { 4625 auto WidthCst = B.buildConstant(ExtractTy, Width); 4626 auto LSBCst = B.buildConstant(ExtractTy, LSBImm); 4627 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {ShiftSrc, LSBCst, WidthCst}); 4628 }; 4629 return true; 4630 } 4631 4632 bool CombinerHelper::matchBitfieldExtractFromShr( 4633 MachineInstr &MI, 4634 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4635 const unsigned Opcode = MI.getOpcode(); 4636 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); 4637 4638 const Register Dst = MI.getOperand(0).getReg(); 4639 4640 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR 4641 ? TargetOpcode::G_SBFX 4642 : TargetOpcode::G_UBFX; 4643 4644 // Check if the type we would use for the extract is legal 4645 LLT Ty = MRI.getType(Dst); 4646 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4647 if (!LI || !LI->isLegalOrCustom({ExtrOpcode, {Ty, ExtractTy}})) 4648 return false; 4649 4650 Register ShlSrc; 4651 int64_t ShrAmt; 4652 int64_t ShlAmt; 4653 const unsigned Size = Ty.getScalarSizeInBits(); 4654 4655 // Try to match shr (shl x, c1), c2 4656 if (!mi_match(Dst, MRI, 4657 m_BinOp(Opcode, 4658 m_OneNonDBGUse(m_GShl(m_Reg(ShlSrc), m_ICst(ShlAmt))), 4659 m_ICst(ShrAmt)))) 4660 return false; 4661 4662 // Make sure that the shift sizes can fit a bitfield extract 4663 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) 4664 return false; 4665 4666 // Skip this combine if the G_SEXT_INREG combine could handle it 4667 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) 4668 return false; 4669 4670 // Calculate start position and width of the extract 4671 const int64_t Pos = ShrAmt - ShlAmt; 4672 const int64_t Width = Size - ShrAmt; 4673 4674 MatchInfo = [=](MachineIRBuilder &B) { 4675 auto WidthCst = B.buildConstant(ExtractTy, Width); 4676 auto PosCst = B.buildConstant(ExtractTy, Pos); 4677 B.buildInstr(ExtrOpcode, {Dst}, {ShlSrc, PosCst, WidthCst}); 4678 }; 4679 return true; 4680 } 4681 4682 bool CombinerHelper::matchBitfieldExtractFromShrAnd( 4683 MachineInstr &MI, 4684 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4685 const unsigned Opcode = MI.getOpcode(); 4686 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); 4687 4688 const Register Dst = MI.getOperand(0).getReg(); 4689 LLT Ty = MRI.getType(Dst); 4690 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4691 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}})) 4692 return false; 4693 4694 // Try to match shr (and x, c1), c2 4695 Register AndSrc; 4696 int64_t ShrAmt; 4697 int64_t SMask; 4698 if (!mi_match(Dst, MRI, 4699 m_BinOp(Opcode, 4700 m_OneNonDBGUse(m_GAnd(m_Reg(AndSrc), m_ICst(SMask))), 4701 m_ICst(ShrAmt)))) 4702 return false; 4703 4704 const unsigned Size = Ty.getScalarSizeInBits(); 4705 if (ShrAmt < 0 || ShrAmt >= Size) 4706 return false; 4707 4708 // If the shift subsumes the mask, emit the 0 directly. 4709 if (0 == (SMask >> ShrAmt)) { 4710 MatchInfo = [=](MachineIRBuilder &B) { 4711 B.buildConstant(Dst, 0); 4712 }; 4713 return true; 4714 } 4715 4716 // Check that ubfx can do the extraction, with no holes in the mask. 4717 uint64_t UMask = SMask; 4718 UMask |= maskTrailingOnes<uint64_t>(ShrAmt); 4719 UMask &= maskTrailingOnes<uint64_t>(Size); 4720 if (!isMask_64(UMask)) 4721 return false; 4722 4723 // Calculate start position and width of the extract. 4724 const int64_t Pos = ShrAmt; 4725 const int64_t Width = llvm::countr_one(UMask) - ShrAmt; 4726 4727 // It's preferable to keep the shift, rather than form G_SBFX. 4728 // TODO: remove the G_AND via demanded bits analysis. 4729 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) 4730 return false; 4731 4732 MatchInfo = [=](MachineIRBuilder &B) { 4733 auto WidthCst = B.buildConstant(ExtractTy, Width); 4734 auto PosCst = B.buildConstant(ExtractTy, Pos); 4735 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {AndSrc, PosCst, WidthCst}); 4736 }; 4737 return true; 4738 } 4739 4740 bool CombinerHelper::reassociationCanBreakAddressingModePattern( 4741 MachineInstr &MI) const { 4742 auto &PtrAdd = cast<GPtrAdd>(MI); 4743 4744 Register Src1Reg = PtrAdd.getBaseReg(); 4745 auto *Src1Def = getOpcodeDef<GPtrAdd>(Src1Reg, MRI); 4746 if (!Src1Def) 4747 return false; 4748 4749 Register Src2Reg = PtrAdd.getOffsetReg(); 4750 4751 if (MRI.hasOneNonDBGUse(Src1Reg)) 4752 return false; 4753 4754 auto C1 = getIConstantVRegVal(Src1Def->getOffsetReg(), MRI); 4755 if (!C1) 4756 return false; 4757 auto C2 = getIConstantVRegVal(Src2Reg, MRI); 4758 if (!C2) 4759 return false; 4760 4761 const APInt &C1APIntVal = *C1; 4762 const APInt &C2APIntVal = *C2; 4763 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); 4764 4765 for (auto &UseMI : MRI.use_nodbg_instructions(PtrAdd.getReg(0))) { 4766 // This combine may end up running before ptrtoint/inttoptr combines 4767 // manage to eliminate redundant conversions, so try to look through them. 4768 MachineInstr *ConvUseMI = &UseMI; 4769 unsigned ConvUseOpc = ConvUseMI->getOpcode(); 4770 while (ConvUseOpc == TargetOpcode::G_INTTOPTR || 4771 ConvUseOpc == TargetOpcode::G_PTRTOINT) { 4772 Register DefReg = ConvUseMI->getOperand(0).getReg(); 4773 if (!MRI.hasOneNonDBGUse(DefReg)) 4774 break; 4775 ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg); 4776 ConvUseOpc = ConvUseMI->getOpcode(); 4777 } 4778 auto *LdStMI = dyn_cast<GLoadStore>(ConvUseMI); 4779 if (!LdStMI) 4780 continue; 4781 // Is x[offset2] already not a legal addressing mode? If so then 4782 // reassociating the constants breaks nothing (we test offset2 because 4783 // that's the one we hope to fold into the load or store). 4784 TargetLoweringBase::AddrMode AM; 4785 AM.HasBaseReg = true; 4786 AM.BaseOffs = C2APIntVal.getSExtValue(); 4787 unsigned AS = MRI.getType(LdStMI->getPointerReg()).getAddressSpace(); 4788 Type *AccessTy = getTypeForLLT(LdStMI->getMMO().getMemoryType(), 4789 PtrAdd.getMF()->getFunction().getContext()); 4790 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); 4791 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM, 4792 AccessTy, AS)) 4793 continue; 4794 4795 // Would x[offset1+offset2] still be a legal addressing mode? 4796 AM.BaseOffs = CombinedValue; 4797 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM, 4798 AccessTy, AS)) 4799 return true; 4800 } 4801 4802 return false; 4803 } 4804 4805 bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, 4806 MachineInstr *RHS, 4807 BuildFnTy &MatchInfo) const { 4808 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) 4809 Register Src1Reg = MI.getOperand(1).getReg(); 4810 if (RHS->getOpcode() != TargetOpcode::G_ADD) 4811 return false; 4812 auto C2 = getIConstantVRegVal(RHS->getOperand(2).getReg(), MRI); 4813 if (!C2) 4814 return false; 4815 4816 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4817 LLT PtrTy = MRI.getType(MI.getOperand(0).getReg()); 4818 4819 auto NewBase = 4820 Builder.buildPtrAdd(PtrTy, Src1Reg, RHS->getOperand(1).getReg()); 4821 Observer.changingInstr(MI); 4822 MI.getOperand(1).setReg(NewBase.getReg(0)); 4823 MI.getOperand(2).setReg(RHS->getOperand(2).getReg()); 4824 Observer.changedInstr(MI); 4825 }; 4826 return !reassociationCanBreakAddressingModePattern(MI); 4827 } 4828 4829 bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, 4830 MachineInstr *LHS, 4831 MachineInstr *RHS, 4832 BuildFnTy &MatchInfo) const { 4833 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) 4834 // if and only if (G_PTR_ADD X, C) has one use. 4835 Register LHSBase; 4836 std::optional<ValueAndVReg> LHSCstOff; 4837 if (!mi_match(MI.getBaseReg(), MRI, 4838 m_OneNonDBGUse(m_GPtrAdd(m_Reg(LHSBase), m_GCst(LHSCstOff))))) 4839 return false; 4840 4841 auto *LHSPtrAdd = cast<GPtrAdd>(LHS); 4842 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4843 // When we change LHSPtrAdd's offset register we might cause it to use a reg 4844 // before its def. Sink the instruction so the outer PTR_ADD to ensure this 4845 // doesn't happen. 4846 LHSPtrAdd->moveBefore(&MI); 4847 Register RHSReg = MI.getOffsetReg(); 4848 // set VReg will cause type mismatch if it comes from extend/trunc 4849 auto NewCst = B.buildConstant(MRI.getType(RHSReg), LHSCstOff->Value); 4850 Observer.changingInstr(MI); 4851 MI.getOperand(2).setReg(NewCst.getReg(0)); 4852 Observer.changedInstr(MI); 4853 Observer.changingInstr(*LHSPtrAdd); 4854 LHSPtrAdd->getOperand(2).setReg(RHSReg); 4855 Observer.changedInstr(*LHSPtrAdd); 4856 }; 4857 return !reassociationCanBreakAddressingModePattern(MI); 4858 } 4859 4860 bool CombinerHelper::matchReassocFoldConstantsInSubTree( 4861 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS, 4862 BuildFnTy &MatchInfo) const { 4863 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) 4864 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(LHS); 4865 if (!LHSPtrAdd) 4866 return false; 4867 4868 Register Src2Reg = MI.getOperand(2).getReg(); 4869 Register LHSSrc1 = LHSPtrAdd->getBaseReg(); 4870 Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); 4871 auto C1 = getIConstantVRegVal(LHSSrc2, MRI); 4872 if (!C1) 4873 return false; 4874 auto C2 = getIConstantVRegVal(Src2Reg, MRI); 4875 if (!C2) 4876 return false; 4877 4878 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4879 auto NewCst = B.buildConstant(MRI.getType(Src2Reg), *C1 + *C2); 4880 Observer.changingInstr(MI); 4881 MI.getOperand(1).setReg(LHSSrc1); 4882 MI.getOperand(2).setReg(NewCst.getReg(0)); 4883 Observer.changedInstr(MI); 4884 }; 4885 return !reassociationCanBreakAddressingModePattern(MI); 4886 } 4887 4888 bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, 4889 BuildFnTy &MatchInfo) const { 4890 auto &PtrAdd = cast<GPtrAdd>(MI); 4891 // We're trying to match a few pointer computation patterns here for 4892 // re-association opportunities. 4893 // 1) Isolating a constant operand to be on the RHS, e.g.: 4894 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) 4895 // 4896 // 2) Folding two constants in each sub-tree as long as such folding 4897 // doesn't break a legal addressing mode. 4898 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) 4899 // 4900 // 3) Move a constant from the LHS of an inner op to the RHS of the outer. 4901 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) 4902 // iif (G_PTR_ADD X, C) has one use. 4903 MachineInstr *LHS = MRI.getVRegDef(PtrAdd.getBaseReg()); 4904 MachineInstr *RHS = MRI.getVRegDef(PtrAdd.getOffsetReg()); 4905 4906 // Try to match example 2. 4907 if (matchReassocFoldConstantsInSubTree(PtrAdd, LHS, RHS, MatchInfo)) 4908 return true; 4909 4910 // Try to match example 3. 4911 if (matchReassocConstantInnerLHS(PtrAdd, LHS, RHS, MatchInfo)) 4912 return true; 4913 4914 // Try to match example 1. 4915 if (matchReassocConstantInnerRHS(PtrAdd, RHS, MatchInfo)) 4916 return true; 4917 4918 return false; 4919 } 4920 bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, 4921 Register OpLHS, Register OpRHS, 4922 BuildFnTy &MatchInfo) const { 4923 LLT OpRHSTy = MRI.getType(OpRHS); 4924 MachineInstr *OpLHSDef = MRI.getVRegDef(OpLHS); 4925 4926 if (OpLHSDef->getOpcode() != Opc) 4927 return false; 4928 4929 MachineInstr *OpRHSDef = MRI.getVRegDef(OpRHS); 4930 Register OpLHSLHS = OpLHSDef->getOperand(1).getReg(); 4931 Register OpLHSRHS = OpLHSDef->getOperand(2).getReg(); 4932 4933 // If the inner op is (X op C), pull the constant out so it can be folded with 4934 // other constants in the expression tree. Folding is not guaranteed so we 4935 // might have (C1 op C2). In that case do not pull a constant out because it 4936 // won't help and can lead to infinite loops. 4937 if (isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSRHS), MRI) && 4938 !isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSLHS), MRI)) { 4939 if (isConstantOrConstantSplatVector(*OpRHSDef, MRI)) { 4940 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) 4941 MatchInfo = [=](MachineIRBuilder &B) { 4942 auto NewCst = B.buildInstr(Opc, {OpRHSTy}, {OpLHSRHS, OpRHS}); 4943 B.buildInstr(Opc, {DstReg}, {OpLHSLHS, NewCst}); 4944 }; 4945 return true; 4946 } 4947 if (getTargetLowering().isReassocProfitable(MRI, OpLHS, OpRHS)) { 4948 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) 4949 // iff (op x, c1) has one use 4950 MatchInfo = [=](MachineIRBuilder &B) { 4951 auto NewLHSLHS = B.buildInstr(Opc, {OpRHSTy}, {OpLHSLHS, OpRHS}); 4952 B.buildInstr(Opc, {DstReg}, {NewLHSLHS, OpLHSRHS}); 4953 }; 4954 return true; 4955 } 4956 } 4957 4958 return false; 4959 } 4960 4961 bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, 4962 BuildFnTy &MatchInfo) const { 4963 // We don't check if the reassociation will break a legal addressing mode 4964 // here since pointer arithmetic is handled by G_PTR_ADD. 4965 unsigned Opc = MI.getOpcode(); 4966 Register DstReg = MI.getOperand(0).getReg(); 4967 Register LHSReg = MI.getOperand(1).getReg(); 4968 Register RHSReg = MI.getOperand(2).getReg(); 4969 4970 if (tryReassocBinOp(Opc, DstReg, LHSReg, RHSReg, MatchInfo)) 4971 return true; 4972 if (tryReassocBinOp(Opc, DstReg, RHSReg, LHSReg, MatchInfo)) 4973 return true; 4974 return false; 4975 } 4976 4977 bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, 4978 APInt &MatchInfo) const { 4979 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 4980 Register SrcOp = MI.getOperand(1).getReg(); 4981 4982 if (auto MaybeCst = ConstantFoldCastOp(MI.getOpcode(), DstTy, SrcOp, MRI)) { 4983 MatchInfo = *MaybeCst; 4984 return true; 4985 } 4986 4987 return false; 4988 } 4989 4990 bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, 4991 APInt &MatchInfo) const { 4992 Register Op1 = MI.getOperand(1).getReg(); 4993 Register Op2 = MI.getOperand(2).getReg(); 4994 auto MaybeCst = ConstantFoldBinOp(MI.getOpcode(), Op1, Op2, MRI); 4995 if (!MaybeCst) 4996 return false; 4997 MatchInfo = *MaybeCst; 4998 return true; 4999 } 5000 5001 bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, 5002 ConstantFP *&MatchInfo) const { 5003 Register Op1 = MI.getOperand(1).getReg(); 5004 Register Op2 = MI.getOperand(2).getReg(); 5005 auto MaybeCst = ConstantFoldFPBinOp(MI.getOpcode(), Op1, Op2, MRI); 5006 if (!MaybeCst) 5007 return false; 5008 MatchInfo = 5009 ConstantFP::get(MI.getMF()->getFunction().getContext(), *MaybeCst); 5010 return true; 5011 } 5012 5013 bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, 5014 ConstantFP *&MatchInfo) const { 5015 assert(MI.getOpcode() == TargetOpcode::G_FMA || 5016 MI.getOpcode() == TargetOpcode::G_FMAD); 5017 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); 5018 5019 const ConstantFP *Op3Cst = getConstantFPVRegVal(Op3, MRI); 5020 if (!Op3Cst) 5021 return false; 5022 5023 const ConstantFP *Op2Cst = getConstantFPVRegVal(Op2, MRI); 5024 if (!Op2Cst) 5025 return false; 5026 5027 const ConstantFP *Op1Cst = getConstantFPVRegVal(Op1, MRI); 5028 if (!Op1Cst) 5029 return false; 5030 5031 APFloat Op1F = Op1Cst->getValueAPF(); 5032 Op1F.fusedMultiplyAdd(Op2Cst->getValueAPF(), Op3Cst->getValueAPF(), 5033 APFloat::rmNearestTiesToEven); 5034 MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F); 5035 return true; 5036 } 5037 5038 bool CombinerHelper::matchNarrowBinopFeedingAnd( 5039 MachineInstr &MI, 5040 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5041 // Look for a binop feeding into an AND with a mask: 5042 // 5043 // %add = G_ADD %lhs, %rhs 5044 // %and = G_AND %add, 000...11111111 5045 // 5046 // Check if it's possible to perform the binop at a narrower width and zext 5047 // back to the original width like so: 5048 // 5049 // %narrow_lhs = G_TRUNC %lhs 5050 // %narrow_rhs = G_TRUNC %rhs 5051 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs 5052 // %new_add = G_ZEXT %narrow_add 5053 // %and = G_AND %new_add, 000...11111111 5054 // 5055 // This can allow later combines to eliminate the G_AND if it turns out 5056 // that the mask is irrelevant. 5057 assert(MI.getOpcode() == TargetOpcode::G_AND); 5058 Register Dst = MI.getOperand(0).getReg(); 5059 Register AndLHS = MI.getOperand(1).getReg(); 5060 Register AndRHS = MI.getOperand(2).getReg(); 5061 LLT WideTy = MRI.getType(Dst); 5062 5063 // If the potential binop has more than one use, then it's possible that one 5064 // of those uses will need its full width. 5065 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(AndLHS)) 5066 return false; 5067 5068 // Check if the LHS feeding the AND is impacted by the high bits that we're 5069 // masking out. 5070 // 5071 // e.g. for 64-bit x, y: 5072 // 5073 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 5074 MachineInstr *LHSInst = getDefIgnoringCopies(AndLHS, MRI); 5075 if (!LHSInst) 5076 return false; 5077 unsigned LHSOpc = LHSInst->getOpcode(); 5078 switch (LHSOpc) { 5079 default: 5080 return false; 5081 case TargetOpcode::G_ADD: 5082 case TargetOpcode::G_SUB: 5083 case TargetOpcode::G_MUL: 5084 case TargetOpcode::G_AND: 5085 case TargetOpcode::G_OR: 5086 case TargetOpcode::G_XOR: 5087 break; 5088 } 5089 5090 // Find the mask on the RHS. 5091 auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI); 5092 if (!Cst) 5093 return false; 5094 auto Mask = Cst->Value; 5095 if (!Mask.isMask()) 5096 return false; 5097 5098 // No point in combining if there's nothing to truncate. 5099 unsigned NarrowWidth = Mask.countr_one(); 5100 if (NarrowWidth == WideTy.getSizeInBits()) 5101 return false; 5102 LLT NarrowTy = LLT::scalar(NarrowWidth); 5103 5104 // Check if adding the zext + truncates could be harmful. 5105 auto &MF = *MI.getMF(); 5106 const auto &TLI = getTargetLowering(); 5107 LLVMContext &Ctx = MF.getFunction().getContext(); 5108 if (!TLI.isTruncateFree(WideTy, NarrowTy, Ctx) || 5109 !TLI.isZExtFree(NarrowTy, WideTy, Ctx)) 5110 return false; 5111 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || 5112 !isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) 5113 return false; 5114 Register BinOpLHS = LHSInst->getOperand(1).getReg(); 5115 Register BinOpRHS = LHSInst->getOperand(2).getReg(); 5116 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5117 auto NarrowLHS = Builder.buildTrunc(NarrowTy, BinOpLHS); 5118 auto NarrowRHS = Builder.buildTrunc(NarrowTy, BinOpRHS); 5119 auto NarrowBinOp = 5120 Builder.buildInstr(LHSOpc, {NarrowTy}, {NarrowLHS, NarrowRHS}); 5121 auto Ext = Builder.buildZExt(WideTy, NarrowBinOp); 5122 Observer.changingInstr(MI); 5123 MI.getOperand(1).setReg(Ext.getReg(0)); 5124 Observer.changedInstr(MI); 5125 }; 5126 return true; 5127 } 5128 5129 bool CombinerHelper::matchMulOBy2(MachineInstr &MI, 5130 BuildFnTy &MatchInfo) const { 5131 unsigned Opc = MI.getOpcode(); 5132 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); 5133 5134 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2))) 5135 return false; 5136 5137 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5138 Observer.changingInstr(MI); 5139 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO 5140 : TargetOpcode::G_SADDO; 5141 MI.setDesc(Builder.getTII().get(NewOpc)); 5142 MI.getOperand(3).setReg(MI.getOperand(2).getReg()); 5143 Observer.changedInstr(MI); 5144 }; 5145 return true; 5146 } 5147 5148 bool CombinerHelper::matchMulOBy0(MachineInstr &MI, 5149 BuildFnTy &MatchInfo) const { 5150 // (G_*MULO x, 0) -> 0 + no carry out 5151 assert(MI.getOpcode() == TargetOpcode::G_UMULO || 5152 MI.getOpcode() == TargetOpcode::G_SMULO); 5153 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0))) 5154 return false; 5155 Register Dst = MI.getOperand(0).getReg(); 5156 Register Carry = MI.getOperand(1).getReg(); 5157 if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Dst)) || 5158 !isConstantLegalOrBeforeLegalizer(MRI.getType(Carry))) 5159 return false; 5160 MatchInfo = [=](MachineIRBuilder &B) { 5161 B.buildConstant(Dst, 0); 5162 B.buildConstant(Carry, 0); 5163 }; 5164 return true; 5165 } 5166 5167 bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, 5168 BuildFnTy &MatchInfo) const { 5169 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) 5170 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) 5171 assert(MI.getOpcode() == TargetOpcode::G_UADDE || 5172 MI.getOpcode() == TargetOpcode::G_SADDE || 5173 MI.getOpcode() == TargetOpcode::G_USUBE || 5174 MI.getOpcode() == TargetOpcode::G_SSUBE); 5175 if (!mi_match(MI.getOperand(4).getReg(), MRI, m_SpecificICstOrSplat(0))) 5176 return false; 5177 MatchInfo = [&](MachineIRBuilder &B) { 5178 unsigned NewOpcode; 5179 switch (MI.getOpcode()) { 5180 case TargetOpcode::G_UADDE: 5181 NewOpcode = TargetOpcode::G_UADDO; 5182 break; 5183 case TargetOpcode::G_SADDE: 5184 NewOpcode = TargetOpcode::G_SADDO; 5185 break; 5186 case TargetOpcode::G_USUBE: 5187 NewOpcode = TargetOpcode::G_USUBO; 5188 break; 5189 case TargetOpcode::G_SSUBE: 5190 NewOpcode = TargetOpcode::G_SSUBO; 5191 break; 5192 } 5193 Observer.changingInstr(MI); 5194 MI.setDesc(B.getTII().get(NewOpcode)); 5195 MI.removeOperand(4); 5196 Observer.changedInstr(MI); 5197 }; 5198 return true; 5199 } 5200 5201 bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, 5202 BuildFnTy &MatchInfo) const { 5203 assert(MI.getOpcode() == TargetOpcode::G_SUB); 5204 Register Dst = MI.getOperand(0).getReg(); 5205 // (x + y) - z -> x (if y == z) 5206 // (x + y) - z -> y (if x == z) 5207 Register X, Y, Z; 5208 if (mi_match(Dst, MRI, m_GSub(m_GAdd(m_Reg(X), m_Reg(Y)), m_Reg(Z)))) { 5209 Register ReplaceReg; 5210 int64_t CstX, CstY; 5211 if (Y == Z || (mi_match(Y, MRI, m_ICstOrSplat(CstY)) && 5212 mi_match(Z, MRI, m_SpecificICstOrSplat(CstY)))) 5213 ReplaceReg = X; 5214 else if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5215 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX)))) 5216 ReplaceReg = Y; 5217 if (ReplaceReg) { 5218 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, ReplaceReg); }; 5219 return true; 5220 } 5221 } 5222 5223 // x - (y + z) -> 0 - y (if x == z) 5224 // x - (y + z) -> 0 - z (if x == y) 5225 if (mi_match(Dst, MRI, m_GSub(m_Reg(X), m_GAdd(m_Reg(Y), m_Reg(Z))))) { 5226 Register ReplaceReg; 5227 int64_t CstX; 5228 if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5229 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX)))) 5230 ReplaceReg = Y; 5231 else if (X == Y || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5232 mi_match(Y, MRI, m_SpecificICstOrSplat(CstX)))) 5233 ReplaceReg = Z; 5234 if (ReplaceReg) { 5235 MatchInfo = [=](MachineIRBuilder &B) { 5236 auto Zero = B.buildConstant(MRI.getType(Dst), 0); 5237 B.buildSub(Dst, Zero, ReplaceReg); 5238 }; 5239 return true; 5240 } 5241 } 5242 return false; 5243 } 5244 5245 MachineInstr *CombinerHelper::buildUDivUsingMul(MachineInstr &MI) const { 5246 assert(MI.getOpcode() == TargetOpcode::G_UDIV); 5247 auto &UDiv = cast<GenericMachineInstr>(MI); 5248 Register Dst = UDiv.getReg(0); 5249 Register LHS = UDiv.getReg(1); 5250 Register RHS = UDiv.getReg(2); 5251 LLT Ty = MRI.getType(Dst); 5252 LLT ScalarTy = Ty.getScalarType(); 5253 const unsigned EltBits = ScalarTy.getScalarSizeInBits(); 5254 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5255 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); 5256 5257 auto &MIB = Builder; 5258 5259 bool UseSRL = false; 5260 SmallVector<Register, 16> Shifts, Factors; 5261 auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI)); 5262 bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value(); 5263 5264 auto BuildExactUDIVPattern = [&](const Constant *C) { 5265 // Don't recompute inverses for each splat element. 5266 if (IsSplat && !Factors.empty()) { 5267 Shifts.push_back(Shifts[0]); 5268 Factors.push_back(Factors[0]); 5269 return true; 5270 } 5271 5272 auto *CI = cast<ConstantInt>(C); 5273 APInt Divisor = CI->getValue(); 5274 unsigned Shift = Divisor.countr_zero(); 5275 if (Shift) { 5276 Divisor.lshrInPlace(Shift); 5277 UseSRL = true; 5278 } 5279 5280 // Calculate the multiplicative inverse modulo BW. 5281 APInt Factor = Divisor.multiplicativeInverse(); 5282 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0)); 5283 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0)); 5284 return true; 5285 }; 5286 5287 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5288 // Collect all magic values from the build vector. 5289 if (!matchUnaryPredicate(MRI, RHS, BuildExactUDIVPattern)) 5290 llvm_unreachable("Expected unary predicate match to succeed"); 5291 5292 Register Shift, Factor; 5293 if (Ty.isVector()) { 5294 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0); 5295 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0); 5296 } else { 5297 Shift = Shifts[0]; 5298 Factor = Factors[0]; 5299 } 5300 5301 Register Res = LHS; 5302 5303 if (UseSRL) 5304 Res = MIB.buildLShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0); 5305 5306 return MIB.buildMul(Ty, Res, Factor); 5307 } 5308 5309 unsigned KnownLeadingZeros = 5310 KB ? KB->getKnownBits(LHS).countMinLeadingZeros() : 0; 5311 5312 bool UseNPQ = false; 5313 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; 5314 auto BuildUDIVPattern = [&](const Constant *C) { 5315 auto *CI = cast<ConstantInt>(C); 5316 const APInt &Divisor = CI->getValue(); 5317 5318 bool SelNPQ = false; 5319 APInt Magic(Divisor.getBitWidth(), 0); 5320 unsigned PreShift = 0, PostShift = 0; 5321 5322 // Magic algorithm doesn't work for division by 1. We need to emit a select 5323 // at the end. 5324 // TODO: Use undef values for divisor of 1. 5325 if (!Divisor.isOne()) { 5326 5327 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros 5328 // in the dividend exceeds the leading zeros for the divisor. 5329 UnsignedDivisionByConstantInfo magics = 5330 UnsignedDivisionByConstantInfo::get( 5331 Divisor, std::min(KnownLeadingZeros, Divisor.countl_zero())); 5332 5333 Magic = std::move(magics.Magic); 5334 5335 assert(magics.PreShift < Divisor.getBitWidth() && 5336 "We shouldn't generate an undefined shift!"); 5337 assert(magics.PostShift < Divisor.getBitWidth() && 5338 "We shouldn't generate an undefined shift!"); 5339 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift"); 5340 PreShift = magics.PreShift; 5341 PostShift = magics.PostShift; 5342 SelNPQ = magics.IsAdd; 5343 } 5344 5345 PreShifts.push_back( 5346 MIB.buildConstant(ScalarShiftAmtTy, PreShift).getReg(0)); 5347 MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magic).getReg(0)); 5348 NPQFactors.push_back( 5349 MIB.buildConstant(ScalarTy, 5350 SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1) 5351 : APInt::getZero(EltBits)) 5352 .getReg(0)); 5353 PostShifts.push_back( 5354 MIB.buildConstant(ScalarShiftAmtTy, PostShift).getReg(0)); 5355 UseNPQ |= SelNPQ; 5356 return true; 5357 }; 5358 5359 // Collect the shifts/magic values from each element. 5360 bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern); 5361 (void)Matched; 5362 assert(Matched && "Expected unary predicate match to succeed"); 5363 5364 Register PreShift, PostShift, MagicFactor, NPQFactor; 5365 auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI); 5366 if (RHSDef) { 5367 PreShift = MIB.buildBuildVector(ShiftAmtTy, PreShifts).getReg(0); 5368 MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0); 5369 NPQFactor = MIB.buildBuildVector(Ty, NPQFactors).getReg(0); 5370 PostShift = MIB.buildBuildVector(ShiftAmtTy, PostShifts).getReg(0); 5371 } else { 5372 assert(MRI.getType(RHS).isScalar() && 5373 "Non-build_vector operation should have been a scalar"); 5374 PreShift = PreShifts[0]; 5375 MagicFactor = MagicFactors[0]; 5376 PostShift = PostShifts[0]; 5377 } 5378 5379 Register Q = LHS; 5380 Q = MIB.buildLShr(Ty, Q, PreShift).getReg(0); 5381 5382 // Multiply the numerator (operand 0) by the magic value. 5383 Q = MIB.buildUMulH(Ty, Q, MagicFactor).getReg(0); 5384 5385 if (UseNPQ) { 5386 Register NPQ = MIB.buildSub(Ty, LHS, Q).getReg(0); 5387 5388 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 5389 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. 5390 if (Ty.isVector()) 5391 NPQ = MIB.buildUMulH(Ty, NPQ, NPQFactor).getReg(0); 5392 else 5393 NPQ = MIB.buildLShr(Ty, NPQ, MIB.buildConstant(ShiftAmtTy, 1)).getReg(0); 5394 5395 Q = MIB.buildAdd(Ty, NPQ, Q).getReg(0); 5396 } 5397 5398 Q = MIB.buildLShr(Ty, Q, PostShift).getReg(0); 5399 auto One = MIB.buildConstant(Ty, 1); 5400 auto IsOne = MIB.buildICmp( 5401 CmpInst::Predicate::ICMP_EQ, 5402 Ty.isScalar() ? LLT::scalar(1) : Ty.changeElementSize(1), RHS, One); 5403 return MIB.buildSelect(Ty, IsOne, LHS, Q); 5404 } 5405 5406 bool CombinerHelper::matchUDivByConst(MachineInstr &MI) const { 5407 assert(MI.getOpcode() == TargetOpcode::G_UDIV); 5408 Register Dst = MI.getOperand(0).getReg(); 5409 Register RHS = MI.getOperand(2).getReg(); 5410 LLT DstTy = MRI.getType(Dst); 5411 5412 auto &MF = *MI.getMF(); 5413 AttributeList Attr = MF.getFunction().getAttributes(); 5414 const auto &TLI = getTargetLowering(); 5415 LLVMContext &Ctx = MF.getFunction().getContext(); 5416 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr)) 5417 return false; 5418 5419 // Don't do this for minsize because the instruction sequence is usually 5420 // larger. 5421 if (MF.getFunction().hasMinSize()) 5422 return false; 5423 5424 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5425 return matchUnaryPredicate( 5426 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5427 } 5428 5429 auto *RHSDef = MRI.getVRegDef(RHS); 5430 if (!isConstantOrConstantVector(*RHSDef, MRI)) 5431 return false; 5432 5433 // Don't do this if the types are not going to be legal. 5434 if (LI) { 5435 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}})) 5436 return false; 5437 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMULH, {DstTy}})) 5438 return false; 5439 if (!isLegalOrBeforeLegalizer( 5440 {TargetOpcode::G_ICMP, 5441 {DstTy.isVector() ? DstTy.changeElementSize(1) : LLT::scalar(1), 5442 DstTy}})) 5443 return false; 5444 } 5445 5446 return matchUnaryPredicate( 5447 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5448 } 5449 5450 void CombinerHelper::applyUDivByConst(MachineInstr &MI) const { 5451 auto *NewMI = buildUDivUsingMul(MI); 5452 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); 5453 } 5454 5455 bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { 5456 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5457 Register Dst = MI.getOperand(0).getReg(); 5458 Register RHS = MI.getOperand(2).getReg(); 5459 LLT DstTy = MRI.getType(Dst); 5460 5461 auto &MF = *MI.getMF(); 5462 AttributeList Attr = MF.getFunction().getAttributes(); 5463 const auto &TLI = getTargetLowering(); 5464 LLVMContext &Ctx = MF.getFunction().getContext(); 5465 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr)) 5466 return false; 5467 5468 // Don't do this for minsize because the instruction sequence is usually 5469 // larger. 5470 if (MF.getFunction().hasMinSize()) 5471 return false; 5472 5473 // If the sdiv has an 'exact' flag we can use a simpler lowering. 5474 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5475 return matchUnaryPredicate( 5476 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5477 } 5478 5479 // Don't support the general case for now. 5480 return false; 5481 } 5482 5483 void CombinerHelper::applySDivByConst(MachineInstr &MI) const { 5484 auto *NewMI = buildSDivUsingMul(MI); 5485 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); 5486 } 5487 5488 MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const { 5489 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5490 auto &SDiv = cast<GenericMachineInstr>(MI); 5491 Register Dst = SDiv.getReg(0); 5492 Register LHS = SDiv.getReg(1); 5493 Register RHS = SDiv.getReg(2); 5494 LLT Ty = MRI.getType(Dst); 5495 LLT ScalarTy = Ty.getScalarType(); 5496 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5497 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); 5498 auto &MIB = Builder; 5499 5500 bool UseSRA = false; 5501 SmallVector<Register, 16> Shifts, Factors; 5502 5503 auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI)); 5504 bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value(); 5505 5506 auto BuildSDIVPattern = [&](const Constant *C) { 5507 // Don't recompute inverses for each splat element. 5508 if (IsSplat && !Factors.empty()) { 5509 Shifts.push_back(Shifts[0]); 5510 Factors.push_back(Factors[0]); 5511 return true; 5512 } 5513 5514 auto *CI = cast<ConstantInt>(C); 5515 APInt Divisor = CI->getValue(); 5516 unsigned Shift = Divisor.countr_zero(); 5517 if (Shift) { 5518 Divisor.ashrInPlace(Shift); 5519 UseSRA = true; 5520 } 5521 5522 // Calculate the multiplicative inverse modulo BW. 5523 // 2^W requires W + 1 bits, so we have to extend and then truncate. 5524 APInt Factor = Divisor.multiplicativeInverse(); 5525 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0)); 5526 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0)); 5527 return true; 5528 }; 5529 5530 // Collect all magic values from the build vector. 5531 bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern); 5532 (void)Matched; 5533 assert(Matched && "Expected unary predicate match to succeed"); 5534 5535 Register Shift, Factor; 5536 if (Ty.isVector()) { 5537 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0); 5538 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0); 5539 } else { 5540 Shift = Shifts[0]; 5541 Factor = Factors[0]; 5542 } 5543 5544 Register Res = LHS; 5545 5546 if (UseSRA) 5547 Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0); 5548 5549 return MIB.buildMul(Ty, Res, Factor); 5550 } 5551 5552 bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const { 5553 assert((MI.getOpcode() == TargetOpcode::G_SDIV || 5554 MI.getOpcode() == TargetOpcode::G_UDIV) && 5555 "Expected SDIV or UDIV"); 5556 auto &Div = cast<GenericMachineInstr>(MI); 5557 Register RHS = Div.getReg(2); 5558 auto MatchPow2 = [&](const Constant *C) { 5559 auto *CI = dyn_cast<ConstantInt>(C); 5560 return CI && (CI->getValue().isPowerOf2() || 5561 (IsSigned && CI->getValue().isNegatedPowerOf2())); 5562 }; 5563 return matchUnaryPredicate(MRI, RHS, MatchPow2, /*AllowUndefs=*/false); 5564 } 5565 5566 void CombinerHelper::applySDivByPow2(MachineInstr &MI) const { 5567 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5568 auto &SDiv = cast<GenericMachineInstr>(MI); 5569 Register Dst = SDiv.getReg(0); 5570 Register LHS = SDiv.getReg(1); 5571 Register RHS = SDiv.getReg(2); 5572 LLT Ty = MRI.getType(Dst); 5573 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5574 LLT CCVT = 5575 Ty.isVector() ? LLT::vector(Ty.getElementCount(), 1) : LLT::scalar(1); 5576 5577 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2, 5578 // to the following version: 5579 // 5580 // %c1 = G_CTTZ %rhs 5581 // %inexact = G_SUB $bitwidth, %c1 5582 // %sign = %G_ASHR %lhs, $(bitwidth - 1) 5583 // %lshr = G_LSHR %sign, %inexact 5584 // %add = G_ADD %lhs, %lshr 5585 // %ashr = G_ASHR %add, %c1 5586 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr 5587 // %zero = G_CONSTANT $0 5588 // %neg = G_NEG %ashr 5589 // %isneg = G_ICMP SLT %rhs, %zero 5590 // %res = G_SELECT %isneg, %neg, %ashr 5591 5592 unsigned BitWidth = Ty.getScalarSizeInBits(); 5593 auto Zero = Builder.buildConstant(Ty, 0); 5594 5595 auto Bits = Builder.buildConstant(ShiftAmtTy, BitWidth); 5596 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS); 5597 auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1); 5598 // Splat the sign bit into the register 5599 auto Sign = Builder.buildAShr( 5600 Ty, LHS, Builder.buildConstant(ShiftAmtTy, BitWidth - 1)); 5601 5602 // Add (LHS < 0) ? abs2 - 1 : 0; 5603 auto LSrl = Builder.buildLShr(Ty, Sign, Inexact); 5604 auto Add = Builder.buildAdd(Ty, LHS, LSrl); 5605 auto AShr = Builder.buildAShr(Ty, Add, C1); 5606 5607 // Special case: (sdiv X, 1) -> X 5608 // Special Case: (sdiv X, -1) -> 0-X 5609 auto One = Builder.buildConstant(Ty, 1); 5610 auto MinusOne = Builder.buildConstant(Ty, -1); 5611 auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One); 5612 auto IsMinusOne = 5613 Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, MinusOne); 5614 auto IsOneOrMinusOne = Builder.buildOr(CCVT, IsOne, IsMinusOne); 5615 AShr = Builder.buildSelect(Ty, IsOneOrMinusOne, LHS, AShr); 5616 5617 // If divided by a positive value, we're done. Otherwise, the result must be 5618 // negated. 5619 auto Neg = Builder.buildNeg(Ty, AShr); 5620 auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, RHS, Zero); 5621 Builder.buildSelect(MI.getOperand(0).getReg(), IsNeg, Neg, AShr); 5622 MI.eraseFromParent(); 5623 } 5624 5625 void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const { 5626 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV"); 5627 auto &UDiv = cast<GenericMachineInstr>(MI); 5628 Register Dst = UDiv.getReg(0); 5629 Register LHS = UDiv.getReg(1); 5630 Register RHS = UDiv.getReg(2); 5631 LLT Ty = MRI.getType(Dst); 5632 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5633 5634 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS); 5635 Builder.buildLShr(MI.getOperand(0).getReg(), LHS, C1); 5636 MI.eraseFromParent(); 5637 } 5638 5639 bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const { 5640 assert(MI.getOpcode() == TargetOpcode::G_UMULH); 5641 Register RHS = MI.getOperand(2).getReg(); 5642 Register Dst = MI.getOperand(0).getReg(); 5643 LLT Ty = MRI.getType(Dst); 5644 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5645 auto MatchPow2ExceptOne = [&](const Constant *C) { 5646 if (auto *CI = dyn_cast<ConstantInt>(C)) 5647 return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); 5648 return false; 5649 }; 5650 if (!matchUnaryPredicate(MRI, RHS, MatchPow2ExceptOne, false)) 5651 return false; 5652 return isLegalOrBeforeLegalizer({TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}); 5653 } 5654 5655 void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const { 5656 Register LHS = MI.getOperand(1).getReg(); 5657 Register RHS = MI.getOperand(2).getReg(); 5658 Register Dst = MI.getOperand(0).getReg(); 5659 LLT Ty = MRI.getType(Dst); 5660 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5661 unsigned NumEltBits = Ty.getScalarSizeInBits(); 5662 5663 auto LogBase2 = buildLogBase2(RHS, Builder); 5664 auto ShiftAmt = 5665 Builder.buildSub(Ty, Builder.buildConstant(Ty, NumEltBits), LogBase2); 5666 auto Trunc = Builder.buildZExtOrTrunc(ShiftAmtTy, ShiftAmt); 5667 Builder.buildLShr(Dst, LHS, Trunc); 5668 MI.eraseFromParent(); 5669 } 5670 5671 bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, 5672 BuildFnTy &MatchInfo) const { 5673 unsigned Opc = MI.getOpcode(); 5674 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || 5675 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || 5676 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); 5677 5678 Register Dst = MI.getOperand(0).getReg(); 5679 Register X = MI.getOperand(1).getReg(); 5680 Register Y = MI.getOperand(2).getReg(); 5681 LLT Type = MRI.getType(Dst); 5682 5683 // fold (fadd x, fneg(y)) -> (fsub x, y) 5684 // fold (fadd fneg(y), x) -> (fsub x, y) 5685 // G_ADD is commutative so both cases are checked by m_GFAdd 5686 if (mi_match(Dst, MRI, m_GFAdd(m_Reg(X), m_GFNeg(m_Reg(Y)))) && 5687 isLegalOrBeforeLegalizer({TargetOpcode::G_FSUB, {Type}})) { 5688 Opc = TargetOpcode::G_FSUB; 5689 } 5690 /// fold (fsub x, fneg(y)) -> (fadd x, y) 5691 else if (mi_match(Dst, MRI, m_GFSub(m_Reg(X), m_GFNeg(m_Reg(Y)))) && 5692 isLegalOrBeforeLegalizer({TargetOpcode::G_FADD, {Type}})) { 5693 Opc = TargetOpcode::G_FADD; 5694 } 5695 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) 5696 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) 5697 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) 5698 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) 5699 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || 5700 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && 5701 mi_match(X, MRI, m_GFNeg(m_Reg(X))) && 5702 mi_match(Y, MRI, m_GFNeg(m_Reg(Y)))) { 5703 // no opcode change 5704 } else 5705 return false; 5706 5707 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5708 Observer.changingInstr(MI); 5709 MI.setDesc(B.getTII().get(Opc)); 5710 MI.getOperand(1).setReg(X); 5711 MI.getOperand(2).setReg(Y); 5712 Observer.changedInstr(MI); 5713 }; 5714 return true; 5715 } 5716 5717 bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, 5718 Register &MatchInfo) const { 5719 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 5720 5721 Register LHS = MI.getOperand(1).getReg(); 5722 MatchInfo = MI.getOperand(2).getReg(); 5723 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 5724 5725 const auto LHSCst = Ty.isVector() 5726 ? getFConstantSplat(LHS, MRI, /* allowUndef */ true) 5727 : getFConstantVRegValWithLookThrough(LHS, MRI); 5728 if (!LHSCst) 5729 return false; 5730 5731 // -0.0 is always allowed 5732 if (LHSCst->Value.isNegZero()) 5733 return true; 5734 5735 // +0.0 is only allowed if nsz is set. 5736 if (LHSCst->Value.isPosZero()) 5737 return MI.getFlag(MachineInstr::FmNsz); 5738 5739 return false; 5740 } 5741 5742 void CombinerHelper::applyFsubToFneg(MachineInstr &MI, 5743 Register &MatchInfo) const { 5744 Register Dst = MI.getOperand(0).getReg(); 5745 Builder.buildFNeg( 5746 Dst, Builder.buildFCanonicalize(MRI.getType(Dst), MatchInfo).getReg(0)); 5747 eraseInst(MI); 5748 } 5749 5750 /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either 5751 /// due to global flags or MachineInstr flags. 5752 static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { 5753 if (MI.getOpcode() != TargetOpcode::G_FMUL) 5754 return false; 5755 return AllowFusionGlobally || MI.getFlag(MachineInstr::MIFlag::FmContract); 5756 } 5757 5758 static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, 5759 const MachineRegisterInfo &MRI) { 5760 return std::distance(MRI.use_instr_nodbg_begin(MI0.getOperand(0).getReg()), 5761 MRI.use_instr_nodbg_end()) > 5762 std::distance(MRI.use_instr_nodbg_begin(MI1.getOperand(0).getReg()), 5763 MRI.use_instr_nodbg_end()); 5764 } 5765 5766 bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, 5767 bool &AllowFusionGlobally, 5768 bool &HasFMAD, bool &Aggressive, 5769 bool CanReassociate) const { 5770 5771 auto *MF = MI.getMF(); 5772 const auto &TLI = *MF->getSubtarget().getTargetLowering(); 5773 const TargetOptions &Options = MF->getTarget().Options; 5774 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 5775 5776 if (CanReassociate && 5777 !(Options.UnsafeFPMath || MI.getFlag(MachineInstr::MIFlag::FmReassoc))) 5778 return false; 5779 5780 // Floating-point multiply-add with intermediate rounding. 5781 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, DstType)); 5782 // Floating-point multiply-add without intermediate rounding. 5783 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) && 5784 isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}}); 5785 // No valid opcode, do not combine. 5786 if (!HasFMAD && !HasFMA) 5787 return false; 5788 5789 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || 5790 Options.UnsafeFPMath || HasFMAD; 5791 // If the addition is not contractable, do not combine. 5792 if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract)) 5793 return false; 5794 5795 Aggressive = TLI.enableAggressiveFMAFusion(DstType); 5796 return true; 5797 } 5798 5799 bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( 5800 MachineInstr &MI, 5801 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5802 assert(MI.getOpcode() == TargetOpcode::G_FADD); 5803 5804 bool AllowFusionGlobally, HasFMAD, Aggressive; 5805 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 5806 return false; 5807 5808 Register Op1 = MI.getOperand(1).getReg(); 5809 Register Op2 = MI.getOperand(2).getReg(); 5810 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 5811 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 5812 unsigned PreferredFusedOpcode = 5813 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 5814 5815 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 5816 // prefer to fold the multiply with fewer uses. 5817 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5818 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 5819 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 5820 std::swap(LHS, RHS); 5821 } 5822 5823 // fold (fadd (fmul x, y), z) -> (fma x, y, z) 5824 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5825 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) { 5826 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5827 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5828 {LHS.MI->getOperand(1).getReg(), 5829 LHS.MI->getOperand(2).getReg(), RHS.Reg}); 5830 }; 5831 return true; 5832 } 5833 5834 // fold (fadd x, (fmul y, z)) -> (fma y, z, x) 5835 if (isContractableFMul(*RHS.MI, AllowFusionGlobally) && 5836 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) { 5837 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5838 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5839 {RHS.MI->getOperand(1).getReg(), 5840 RHS.MI->getOperand(2).getReg(), LHS.Reg}); 5841 }; 5842 return true; 5843 } 5844 5845 return false; 5846 } 5847 5848 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( 5849 MachineInstr &MI, 5850 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5851 assert(MI.getOpcode() == TargetOpcode::G_FADD); 5852 5853 bool AllowFusionGlobally, HasFMAD, Aggressive; 5854 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 5855 return false; 5856 5857 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 5858 Register Op1 = MI.getOperand(1).getReg(); 5859 Register Op2 = MI.getOperand(2).getReg(); 5860 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 5861 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 5862 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 5863 5864 unsigned PreferredFusedOpcode = 5865 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 5866 5867 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 5868 // prefer to fold the multiply with fewer uses. 5869 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5870 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 5871 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 5872 std::swap(LHS, RHS); 5873 } 5874 5875 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) 5876 MachineInstr *FpExtSrc; 5877 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) && 5878 isContractableFMul(*FpExtSrc, AllowFusionGlobally) && 5879 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 5880 MRI.getType(FpExtSrc->getOperand(1).getReg()))) { 5881 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5882 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); 5883 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); 5884 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5885 {FpExtX.getReg(0), FpExtY.getReg(0), RHS.Reg}); 5886 }; 5887 return true; 5888 } 5889 5890 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) 5891 // Note: Commutes FADD operands. 5892 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) && 5893 isContractableFMul(*FpExtSrc, AllowFusionGlobally) && 5894 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 5895 MRI.getType(FpExtSrc->getOperand(1).getReg()))) { 5896 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5897 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); 5898 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); 5899 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5900 {FpExtX.getReg(0), FpExtY.getReg(0), LHS.Reg}); 5901 }; 5902 return true; 5903 } 5904 5905 return false; 5906 } 5907 5908 bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( 5909 MachineInstr &MI, 5910 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5911 assert(MI.getOpcode() == TargetOpcode::G_FADD); 5912 5913 bool AllowFusionGlobally, HasFMAD, Aggressive; 5914 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, true)) 5915 return false; 5916 5917 Register Op1 = MI.getOperand(1).getReg(); 5918 Register Op2 = MI.getOperand(2).getReg(); 5919 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 5920 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 5921 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 5922 5923 unsigned PreferredFusedOpcode = 5924 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 5925 5926 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 5927 // prefer to fold the multiply with fewer uses. 5928 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5929 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 5930 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 5931 std::swap(LHS, RHS); 5932 } 5933 5934 MachineInstr *FMA = nullptr; 5935 Register Z; 5936 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) 5937 if (LHS.MI->getOpcode() == PreferredFusedOpcode && 5938 (MRI.getVRegDef(LHS.MI->getOperand(3).getReg())->getOpcode() == 5939 TargetOpcode::G_FMUL) && 5940 MRI.hasOneNonDBGUse(LHS.MI->getOperand(0).getReg()) && 5941 MRI.hasOneNonDBGUse(LHS.MI->getOperand(3).getReg())) { 5942 FMA = LHS.MI; 5943 Z = RHS.Reg; 5944 } 5945 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) 5946 else if (RHS.MI->getOpcode() == PreferredFusedOpcode && 5947 (MRI.getVRegDef(RHS.MI->getOperand(3).getReg())->getOpcode() == 5948 TargetOpcode::G_FMUL) && 5949 MRI.hasOneNonDBGUse(RHS.MI->getOperand(0).getReg()) && 5950 MRI.hasOneNonDBGUse(RHS.MI->getOperand(3).getReg())) { 5951 Z = LHS.Reg; 5952 FMA = RHS.MI; 5953 } 5954 5955 if (FMA) { 5956 MachineInstr *FMulMI = MRI.getVRegDef(FMA->getOperand(3).getReg()); 5957 Register X = FMA->getOperand(1).getReg(); 5958 Register Y = FMA->getOperand(2).getReg(); 5959 Register U = FMulMI->getOperand(1).getReg(); 5960 Register V = FMulMI->getOperand(2).getReg(); 5961 5962 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5963 Register InnerFMA = MRI.createGenericVirtualRegister(DstTy); 5964 B.buildInstr(PreferredFusedOpcode, {InnerFMA}, {U, V, Z}); 5965 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5966 {X, Y, InnerFMA}); 5967 }; 5968 return true; 5969 } 5970 5971 return false; 5972 } 5973 5974 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( 5975 MachineInstr &MI, 5976 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5977 assert(MI.getOpcode() == TargetOpcode::G_FADD); 5978 5979 bool AllowFusionGlobally, HasFMAD, Aggressive; 5980 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 5981 return false; 5982 5983 if (!Aggressive) 5984 return false; 5985 5986 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 5987 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 5988 Register Op1 = MI.getOperand(1).getReg(); 5989 Register Op2 = MI.getOperand(2).getReg(); 5990 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 5991 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 5992 5993 unsigned PreferredFusedOpcode = 5994 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 5995 5996 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 5997 // prefer to fold the multiply with fewer uses. 5998 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5999 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 6000 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6001 std::swap(LHS, RHS); 6002 } 6003 6004 // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) 6005 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, 6006 Register Y, MachineIRBuilder &B) { 6007 Register FpExtU = B.buildFPExt(DstType, U).getReg(0); 6008 Register FpExtV = B.buildFPExt(DstType, V).getReg(0); 6009 Register InnerFMA = 6010 B.buildInstr(PreferredFusedOpcode, {DstType}, {FpExtU, FpExtV, Z}) 6011 .getReg(0); 6012 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6013 {X, Y, InnerFMA}); 6014 }; 6015 6016 MachineInstr *FMulMI, *FMAMI; 6017 // fold (fadd (fma x, y, (fpext (fmul u, v))), z) 6018 // -> (fma x, y, (fma (fpext u), (fpext v), z)) 6019 if (LHS.MI->getOpcode() == PreferredFusedOpcode && 6020 mi_match(LHS.MI->getOperand(3).getReg(), MRI, 6021 m_GFPExt(m_MInstr(FMulMI))) && 6022 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6023 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6024 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6025 MatchInfo = [=](MachineIRBuilder &B) { 6026 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6027 FMulMI->getOperand(2).getReg(), RHS.Reg, 6028 LHS.MI->getOperand(1).getReg(), 6029 LHS.MI->getOperand(2).getReg(), B); 6030 }; 6031 return true; 6032 } 6033 6034 // fold (fadd (fpext (fma x, y, (fmul u, v))), z) 6035 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) 6036 // FIXME: This turns two single-precision and one double-precision 6037 // operation into two double-precision operations, which might not be 6038 // interesting for all targets, especially GPUs. 6039 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) && 6040 FMAMI->getOpcode() == PreferredFusedOpcode) { 6041 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg()); 6042 if (isContractableFMul(*FMulMI, AllowFusionGlobally) && 6043 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6044 MRI.getType(FMAMI->getOperand(0).getReg()))) { 6045 MatchInfo = [=](MachineIRBuilder &B) { 6046 Register X = FMAMI->getOperand(1).getReg(); 6047 Register Y = FMAMI->getOperand(2).getReg(); 6048 X = B.buildFPExt(DstType, X).getReg(0); 6049 Y = B.buildFPExt(DstType, Y).getReg(0); 6050 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6051 FMulMI->getOperand(2).getReg(), RHS.Reg, X, Y, B); 6052 }; 6053 6054 return true; 6055 } 6056 } 6057 6058 // fold (fadd z, (fma x, y, (fpext (fmul u, v))) 6059 // -> (fma x, y, (fma (fpext u), (fpext v), z)) 6060 if (RHS.MI->getOpcode() == PreferredFusedOpcode && 6061 mi_match(RHS.MI->getOperand(3).getReg(), MRI, 6062 m_GFPExt(m_MInstr(FMulMI))) && 6063 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6064 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6065 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6066 MatchInfo = [=](MachineIRBuilder &B) { 6067 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6068 FMulMI->getOperand(2).getReg(), LHS.Reg, 6069 RHS.MI->getOperand(1).getReg(), 6070 RHS.MI->getOperand(2).getReg(), B); 6071 }; 6072 return true; 6073 } 6074 6075 // fold (fadd z, (fpext (fma x, y, (fmul u, v))) 6076 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) 6077 // FIXME: This turns two single-precision and one double-precision 6078 // operation into two double-precision operations, which might not be 6079 // interesting for all targets, especially GPUs. 6080 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) && 6081 FMAMI->getOpcode() == PreferredFusedOpcode) { 6082 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg()); 6083 if (isContractableFMul(*FMulMI, AllowFusionGlobally) && 6084 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6085 MRI.getType(FMAMI->getOperand(0).getReg()))) { 6086 MatchInfo = [=](MachineIRBuilder &B) { 6087 Register X = FMAMI->getOperand(1).getReg(); 6088 Register Y = FMAMI->getOperand(2).getReg(); 6089 X = B.buildFPExt(DstType, X).getReg(0); 6090 Y = B.buildFPExt(DstType, Y).getReg(0); 6091 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6092 FMulMI->getOperand(2).getReg(), LHS.Reg, X, Y, B); 6093 }; 6094 return true; 6095 } 6096 } 6097 6098 return false; 6099 } 6100 6101 bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( 6102 MachineInstr &MI, 6103 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6104 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6105 6106 bool AllowFusionGlobally, HasFMAD, Aggressive; 6107 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6108 return false; 6109 6110 Register Op1 = MI.getOperand(1).getReg(); 6111 Register Op2 = MI.getOperand(2).getReg(); 6112 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 6113 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 6114 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6115 6116 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 6117 // prefer to fold the multiply with fewer uses. 6118 int FirstMulHasFewerUses = true; 6119 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6120 isContractableFMul(*RHS.MI, AllowFusionGlobally) && 6121 hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6122 FirstMulHasFewerUses = false; 6123 6124 unsigned PreferredFusedOpcode = 6125 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6126 6127 // fold (fsub (fmul x, y), z) -> (fma x, y, -z) 6128 if (FirstMulHasFewerUses && 6129 (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6130 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) { 6131 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6132 Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0); 6133 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6134 {LHS.MI->getOperand(1).getReg(), 6135 LHS.MI->getOperand(2).getReg(), NegZ}); 6136 }; 6137 return true; 6138 } 6139 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) 6140 else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) && 6141 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) { 6142 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6143 Register NegY = 6144 B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0); 6145 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6146 {NegY, RHS.MI->getOperand(2).getReg(), LHS.Reg}); 6147 }; 6148 return true; 6149 } 6150 6151 return false; 6152 } 6153 6154 bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( 6155 MachineInstr &MI, 6156 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6157 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6158 6159 bool AllowFusionGlobally, HasFMAD, Aggressive; 6160 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6161 return false; 6162 6163 Register LHSReg = MI.getOperand(1).getReg(); 6164 Register RHSReg = MI.getOperand(2).getReg(); 6165 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6166 6167 unsigned PreferredFusedOpcode = 6168 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6169 6170 MachineInstr *FMulMI; 6171 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) 6172 if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) && 6173 (Aggressive || (MRI.hasOneNonDBGUse(LHSReg) && 6174 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) && 6175 isContractableFMul(*FMulMI, AllowFusionGlobally)) { 6176 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6177 Register NegX = 6178 B.buildFNeg(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6179 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0); 6180 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6181 {NegX, FMulMI->getOperand(2).getReg(), NegZ}); 6182 }; 6183 return true; 6184 } 6185 6186 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) 6187 if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) && 6188 (Aggressive || (MRI.hasOneNonDBGUse(RHSReg) && 6189 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) && 6190 isContractableFMul(*FMulMI, AllowFusionGlobally)) { 6191 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6192 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6193 {FMulMI->getOperand(1).getReg(), 6194 FMulMI->getOperand(2).getReg(), LHSReg}); 6195 }; 6196 return true; 6197 } 6198 6199 return false; 6200 } 6201 6202 bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( 6203 MachineInstr &MI, 6204 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6205 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6206 6207 bool AllowFusionGlobally, HasFMAD, Aggressive; 6208 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6209 return false; 6210 6211 Register LHSReg = MI.getOperand(1).getReg(); 6212 Register RHSReg = MI.getOperand(2).getReg(); 6213 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6214 6215 unsigned PreferredFusedOpcode = 6216 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6217 6218 MachineInstr *FMulMI; 6219 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) 6220 if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) && 6221 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6222 (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) { 6223 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6224 Register FpExtX = 6225 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6226 Register FpExtY = 6227 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0); 6228 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0); 6229 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6230 {FpExtX, FpExtY, NegZ}); 6231 }; 6232 return true; 6233 } 6234 6235 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) 6236 if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) && 6237 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6238 (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) { 6239 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6240 Register FpExtY = 6241 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6242 Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0); 6243 Register FpExtZ = 6244 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0); 6245 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6246 {NegY, FpExtZ, LHSReg}); 6247 }; 6248 return true; 6249 } 6250 6251 return false; 6252 } 6253 6254 bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( 6255 MachineInstr &MI, 6256 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6257 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6258 6259 bool AllowFusionGlobally, HasFMAD, Aggressive; 6260 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6261 return false; 6262 6263 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 6264 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6265 Register LHSReg = MI.getOperand(1).getReg(); 6266 Register RHSReg = MI.getOperand(2).getReg(); 6267 6268 unsigned PreferredFusedOpcode = 6269 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6270 6271 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, 6272 MachineIRBuilder &B) { 6273 Register FpExtX = B.buildFPExt(DstTy, X).getReg(0); 6274 Register FpExtY = B.buildFPExt(DstTy, Y).getReg(0); 6275 B.buildInstr(PreferredFusedOpcode, {Dst}, {FpExtX, FpExtY, Z}); 6276 }; 6277 6278 MachineInstr *FMulMI; 6279 // fold (fsub (fpext (fneg (fmul x, y))), z) -> 6280 // (fneg (fma (fpext x), (fpext y), z)) 6281 // fold (fsub (fneg (fpext (fmul x, y))), z) -> 6282 // (fneg (fma (fpext x), (fpext y), z)) 6283 if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) || 6284 mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) && 6285 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6286 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy, 6287 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6288 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6289 Register FMAReg = MRI.createGenericVirtualRegister(DstTy); 6290 buildMatchInfo(FMAReg, FMulMI->getOperand(1).getReg(), 6291 FMulMI->getOperand(2).getReg(), RHSReg, B); 6292 B.buildFNeg(MI.getOperand(0).getReg(), FMAReg); 6293 }; 6294 return true; 6295 } 6296 6297 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) 6298 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) 6299 if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) || 6300 mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) && 6301 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6302 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy, 6303 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6304 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6305 buildMatchInfo(MI.getOperand(0).getReg(), FMulMI->getOperand(1).getReg(), 6306 FMulMI->getOperand(2).getReg(), LHSReg, B); 6307 }; 6308 return true; 6309 } 6310 6311 return false; 6312 } 6313 6314 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, 6315 unsigned &IdxToPropagate) const { 6316 bool PropagateNaN; 6317 switch (MI.getOpcode()) { 6318 default: 6319 return false; 6320 case TargetOpcode::G_FMINNUM: 6321 case TargetOpcode::G_FMAXNUM: 6322 PropagateNaN = false; 6323 break; 6324 case TargetOpcode::G_FMINIMUM: 6325 case TargetOpcode::G_FMAXIMUM: 6326 PropagateNaN = true; 6327 break; 6328 } 6329 6330 auto MatchNaN = [&](unsigned Idx) { 6331 Register MaybeNaNReg = MI.getOperand(Idx).getReg(); 6332 const ConstantFP *MaybeCst = getConstantFPVRegVal(MaybeNaNReg, MRI); 6333 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) 6334 return false; 6335 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); 6336 return true; 6337 }; 6338 6339 return MatchNaN(1) || MatchNaN(2); 6340 } 6341 6342 bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const { 6343 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD"); 6344 Register LHS = MI.getOperand(1).getReg(); 6345 Register RHS = MI.getOperand(2).getReg(); 6346 6347 // Helper lambda to check for opportunities for 6348 // A + (B - A) -> B 6349 // (B - A) + A -> B 6350 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { 6351 Register Reg; 6352 return mi_match(MaybeSub, MRI, m_GSub(m_Reg(Src), m_Reg(Reg))) && 6353 Reg == MaybeSameReg; 6354 }; 6355 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); 6356 } 6357 6358 bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, 6359 Register &MatchInfo) const { 6360 // This combine folds the following patterns: 6361 // 6362 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) 6363 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) 6364 // into 6365 // x 6366 // if 6367 // k == sizeof(VecEltTy)/2 6368 // type(x) == type(dst) 6369 // 6370 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) 6371 // into 6372 // x 6373 // if 6374 // type(x) == type(dst) 6375 6376 LLT DstVecTy = MRI.getType(MI.getOperand(0).getReg()); 6377 LLT DstEltTy = DstVecTy.getElementType(); 6378 6379 Register Lo, Hi; 6380 6381 if (mi_match( 6382 MI, MRI, 6383 m_GBuildVector(m_GTrunc(m_GBitcast(m_Reg(Lo))), m_GImplicitDef()))) { 6384 MatchInfo = Lo; 6385 return MRI.getType(MatchInfo) == DstVecTy; 6386 } 6387 6388 std::optional<ValueAndVReg> ShiftAmount; 6389 const auto LoPattern = m_GBitcast(m_Reg(Lo)); 6390 const auto HiPattern = m_GLShr(m_GBitcast(m_Reg(Hi)), m_GCst(ShiftAmount)); 6391 if (mi_match( 6392 MI, MRI, 6393 m_any_of(m_GBuildVectorTrunc(LoPattern, HiPattern), 6394 m_GBuildVector(m_GTrunc(LoPattern), m_GTrunc(HiPattern))))) { 6395 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { 6396 MatchInfo = Lo; 6397 return MRI.getType(MatchInfo) == DstVecTy; 6398 } 6399 } 6400 6401 return false; 6402 } 6403 6404 bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, 6405 Register &MatchInfo) const { 6406 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x 6407 // if type(x) == type(G_TRUNC) 6408 if (!mi_match(MI.getOperand(1).getReg(), MRI, 6409 m_GBitcast(m_GBuildVector(m_Reg(MatchInfo), m_Reg())))) 6410 return false; 6411 6412 return MRI.getType(MatchInfo) == MRI.getType(MI.getOperand(0).getReg()); 6413 } 6414 6415 bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, 6416 Register &MatchInfo) const { 6417 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with 6418 // y if K == size of vector element type 6419 std::optional<ValueAndVReg> ShiftAmt; 6420 if (!mi_match(MI.getOperand(1).getReg(), MRI, 6421 m_GLShr(m_GBitcast(m_GBuildVector(m_Reg(), m_Reg(MatchInfo))), 6422 m_GCst(ShiftAmt)))) 6423 return false; 6424 6425 LLT MatchTy = MRI.getType(MatchInfo); 6426 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && 6427 MatchTy == MRI.getType(MI.getOperand(0).getReg()); 6428 } 6429 6430 unsigned CombinerHelper::getFPMinMaxOpcForSelect( 6431 CmpInst::Predicate Pred, LLT DstTy, 6432 SelectPatternNaNBehaviour VsNaNRetVal) const { 6433 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && 6434 "Expected a NaN behaviour?"); 6435 // Choose an opcode based off of legality or the behaviour when one of the 6436 // LHS/RHS may be NaN. 6437 switch (Pred) { 6438 default: 6439 return 0; 6440 case CmpInst::FCMP_UGT: 6441 case CmpInst::FCMP_UGE: 6442 case CmpInst::FCMP_OGT: 6443 case CmpInst::FCMP_OGE: 6444 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) 6445 return TargetOpcode::G_FMAXNUM; 6446 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) 6447 return TargetOpcode::G_FMAXIMUM; 6448 if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}})) 6449 return TargetOpcode::G_FMAXNUM; 6450 if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}})) 6451 return TargetOpcode::G_FMAXIMUM; 6452 return 0; 6453 case CmpInst::FCMP_ULT: 6454 case CmpInst::FCMP_ULE: 6455 case CmpInst::FCMP_OLT: 6456 case CmpInst::FCMP_OLE: 6457 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) 6458 return TargetOpcode::G_FMINNUM; 6459 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) 6460 return TargetOpcode::G_FMINIMUM; 6461 if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}})) 6462 return TargetOpcode::G_FMINNUM; 6463 if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}})) 6464 return 0; 6465 return TargetOpcode::G_FMINIMUM; 6466 } 6467 } 6468 6469 CombinerHelper::SelectPatternNaNBehaviour 6470 CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, 6471 bool IsOrderedComparison) const { 6472 bool LHSSafe = isKnownNeverNaN(LHS, MRI); 6473 bool RHSSafe = isKnownNeverNaN(RHS, MRI); 6474 // Completely unsafe. 6475 if (!LHSSafe && !RHSSafe) 6476 return SelectPatternNaNBehaviour::NOT_APPLICABLE; 6477 if (LHSSafe && RHSSafe) 6478 return SelectPatternNaNBehaviour::RETURNS_ANY; 6479 // An ordered comparison will return false when given a NaN, so it 6480 // returns the RHS. 6481 if (IsOrderedComparison) 6482 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN 6483 : SelectPatternNaNBehaviour::RETURNS_OTHER; 6484 // An unordered comparison will return true when given a NaN, so it 6485 // returns the LHS. 6486 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER 6487 : SelectPatternNaNBehaviour::RETURNS_NAN; 6488 } 6489 6490 bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, 6491 Register TrueVal, Register FalseVal, 6492 BuildFnTy &MatchInfo) const { 6493 // Match: select (fcmp cond x, y) x, y 6494 // select (fcmp cond x, y) y, x 6495 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. 6496 LLT DstTy = MRI.getType(Dst); 6497 // Bail out early on pointers, since we'll never want to fold to a min/max. 6498 if (DstTy.isPointer()) 6499 return false; 6500 // Match a floating point compare with a less-than/greater-than predicate. 6501 // TODO: Allow multiple users of the compare if they are all selects. 6502 CmpInst::Predicate Pred; 6503 Register CmpLHS, CmpRHS; 6504 if (!mi_match(Cond, MRI, 6505 m_OneNonDBGUse( 6506 m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) || 6507 CmpInst::isEquality(Pred)) 6508 return false; 6509 SelectPatternNaNBehaviour ResWithKnownNaNInfo = 6510 computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred)); 6511 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) 6512 return false; 6513 if (TrueVal == CmpRHS && FalseVal == CmpLHS) { 6514 std::swap(CmpLHS, CmpRHS); 6515 Pred = CmpInst::getSwappedPredicate(Pred); 6516 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) 6517 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; 6518 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) 6519 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; 6520 } 6521 if (TrueVal != CmpLHS || FalseVal != CmpRHS) 6522 return false; 6523 // Decide what type of max/min this should be based off of the predicate. 6524 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo); 6525 if (!Opc || !isLegal({Opc, {DstTy}})) 6526 return false; 6527 // Comparisons between signed zero and zero may have different results... 6528 // unless we have fmaximum/fminimum. In that case, we know -0 < 0. 6529 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { 6530 // We don't know if a comparison between two 0s will give us a consistent 6531 // result. Be conservative and only proceed if at least one side is 6532 // non-zero. 6533 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI); 6534 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { 6535 KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI); 6536 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) 6537 return false; 6538 } 6539 } 6540 MatchInfo = [=](MachineIRBuilder &B) { 6541 B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS}); 6542 }; 6543 return true; 6544 } 6545 6546 bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, 6547 BuildFnTy &MatchInfo) const { 6548 // TODO: Handle integer cases. 6549 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 6550 // Condition may be fed by a truncated compare. 6551 Register Cond = MI.getOperand(1).getReg(); 6552 Register MaybeTrunc; 6553 if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc))))) 6554 Cond = MaybeTrunc; 6555 Register Dst = MI.getOperand(0).getReg(); 6556 Register TrueVal = MI.getOperand(2).getReg(); 6557 Register FalseVal = MI.getOperand(3).getReg(); 6558 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); 6559 } 6560 6561 bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, 6562 BuildFnTy &MatchInfo) const { 6563 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 6564 // (X + Y) == X --> Y == 0 6565 // (X + Y) != X --> Y != 0 6566 // (X - Y) == X --> Y == 0 6567 // (X - Y) != X --> Y != 0 6568 // (X ^ Y) == X --> Y == 0 6569 // (X ^ Y) != X --> Y != 0 6570 Register Dst = MI.getOperand(0).getReg(); 6571 CmpInst::Predicate Pred; 6572 Register X, Y, OpLHS, OpRHS; 6573 bool MatchedSub = mi_match( 6574 Dst, MRI, 6575 m_c_GICmp(m_Pred(Pred), m_Reg(X), m_GSub(m_Reg(OpLHS), m_Reg(Y)))); 6576 if (MatchedSub && X != OpLHS) 6577 return false; 6578 if (!MatchedSub) { 6579 if (!mi_match(Dst, MRI, 6580 m_c_GICmp(m_Pred(Pred), m_Reg(X), 6581 m_any_of(m_GAdd(m_Reg(OpLHS), m_Reg(OpRHS)), 6582 m_GXor(m_Reg(OpLHS), m_Reg(OpRHS)))))) 6583 return false; 6584 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); 6585 } 6586 MatchInfo = [=](MachineIRBuilder &B) { 6587 auto Zero = B.buildConstant(MRI.getType(Y), 0); 6588 B.buildICmp(Pred, Dst, Y, Zero); 6589 }; 6590 return CmpInst::isEquality(Pred) && Y.isValid(); 6591 } 6592 6593 /// Return the minimum useless shift amount that results in complete loss of the 6594 /// source value. Return std::nullopt when it cannot determine a value. 6595 static std::optional<unsigned> 6596 getMinUselessShift(KnownBits ValueKB, unsigned Opcode, 6597 std::optional<int64_t> &Result) { 6598 assert(Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR || 6599 Opcode == TargetOpcode::G_ASHR && "Expect G_SHL, G_LSHR or G_ASHR."); 6600 auto SignificantBits = 0; 6601 switch (Opcode) { 6602 case TargetOpcode::G_SHL: 6603 SignificantBits = ValueKB.countMinTrailingZeros(); 6604 Result = 0; 6605 break; 6606 case TargetOpcode::G_LSHR: 6607 Result = 0; 6608 SignificantBits = ValueKB.countMinLeadingZeros(); 6609 break; 6610 case TargetOpcode::G_ASHR: 6611 if (ValueKB.isNonNegative()) { 6612 SignificantBits = ValueKB.countMinLeadingZeros(); 6613 Result = 0; 6614 } else if (ValueKB.isNegative()) { 6615 SignificantBits = ValueKB.countMinLeadingOnes(); 6616 Result = -1; 6617 } else { 6618 // Cannot determine shift result. 6619 Result = std::nullopt; 6620 } 6621 break; 6622 default: 6623 break; 6624 } 6625 return ValueKB.getBitWidth() - SignificantBits; 6626 } 6627 6628 bool CombinerHelper::matchShiftsTooBig( 6629 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const { 6630 Register ShiftVal = MI.getOperand(1).getReg(); 6631 Register ShiftReg = MI.getOperand(2).getReg(); 6632 LLT ResTy = MRI.getType(MI.getOperand(0).getReg()); 6633 auto IsShiftTooBig = [&](const Constant *C) { 6634 auto *CI = dyn_cast<ConstantInt>(C); 6635 if (!CI) 6636 return false; 6637 if (CI->uge(ResTy.getScalarSizeInBits())) { 6638 MatchInfo = std::nullopt; 6639 return true; 6640 } 6641 auto OptMaxUsefulShift = getMinUselessShift(KB->getKnownBits(ShiftVal), 6642 MI.getOpcode(), MatchInfo); 6643 return OptMaxUsefulShift && CI->uge(*OptMaxUsefulShift); 6644 }; 6645 return matchUnaryPredicate(MRI, ShiftReg, IsShiftTooBig); 6646 } 6647 6648 bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const { 6649 unsigned LHSOpndIdx = 1; 6650 unsigned RHSOpndIdx = 2; 6651 switch (MI.getOpcode()) { 6652 case TargetOpcode::G_UADDO: 6653 case TargetOpcode::G_SADDO: 6654 case TargetOpcode::G_UMULO: 6655 case TargetOpcode::G_SMULO: 6656 LHSOpndIdx = 2; 6657 RHSOpndIdx = 3; 6658 break; 6659 default: 6660 break; 6661 } 6662 Register LHS = MI.getOperand(LHSOpndIdx).getReg(); 6663 Register RHS = MI.getOperand(RHSOpndIdx).getReg(); 6664 if (!getIConstantVRegVal(LHS, MRI)) { 6665 // Skip commuting if LHS is not a constant. But, LHS may be a 6666 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already 6667 // have a constant on the RHS. 6668 if (MRI.getVRegDef(LHS)->getOpcode() != 6669 TargetOpcode::G_CONSTANT_FOLD_BARRIER) 6670 return false; 6671 } 6672 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER. 6673 return MRI.getVRegDef(RHS)->getOpcode() != 6674 TargetOpcode::G_CONSTANT_FOLD_BARRIER && 6675 !getIConstantVRegVal(RHS, MRI); 6676 } 6677 6678 bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const { 6679 Register LHS = MI.getOperand(1).getReg(); 6680 Register RHS = MI.getOperand(2).getReg(); 6681 std::optional<FPValueAndVReg> ValAndVReg; 6682 if (!mi_match(LHS, MRI, m_GFCstOrSplat(ValAndVReg))) 6683 return false; 6684 return !mi_match(RHS, MRI, m_GFCstOrSplat(ValAndVReg)); 6685 } 6686 6687 void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const { 6688 Observer.changingInstr(MI); 6689 unsigned LHSOpndIdx = 1; 6690 unsigned RHSOpndIdx = 2; 6691 switch (MI.getOpcode()) { 6692 case TargetOpcode::G_UADDO: 6693 case TargetOpcode::G_SADDO: 6694 case TargetOpcode::G_UMULO: 6695 case TargetOpcode::G_SMULO: 6696 LHSOpndIdx = 2; 6697 RHSOpndIdx = 3; 6698 break; 6699 default: 6700 break; 6701 } 6702 Register LHSReg = MI.getOperand(LHSOpndIdx).getReg(); 6703 Register RHSReg = MI.getOperand(RHSOpndIdx).getReg(); 6704 MI.getOperand(LHSOpndIdx).setReg(RHSReg); 6705 MI.getOperand(RHSOpndIdx).setReg(LHSReg); 6706 Observer.changedInstr(MI); 6707 } 6708 6709 bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const { 6710 LLT SrcTy = MRI.getType(Src); 6711 if (SrcTy.isFixedVector()) 6712 return isConstantSplatVector(Src, 1, AllowUndefs); 6713 if (SrcTy.isScalar()) { 6714 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) 6715 return true; 6716 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6717 return IConstant && IConstant->Value == 1; 6718 } 6719 return false; // scalable vector 6720 } 6721 6722 bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const { 6723 LLT SrcTy = MRI.getType(Src); 6724 if (SrcTy.isFixedVector()) 6725 return isConstantSplatVector(Src, 0, AllowUndefs); 6726 if (SrcTy.isScalar()) { 6727 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) 6728 return true; 6729 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6730 return IConstant && IConstant->Value == 0; 6731 } 6732 return false; // scalable vector 6733 } 6734 6735 // Ignores COPYs during conformance checks. 6736 // FIXME scalable vectors. 6737 bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, 6738 bool AllowUndefs) const { 6739 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6740 if (!BuildVector) 6741 return false; 6742 unsigned NumSources = BuildVector->getNumSources(); 6743 6744 for (unsigned I = 0; I < NumSources; ++I) { 6745 GImplicitDef *ImplicitDef = 6746 getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI); 6747 if (ImplicitDef && AllowUndefs) 6748 continue; 6749 if (ImplicitDef && !AllowUndefs) 6750 return false; 6751 std::optional<ValueAndVReg> IConstant = 6752 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6753 if (IConstant && IConstant->Value == SplatValue) 6754 continue; 6755 return false; 6756 } 6757 return true; 6758 } 6759 6760 // Ignores COPYs during lookups. 6761 // FIXME scalable vectors 6762 std::optional<APInt> 6763 CombinerHelper::getConstantOrConstantSplatVector(Register Src) const { 6764 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6765 if (IConstant) 6766 return IConstant->Value; 6767 6768 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6769 if (!BuildVector) 6770 return std::nullopt; 6771 unsigned NumSources = BuildVector->getNumSources(); 6772 6773 std::optional<APInt> Value = std::nullopt; 6774 for (unsigned I = 0; I < NumSources; ++I) { 6775 std::optional<ValueAndVReg> IConstant = 6776 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6777 if (!IConstant) 6778 return std::nullopt; 6779 if (!Value) 6780 Value = IConstant->Value; 6781 else if (*Value != IConstant->Value) 6782 return std::nullopt; 6783 } 6784 return Value; 6785 } 6786 6787 // FIXME G_SPLAT_VECTOR 6788 bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const { 6789 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6790 if (IConstant) 6791 return true; 6792 6793 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6794 if (!BuildVector) 6795 return false; 6796 6797 unsigned NumSources = BuildVector->getNumSources(); 6798 for (unsigned I = 0; I < NumSources; ++I) { 6799 std::optional<ValueAndVReg> IConstant = 6800 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6801 if (!IConstant) 6802 return false; 6803 } 6804 return true; 6805 } 6806 6807 // TODO: use knownbits to determine zeros 6808 bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, 6809 BuildFnTy &MatchInfo) const { 6810 uint32_t Flags = Select->getFlags(); 6811 Register Dest = Select->getReg(0); 6812 Register Cond = Select->getCondReg(); 6813 Register True = Select->getTrueReg(); 6814 Register False = Select->getFalseReg(); 6815 LLT CondTy = MRI.getType(Select->getCondReg()); 6816 LLT TrueTy = MRI.getType(Select->getTrueReg()); 6817 6818 // We only do this combine for scalar boolean conditions. 6819 if (CondTy != LLT::scalar(1)) 6820 return false; 6821 6822 if (TrueTy.isPointer()) 6823 return false; 6824 6825 // Both are scalars. 6826 std::optional<ValueAndVReg> TrueOpt = 6827 getIConstantVRegValWithLookThrough(True, MRI); 6828 std::optional<ValueAndVReg> FalseOpt = 6829 getIConstantVRegValWithLookThrough(False, MRI); 6830 6831 if (!TrueOpt || !FalseOpt) 6832 return false; 6833 6834 APInt TrueValue = TrueOpt->Value; 6835 APInt FalseValue = FalseOpt->Value; 6836 6837 // select Cond, 1, 0 --> zext (Cond) 6838 if (TrueValue.isOne() && FalseValue.isZero()) { 6839 MatchInfo = [=](MachineIRBuilder &B) { 6840 B.setInstrAndDebugLoc(*Select); 6841 B.buildZExtOrTrunc(Dest, Cond); 6842 }; 6843 return true; 6844 } 6845 6846 // select Cond, -1, 0 --> sext (Cond) 6847 if (TrueValue.isAllOnes() && FalseValue.isZero()) { 6848 MatchInfo = [=](MachineIRBuilder &B) { 6849 B.setInstrAndDebugLoc(*Select); 6850 B.buildSExtOrTrunc(Dest, Cond); 6851 }; 6852 return true; 6853 } 6854 6855 // select Cond, 0, 1 --> zext (!Cond) 6856 if (TrueValue.isZero() && FalseValue.isOne()) { 6857 MatchInfo = [=](MachineIRBuilder &B) { 6858 B.setInstrAndDebugLoc(*Select); 6859 Register Inner = MRI.createGenericVirtualRegister(CondTy); 6860 B.buildNot(Inner, Cond); 6861 B.buildZExtOrTrunc(Dest, Inner); 6862 }; 6863 return true; 6864 } 6865 6866 // select Cond, 0, -1 --> sext (!Cond) 6867 if (TrueValue.isZero() && FalseValue.isAllOnes()) { 6868 MatchInfo = [=](MachineIRBuilder &B) { 6869 B.setInstrAndDebugLoc(*Select); 6870 Register Inner = MRI.createGenericVirtualRegister(CondTy); 6871 B.buildNot(Inner, Cond); 6872 B.buildSExtOrTrunc(Dest, Inner); 6873 }; 6874 return true; 6875 } 6876 6877 // select Cond, C1, C1-1 --> add (zext Cond), C1-1 6878 if (TrueValue - 1 == FalseValue) { 6879 MatchInfo = [=](MachineIRBuilder &B) { 6880 B.setInstrAndDebugLoc(*Select); 6881 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6882 B.buildZExtOrTrunc(Inner, Cond); 6883 B.buildAdd(Dest, Inner, False); 6884 }; 6885 return true; 6886 } 6887 6888 // select Cond, C1, C1+1 --> add (sext Cond), C1+1 6889 if (TrueValue + 1 == FalseValue) { 6890 MatchInfo = [=](MachineIRBuilder &B) { 6891 B.setInstrAndDebugLoc(*Select); 6892 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6893 B.buildSExtOrTrunc(Inner, Cond); 6894 B.buildAdd(Dest, Inner, False); 6895 }; 6896 return true; 6897 } 6898 6899 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) 6900 if (TrueValue.isPowerOf2() && FalseValue.isZero()) { 6901 MatchInfo = [=](MachineIRBuilder &B) { 6902 B.setInstrAndDebugLoc(*Select); 6903 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6904 B.buildZExtOrTrunc(Inner, Cond); 6905 // The shift amount must be scalar. 6906 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; 6907 auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2()); 6908 B.buildShl(Dest, Inner, ShAmtC, Flags); 6909 }; 6910 return true; 6911 } 6912 6913 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2) 6914 if (FalseValue.isPowerOf2() && TrueValue.isZero()) { 6915 MatchInfo = [=](MachineIRBuilder &B) { 6916 B.setInstrAndDebugLoc(*Select); 6917 Register Not = MRI.createGenericVirtualRegister(CondTy); 6918 B.buildNot(Not, Cond); 6919 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6920 B.buildZExtOrTrunc(Inner, Not); 6921 // The shift amount must be scalar. 6922 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; 6923 auto ShAmtC = B.buildConstant(ShiftTy, FalseValue.exactLogBase2()); 6924 B.buildShl(Dest, Inner, ShAmtC, Flags); 6925 }; 6926 return true; 6927 } 6928 6929 // select Cond, -1, C --> or (sext Cond), C 6930 if (TrueValue.isAllOnes()) { 6931 MatchInfo = [=](MachineIRBuilder &B) { 6932 B.setInstrAndDebugLoc(*Select); 6933 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6934 B.buildSExtOrTrunc(Inner, Cond); 6935 B.buildOr(Dest, Inner, False, Flags); 6936 }; 6937 return true; 6938 } 6939 6940 // select Cond, C, -1 --> or (sext (not Cond)), C 6941 if (FalseValue.isAllOnes()) { 6942 MatchInfo = [=](MachineIRBuilder &B) { 6943 B.setInstrAndDebugLoc(*Select); 6944 Register Not = MRI.createGenericVirtualRegister(CondTy); 6945 B.buildNot(Not, Cond); 6946 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 6947 B.buildSExtOrTrunc(Inner, Not); 6948 B.buildOr(Dest, Inner, True, Flags); 6949 }; 6950 return true; 6951 } 6952 6953 return false; 6954 } 6955 6956 // TODO: use knownbits to determine zeros 6957 bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, 6958 BuildFnTy &MatchInfo) const { 6959 uint32_t Flags = Select->getFlags(); 6960 Register DstReg = Select->getReg(0); 6961 Register Cond = Select->getCondReg(); 6962 Register True = Select->getTrueReg(); 6963 Register False = Select->getFalseReg(); 6964 LLT CondTy = MRI.getType(Select->getCondReg()); 6965 LLT TrueTy = MRI.getType(Select->getTrueReg()); 6966 6967 // Boolean or fixed vector of booleans. 6968 if (CondTy.isScalableVector() || 6969 (CondTy.isFixedVector() && 6970 CondTy.getElementType().getScalarSizeInBits() != 1) || 6971 CondTy.getScalarSizeInBits() != 1) 6972 return false; 6973 6974 if (CondTy != TrueTy) 6975 return false; 6976 6977 // select Cond, Cond, F --> or Cond, F 6978 // select Cond, 1, F --> or Cond, F 6979 if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) { 6980 MatchInfo = [=](MachineIRBuilder &B) { 6981 B.setInstrAndDebugLoc(*Select); 6982 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 6983 B.buildZExtOrTrunc(Ext, Cond); 6984 auto FreezeFalse = B.buildFreeze(TrueTy, False); 6985 B.buildOr(DstReg, Ext, FreezeFalse, Flags); 6986 }; 6987 return true; 6988 } 6989 6990 // select Cond, T, Cond --> and Cond, T 6991 // select Cond, T, 0 --> and Cond, T 6992 if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) { 6993 MatchInfo = [=](MachineIRBuilder &B) { 6994 B.setInstrAndDebugLoc(*Select); 6995 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 6996 B.buildZExtOrTrunc(Ext, Cond); 6997 auto FreezeTrue = B.buildFreeze(TrueTy, True); 6998 B.buildAnd(DstReg, Ext, FreezeTrue); 6999 }; 7000 return true; 7001 } 7002 7003 // select Cond, T, 1 --> or (not Cond), T 7004 if (isOneOrOneSplat(False, /* AllowUndefs */ true)) { 7005 MatchInfo = [=](MachineIRBuilder &B) { 7006 B.setInstrAndDebugLoc(*Select); 7007 // First the not. 7008 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7009 B.buildNot(Inner, Cond); 7010 // Then an ext to match the destination register. 7011 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7012 B.buildZExtOrTrunc(Ext, Inner); 7013 auto FreezeTrue = B.buildFreeze(TrueTy, True); 7014 B.buildOr(DstReg, Ext, FreezeTrue, Flags); 7015 }; 7016 return true; 7017 } 7018 7019 // select Cond, 0, F --> and (not Cond), F 7020 if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) { 7021 MatchInfo = [=](MachineIRBuilder &B) { 7022 B.setInstrAndDebugLoc(*Select); 7023 // First the not. 7024 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7025 B.buildNot(Inner, Cond); 7026 // Then an ext to match the destination register. 7027 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7028 B.buildZExtOrTrunc(Ext, Inner); 7029 auto FreezeFalse = B.buildFreeze(TrueTy, False); 7030 B.buildAnd(DstReg, Ext, FreezeFalse); 7031 }; 7032 return true; 7033 } 7034 7035 return false; 7036 } 7037 7038 bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO, 7039 BuildFnTy &MatchInfo) const { 7040 GSelect *Select = cast<GSelect>(MRI.getVRegDef(MO.getReg())); 7041 GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Select->getCondReg())); 7042 7043 Register DstReg = Select->getReg(0); 7044 Register True = Select->getTrueReg(); 7045 Register False = Select->getFalseReg(); 7046 LLT DstTy = MRI.getType(DstReg); 7047 7048 if (DstTy.isPointer()) 7049 return false; 7050 7051 // We want to fold the icmp and replace the select. 7052 if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) 7053 return false; 7054 7055 CmpInst::Predicate Pred = Cmp->getCond(); 7056 // We need a larger or smaller predicate for 7057 // canonicalization. 7058 if (CmpInst::isEquality(Pred)) 7059 return false; 7060 7061 Register CmpLHS = Cmp->getLHSReg(); 7062 Register CmpRHS = Cmp->getRHSReg(); 7063 7064 // We can swap CmpLHS and CmpRHS for higher hitrate. 7065 if (True == CmpRHS && False == CmpLHS) { 7066 std::swap(CmpLHS, CmpRHS); 7067 Pred = CmpInst::getSwappedPredicate(Pred); 7068 } 7069 7070 // (icmp X, Y) ? X : Y -> integer minmax. 7071 // see matchSelectPattern in ValueTracking. 7072 // Legality between G_SELECT and integer minmax can differ. 7073 if (True != CmpLHS || False != CmpRHS) 7074 return false; 7075 7076 switch (Pred) { 7077 case ICmpInst::ICMP_UGT: 7078 case ICmpInst::ICMP_UGE: { 7079 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMAX, DstTy})) 7080 return false; 7081 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(DstReg, True, False); }; 7082 return true; 7083 } 7084 case ICmpInst::ICMP_SGT: 7085 case ICmpInst::ICMP_SGE: { 7086 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMAX, DstTy})) 7087 return false; 7088 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(DstReg, True, False); }; 7089 return true; 7090 } 7091 case ICmpInst::ICMP_ULT: 7092 case ICmpInst::ICMP_ULE: { 7093 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMIN, DstTy})) 7094 return false; 7095 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(DstReg, True, False); }; 7096 return true; 7097 } 7098 case ICmpInst::ICMP_SLT: 7099 case ICmpInst::ICMP_SLE: { 7100 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMIN, DstTy})) 7101 return false; 7102 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(DstReg, True, False); }; 7103 return true; 7104 } 7105 default: 7106 return false; 7107 } 7108 } 7109 7110 // (neg (min/max x, (neg x))) --> (max/min x, (neg x)) 7111 bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI, 7112 BuildFnTy &MatchInfo) const { 7113 assert(MI.getOpcode() == TargetOpcode::G_SUB); 7114 Register DestReg = MI.getOperand(0).getReg(); 7115 LLT DestTy = MRI.getType(DestReg); 7116 7117 Register X; 7118 Register Sub0; 7119 auto NegPattern = m_all_of(m_Neg(m_DeferredReg(X)), m_Reg(Sub0)); 7120 if (mi_match(DestReg, MRI, 7121 m_Neg(m_OneUse(m_any_of(m_GSMin(m_Reg(X), NegPattern), 7122 m_GSMax(m_Reg(X), NegPattern), 7123 m_GUMin(m_Reg(X), NegPattern), 7124 m_GUMax(m_Reg(X), NegPattern)))))) { 7125 MachineInstr *MinMaxMI = MRI.getVRegDef(MI.getOperand(2).getReg()); 7126 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxMI->getOpcode()); 7127 if (isLegal({NewOpc, {DestTy}})) { 7128 MatchInfo = [=](MachineIRBuilder &B) { 7129 B.buildInstr(NewOpc, {DestReg}, {X, Sub0}); 7130 }; 7131 return true; 7132 } 7133 } 7134 7135 return false; 7136 } 7137 7138 bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7139 GSelect *Select = cast<GSelect>(&MI); 7140 7141 if (tryFoldSelectOfConstants(Select, MatchInfo)) 7142 return true; 7143 7144 if (tryFoldBoolSelectToLogic(Select, MatchInfo)) 7145 return true; 7146 7147 return false; 7148 } 7149 7150 /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) 7151 /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) 7152 /// into a single comparison using range-based reasoning. 7153 /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. 7154 bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges( 7155 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const { 7156 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor"); 7157 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; 7158 Register DstReg = Logic->getReg(0); 7159 Register LHS = Logic->getLHSReg(); 7160 Register RHS = Logic->getRHSReg(); 7161 unsigned Flags = Logic->getFlags(); 7162 7163 // We need an G_ICMP on the LHS register. 7164 GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI); 7165 if (!Cmp1) 7166 return false; 7167 7168 // We need an G_ICMP on the RHS register. 7169 GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI); 7170 if (!Cmp2) 7171 return false; 7172 7173 // We want to fold the icmps. 7174 if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) || 7175 !MRI.hasOneNonDBGUse(Cmp2->getReg(0))) 7176 return false; 7177 7178 APInt C1; 7179 APInt C2; 7180 std::optional<ValueAndVReg> MaybeC1 = 7181 getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI); 7182 if (!MaybeC1) 7183 return false; 7184 C1 = MaybeC1->Value; 7185 7186 std::optional<ValueAndVReg> MaybeC2 = 7187 getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI); 7188 if (!MaybeC2) 7189 return false; 7190 C2 = MaybeC2->Value; 7191 7192 Register R1 = Cmp1->getLHSReg(); 7193 Register R2 = Cmp2->getLHSReg(); 7194 CmpInst::Predicate Pred1 = Cmp1->getCond(); 7195 CmpInst::Predicate Pred2 = Cmp2->getCond(); 7196 LLT CmpTy = MRI.getType(Cmp1->getReg(0)); 7197 LLT CmpOperandTy = MRI.getType(R1); 7198 7199 if (CmpOperandTy.isPointer()) 7200 return false; 7201 7202 // We build ands, adds, and constants of type CmpOperandTy. 7203 // They must be legal to build. 7204 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) || 7205 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) || 7206 !isConstantLegalOrBeforeLegalizer(CmpOperandTy)) 7207 return false; 7208 7209 // Look through add of a constant offset on R1, R2, or both operands. This 7210 // allows us to interpret the R + C' < C'' range idiom into a proper range. 7211 std::optional<APInt> Offset1; 7212 std::optional<APInt> Offset2; 7213 if (R1 != R2) { 7214 if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) { 7215 std::optional<ValueAndVReg> MaybeOffset1 = 7216 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI); 7217 if (MaybeOffset1) { 7218 R1 = Add->getLHSReg(); 7219 Offset1 = MaybeOffset1->Value; 7220 } 7221 } 7222 if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) { 7223 std::optional<ValueAndVReg> MaybeOffset2 = 7224 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI); 7225 if (MaybeOffset2) { 7226 R2 = Add->getLHSReg(); 7227 Offset2 = MaybeOffset2->Value; 7228 } 7229 } 7230 } 7231 7232 if (R1 != R2) 7233 return false; 7234 7235 // We calculate the icmp ranges including maybe offsets. 7236 ConstantRange CR1 = ConstantRange::makeExactICmpRegion( 7237 IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1); 7238 if (Offset1) 7239 CR1 = CR1.subtract(*Offset1); 7240 7241 ConstantRange CR2 = ConstantRange::makeExactICmpRegion( 7242 IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2); 7243 if (Offset2) 7244 CR2 = CR2.subtract(*Offset2); 7245 7246 bool CreateMask = false; 7247 APInt LowerDiff; 7248 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2); 7249 if (!CR) { 7250 // We need non-wrapping ranges. 7251 if (CR1.isWrappedSet() || CR2.isWrappedSet()) 7252 return false; 7253 7254 // Check whether we have equal-size ranges that only differ by one bit. 7255 // In that case we can apply a mask to map one range onto the other. 7256 LowerDiff = CR1.getLower() ^ CR2.getLower(); 7257 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); 7258 APInt CR1Size = CR1.getUpper() - CR1.getLower(); 7259 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || 7260 CR1Size != CR2.getUpper() - CR2.getLower()) 7261 return false; 7262 7263 CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2; 7264 CreateMask = true; 7265 } 7266 7267 if (IsAnd) 7268 CR = CR->inverse(); 7269 7270 CmpInst::Predicate NewPred; 7271 APInt NewC, Offset; 7272 CR->getEquivalentICmp(NewPred, NewC, Offset); 7273 7274 // We take the result type of one of the original icmps, CmpTy, for 7275 // the to be build icmp. The operand type, CmpOperandTy, is used for 7276 // the other instructions and constants to be build. The types of 7277 // the parameters and output are the same for add and and. CmpTy 7278 // and the type of DstReg might differ. That is why we zext or trunc 7279 // the icmp into the destination register. 7280 7281 MatchInfo = [=](MachineIRBuilder &B) { 7282 if (CreateMask && Offset != 0) { 7283 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff); 7284 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask. 7285 auto OffsetC = B.buildConstant(CmpOperandTy, Offset); 7286 auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags); 7287 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7288 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon); 7289 B.buildZExtOrTrunc(DstReg, ICmp); 7290 } else if (CreateMask && Offset == 0) { 7291 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff); 7292 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask. 7293 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7294 auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon); 7295 B.buildZExtOrTrunc(DstReg, ICmp); 7296 } else if (!CreateMask && Offset != 0) { 7297 auto OffsetC = B.buildConstant(CmpOperandTy, Offset); 7298 auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags); 7299 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7300 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon); 7301 B.buildZExtOrTrunc(DstReg, ICmp); 7302 } else if (!CreateMask && Offset == 0) { 7303 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7304 auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon); 7305 B.buildZExtOrTrunc(DstReg, ICmp); 7306 } else { 7307 llvm_unreachable("unexpected configuration of CreateMask and Offset"); 7308 } 7309 }; 7310 return true; 7311 } 7312 7313 bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic, 7314 BuildFnTy &MatchInfo) const { 7315 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor"); 7316 Register DestReg = Logic->getReg(0); 7317 Register LHS = Logic->getLHSReg(); 7318 Register RHS = Logic->getRHSReg(); 7319 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; 7320 7321 // We need a compare on the LHS register. 7322 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(LHS, MRI); 7323 if (!Cmp1) 7324 return false; 7325 7326 // We need a compare on the RHS register. 7327 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(RHS, MRI); 7328 if (!Cmp2) 7329 return false; 7330 7331 LLT CmpTy = MRI.getType(Cmp1->getReg(0)); 7332 LLT CmpOperandTy = MRI.getType(Cmp1->getLHSReg()); 7333 7334 // We build one fcmp, want to fold the fcmps, replace the logic op, 7335 // and the fcmps must have the same shape. 7336 if (!isLegalOrBeforeLegalizer( 7337 {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) || 7338 !MRI.hasOneNonDBGUse(Logic->getReg(0)) || 7339 !MRI.hasOneNonDBGUse(Cmp1->getReg(0)) || 7340 !MRI.hasOneNonDBGUse(Cmp2->getReg(0)) || 7341 MRI.getType(Cmp1->getLHSReg()) != MRI.getType(Cmp2->getLHSReg())) 7342 return false; 7343 7344 CmpInst::Predicate PredL = Cmp1->getCond(); 7345 CmpInst::Predicate PredR = Cmp2->getCond(); 7346 Register LHS0 = Cmp1->getLHSReg(); 7347 Register LHS1 = Cmp1->getRHSReg(); 7348 Register RHS0 = Cmp2->getLHSReg(); 7349 Register RHS1 = Cmp2->getRHSReg(); 7350 7351 if (LHS0 == RHS1 && LHS1 == RHS0) { 7352 // Swap RHS operands to match LHS. 7353 PredR = CmpInst::getSwappedPredicate(PredR); 7354 std::swap(RHS0, RHS1); 7355 } 7356 7357 if (LHS0 == RHS0 && LHS1 == RHS1) { 7358 // We determine the new predicate. 7359 unsigned CmpCodeL = getFCmpCode(PredL); 7360 unsigned CmpCodeR = getFCmpCode(PredR); 7361 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR; 7362 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags(); 7363 MatchInfo = [=](MachineIRBuilder &B) { 7364 // The fcmp predicates fill the lower part of the enum. 7365 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred); 7366 if (Pred == FCmpInst::FCMP_FALSE && 7367 isConstantLegalOrBeforeLegalizer(CmpTy)) { 7368 auto False = B.buildConstant(CmpTy, 0); 7369 B.buildZExtOrTrunc(DestReg, False); 7370 } else if (Pred == FCmpInst::FCMP_TRUE && 7371 isConstantLegalOrBeforeLegalizer(CmpTy)) { 7372 auto True = 7373 B.buildConstant(CmpTy, getICmpTrueVal(getTargetLowering(), 7374 CmpTy.isVector() /*isVector*/, 7375 true /*isFP*/)); 7376 B.buildZExtOrTrunc(DestReg, True); 7377 } else { // We take the predicate without predicate optimizations. 7378 auto Cmp = B.buildFCmp(Pred, CmpTy, LHS0, LHS1, Flags); 7379 B.buildZExtOrTrunc(DestReg, Cmp); 7380 } 7381 }; 7382 return true; 7383 } 7384 7385 return false; 7386 } 7387 7388 bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7389 GAnd *And = cast<GAnd>(&MI); 7390 7391 if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo)) 7392 return true; 7393 7394 if (tryFoldLogicOfFCmps(And, MatchInfo)) 7395 return true; 7396 7397 return false; 7398 } 7399 7400 bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7401 GOr *Or = cast<GOr>(&MI); 7402 7403 if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo)) 7404 return true; 7405 7406 if (tryFoldLogicOfFCmps(Or, MatchInfo)) 7407 return true; 7408 7409 return false; 7410 } 7411 7412 bool CombinerHelper::matchAddOverflow(MachineInstr &MI, 7413 BuildFnTy &MatchInfo) const { 7414 GAddCarryOut *Add = cast<GAddCarryOut>(&MI); 7415 7416 // Addo has no flags 7417 Register Dst = Add->getReg(0); 7418 Register Carry = Add->getReg(1); 7419 Register LHS = Add->getLHSReg(); 7420 Register RHS = Add->getRHSReg(); 7421 bool IsSigned = Add->isSigned(); 7422 LLT DstTy = MRI.getType(Dst); 7423 LLT CarryTy = MRI.getType(Carry); 7424 7425 // Fold addo, if the carry is dead -> add, undef. 7426 if (MRI.use_nodbg_empty(Carry) && 7427 isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}})) { 7428 MatchInfo = [=](MachineIRBuilder &B) { 7429 B.buildAdd(Dst, LHS, RHS); 7430 B.buildUndef(Carry); 7431 }; 7432 return true; 7433 } 7434 7435 // Canonicalize constant to RHS. 7436 if (isConstantOrConstantVectorI(LHS) && !isConstantOrConstantVectorI(RHS)) { 7437 if (IsSigned) { 7438 MatchInfo = [=](MachineIRBuilder &B) { 7439 B.buildSAddo(Dst, Carry, RHS, LHS); 7440 }; 7441 return true; 7442 } 7443 // !IsSigned 7444 MatchInfo = [=](MachineIRBuilder &B) { 7445 B.buildUAddo(Dst, Carry, RHS, LHS); 7446 }; 7447 return true; 7448 } 7449 7450 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(LHS); 7451 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(RHS); 7452 7453 // Fold addo(c1, c2) -> c3, carry. 7454 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(DstTy) && 7455 isConstantLegalOrBeforeLegalizer(CarryTy)) { 7456 bool Overflow; 7457 APInt Result = IsSigned ? MaybeLHS->sadd_ov(*MaybeRHS, Overflow) 7458 : MaybeLHS->uadd_ov(*MaybeRHS, Overflow); 7459 MatchInfo = [=](MachineIRBuilder &B) { 7460 B.buildConstant(Dst, Result); 7461 B.buildConstant(Carry, Overflow); 7462 }; 7463 return true; 7464 } 7465 7466 // Fold (addo x, 0) -> x, no carry 7467 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(CarryTy)) { 7468 MatchInfo = [=](MachineIRBuilder &B) { 7469 B.buildCopy(Dst, LHS); 7470 B.buildConstant(Carry, 0); 7471 }; 7472 return true; 7473 } 7474 7475 // Given 2 constant operands whose sum does not overflow: 7476 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 7477 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 7478 GAdd *AddLHS = getOpcodeDef<GAdd>(LHS, MRI); 7479 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(Add->getReg(0)) && 7480 ((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) || 7481 (!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) { 7482 std::optional<APInt> MaybeAddRHS = 7483 getConstantOrConstantSplatVector(AddLHS->getRHSReg()); 7484 if (MaybeAddRHS) { 7485 bool Overflow; 7486 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(*MaybeRHS, Overflow) 7487 : MaybeAddRHS->uadd_ov(*MaybeRHS, Overflow); 7488 if (!Overflow && isConstantLegalOrBeforeLegalizer(DstTy)) { 7489 if (IsSigned) { 7490 MatchInfo = [=](MachineIRBuilder &B) { 7491 auto ConstRHS = B.buildConstant(DstTy, NewC); 7492 B.buildSAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS); 7493 }; 7494 return true; 7495 } 7496 // !IsSigned 7497 MatchInfo = [=](MachineIRBuilder &B) { 7498 auto ConstRHS = B.buildConstant(DstTy, NewC); 7499 B.buildUAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS); 7500 }; 7501 return true; 7502 } 7503 } 7504 }; 7505 7506 // We try to combine addo to non-overflowing add. 7507 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) || 7508 !isConstantLegalOrBeforeLegalizer(CarryTy)) 7509 return false; 7510 7511 // We try to combine uaddo to non-overflowing add. 7512 if (!IsSigned) { 7513 ConstantRange CRLHS = 7514 ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/false); 7515 ConstantRange CRRHS = 7516 ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/false); 7517 7518 switch (CRLHS.unsignedAddMayOverflow(CRRHS)) { 7519 case ConstantRange::OverflowResult::MayOverflow: 7520 return false; 7521 case ConstantRange::OverflowResult::NeverOverflows: { 7522 MatchInfo = [=](MachineIRBuilder &B) { 7523 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap); 7524 B.buildConstant(Carry, 0); 7525 }; 7526 return true; 7527 } 7528 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7529 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7530 MatchInfo = [=](MachineIRBuilder &B) { 7531 B.buildAdd(Dst, LHS, RHS); 7532 B.buildConstant(Carry, 1); 7533 }; 7534 return true; 7535 } 7536 } 7537 return false; 7538 } 7539 7540 // We try to combine saddo to non-overflowing add. 7541 7542 // If LHS and RHS each have at least two sign bits, then there is no signed 7543 // overflow. 7544 if (KB->computeNumSignBits(RHS) > 1 && KB->computeNumSignBits(LHS) > 1) { 7545 MatchInfo = [=](MachineIRBuilder &B) { 7546 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 7547 B.buildConstant(Carry, 0); 7548 }; 7549 return true; 7550 } 7551 7552 ConstantRange CRLHS = 7553 ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/true); 7554 ConstantRange CRRHS = 7555 ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/true); 7556 7557 switch (CRLHS.signedAddMayOverflow(CRRHS)) { 7558 case ConstantRange::OverflowResult::MayOverflow: 7559 return false; 7560 case ConstantRange::OverflowResult::NeverOverflows: { 7561 MatchInfo = [=](MachineIRBuilder &B) { 7562 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 7563 B.buildConstant(Carry, 0); 7564 }; 7565 return true; 7566 } 7567 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7568 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7569 MatchInfo = [=](MachineIRBuilder &B) { 7570 B.buildAdd(Dst, LHS, RHS); 7571 B.buildConstant(Carry, 1); 7572 }; 7573 return true; 7574 } 7575 } 7576 7577 return false; 7578 } 7579 7580 void CombinerHelper::applyBuildFnMO(const MachineOperand &MO, 7581 BuildFnTy &MatchInfo) const { 7582 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 7583 MatchInfo(Builder); 7584 Root->eraseFromParent(); 7585 } 7586 7587 bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI, 7588 int64_t Exponent) const { 7589 bool OptForSize = MI.getMF()->getFunction().hasOptSize(); 7590 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize); 7591 } 7592 7593 void CombinerHelper::applyExpandFPowI(MachineInstr &MI, 7594 int64_t Exponent) const { 7595 auto [Dst, Base] = MI.getFirst2Regs(); 7596 LLT Ty = MRI.getType(Dst); 7597 int64_t ExpVal = Exponent; 7598 7599 if (ExpVal == 0) { 7600 Builder.buildFConstant(Dst, 1.0); 7601 MI.removeFromParent(); 7602 return; 7603 } 7604 7605 if (ExpVal < 0) 7606 ExpVal = -ExpVal; 7607 7608 // We use the simple binary decomposition method from SelectionDAG ExpandPowI 7609 // to generate the multiply sequence. There are more optimal ways to do this 7610 // (for example, powi(x,15) generates one more multiply than it should), but 7611 // this has the benefit of being both really simple and much better than a 7612 // libcall. 7613 std::optional<SrcOp> Res; 7614 SrcOp CurSquare = Base; 7615 while (ExpVal > 0) { 7616 if (ExpVal & 1) { 7617 if (!Res) 7618 Res = CurSquare; 7619 else 7620 Res = Builder.buildFMul(Ty, *Res, CurSquare); 7621 } 7622 7623 CurSquare = Builder.buildFMul(Ty, CurSquare, CurSquare); 7624 ExpVal >>= 1; 7625 } 7626 7627 // If the original exponent was negative, invert the result, producing 7628 // 1/(x*x*x). 7629 if (Exponent < 0) 7630 Res = Builder.buildFDiv(Ty, Builder.buildFConstant(Ty, 1.0), *Res, 7631 MI.getFlags()); 7632 7633 Builder.buildCopy(Dst, *Res); 7634 MI.eraseFromParent(); 7635 } 7636 7637 bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI, 7638 BuildFnTy &MatchInfo) const { 7639 // fold (A+C1)-C2 -> A+(C1-C2) 7640 const GSub *Sub = cast<GSub>(&MI); 7641 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getLHSReg())); 7642 7643 if (!MRI.hasOneNonDBGUse(Add->getReg(0))) 7644 return false; 7645 7646 APInt C2 = getIConstantFromReg(Sub->getRHSReg(), MRI); 7647 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI); 7648 7649 Register Dst = Sub->getReg(0); 7650 LLT DstTy = MRI.getType(Dst); 7651 7652 MatchInfo = [=](MachineIRBuilder &B) { 7653 auto Const = B.buildConstant(DstTy, C1 - C2); 7654 B.buildAdd(Dst, Add->getLHSReg(), Const); 7655 }; 7656 7657 return true; 7658 } 7659 7660 bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI, 7661 BuildFnTy &MatchInfo) const { 7662 // fold C2-(A+C1) -> (C2-C1)-A 7663 const GSub *Sub = cast<GSub>(&MI); 7664 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getRHSReg())); 7665 7666 if (!MRI.hasOneNonDBGUse(Add->getReg(0))) 7667 return false; 7668 7669 APInt C2 = getIConstantFromReg(Sub->getLHSReg(), MRI); 7670 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI); 7671 7672 Register Dst = Sub->getReg(0); 7673 LLT DstTy = MRI.getType(Dst); 7674 7675 MatchInfo = [=](MachineIRBuilder &B) { 7676 auto Const = B.buildConstant(DstTy, C2 - C1); 7677 B.buildSub(Dst, Const, Add->getLHSReg()); 7678 }; 7679 7680 return true; 7681 } 7682 7683 bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI, 7684 BuildFnTy &MatchInfo) const { 7685 // fold (A-C1)-C2 -> A-(C1+C2) 7686 const GSub *Sub1 = cast<GSub>(&MI); 7687 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg())); 7688 7689 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0))) 7690 return false; 7691 7692 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI); 7693 APInt C1 = getIConstantFromReg(Sub2->getRHSReg(), MRI); 7694 7695 Register Dst = Sub1->getReg(0); 7696 LLT DstTy = MRI.getType(Dst); 7697 7698 MatchInfo = [=](MachineIRBuilder &B) { 7699 auto Const = B.buildConstant(DstTy, C1 + C2); 7700 B.buildSub(Dst, Sub2->getLHSReg(), Const); 7701 }; 7702 7703 return true; 7704 } 7705 7706 bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI, 7707 BuildFnTy &MatchInfo) const { 7708 // fold (C1-A)-C2 -> (C1-C2)-A 7709 const GSub *Sub1 = cast<GSub>(&MI); 7710 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg())); 7711 7712 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0))) 7713 return false; 7714 7715 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI); 7716 APInt C1 = getIConstantFromReg(Sub2->getLHSReg(), MRI); 7717 7718 Register Dst = Sub1->getReg(0); 7719 LLT DstTy = MRI.getType(Dst); 7720 7721 MatchInfo = [=](MachineIRBuilder &B) { 7722 auto Const = B.buildConstant(DstTy, C1 - C2); 7723 B.buildSub(Dst, Const, Sub2->getRHSReg()); 7724 }; 7725 7726 return true; 7727 } 7728 7729 bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI, 7730 BuildFnTy &MatchInfo) const { 7731 // fold ((A-C1)+C2) -> (A+(C2-C1)) 7732 const GAdd *Add = cast<GAdd>(&MI); 7733 GSub *Sub = cast<GSub>(MRI.getVRegDef(Add->getLHSReg())); 7734 7735 if (!MRI.hasOneNonDBGUse(Sub->getReg(0))) 7736 return false; 7737 7738 APInt C2 = getIConstantFromReg(Add->getRHSReg(), MRI); 7739 APInt C1 = getIConstantFromReg(Sub->getRHSReg(), MRI); 7740 7741 Register Dst = Add->getReg(0); 7742 LLT DstTy = MRI.getType(Dst); 7743 7744 MatchInfo = [=](MachineIRBuilder &B) { 7745 auto Const = B.buildConstant(DstTy, C2 - C1); 7746 B.buildAdd(Dst, Sub->getLHSReg(), Const); 7747 }; 7748 7749 return true; 7750 } 7751 7752 bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector( 7753 const MachineInstr &MI, BuildFnTy &MatchInfo) const { 7754 const GUnmerge *Unmerge = cast<GUnmerge>(&MI); 7755 7756 if (!MRI.hasOneNonDBGUse(Unmerge->getSourceReg())) 7757 return false; 7758 7759 const MachineInstr *Source = MRI.getVRegDef(Unmerge->getSourceReg()); 7760 7761 LLT DstTy = MRI.getType(Unmerge->getReg(0)); 7762 7763 // $bv:_(<8 x s8>) = G_BUILD_VECTOR .... 7764 // $any:_(<8 x s16>) = G_ANYEXT $bv 7765 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any 7766 // 7767 // -> 7768 // 7769 // $any:_(s16) = G_ANYEXT $bv[0] 7770 // $any1:_(s16) = G_ANYEXT $bv[1] 7771 // $any2:_(s16) = G_ANYEXT $bv[2] 7772 // $any3:_(s16) = G_ANYEXT $bv[3] 7773 // $any4:_(s16) = G_ANYEXT $bv[4] 7774 // $any5:_(s16) = G_ANYEXT $bv[5] 7775 // $any6:_(s16) = G_ANYEXT $bv[6] 7776 // $any7:_(s16) = G_ANYEXT $bv[7] 7777 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3 7778 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7 7779 7780 // We want to unmerge into vectors. 7781 if (!DstTy.isFixedVector()) 7782 return false; 7783 7784 const GAnyExt *Any = dyn_cast<GAnyExt>(Source); 7785 if (!Any) 7786 return false; 7787 7788 const MachineInstr *NextSource = MRI.getVRegDef(Any->getSrcReg()); 7789 7790 if (const GBuildVector *BV = dyn_cast<GBuildVector>(NextSource)) { 7791 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR 7792 7793 if (!MRI.hasOneNonDBGUse(BV->getReg(0))) 7794 return false; 7795 7796 // FIXME: check element types? 7797 if (BV->getNumSources() % Unmerge->getNumDefs() != 0) 7798 return false; 7799 7800 LLT BigBvTy = MRI.getType(BV->getReg(0)); 7801 LLT SmallBvTy = DstTy; 7802 LLT SmallBvElemenTy = SmallBvTy.getElementType(); 7803 7804 if (!isLegalOrBeforeLegalizer( 7805 {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}})) 7806 return false; 7807 7808 // We check the legality of scalar anyext. 7809 if (!isLegalOrBeforeLegalizer( 7810 {TargetOpcode::G_ANYEXT, 7811 {SmallBvElemenTy, BigBvTy.getElementType()}})) 7812 return false; 7813 7814 MatchInfo = [=](MachineIRBuilder &B) { 7815 // Build into each G_UNMERGE_VALUES def 7816 // a small build vector with anyext from the source build vector. 7817 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) { 7818 SmallVector<Register> Ops; 7819 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) { 7820 Register SourceArray = 7821 BV->getSourceReg(I * SmallBvTy.getNumElements() + J); 7822 auto AnyExt = B.buildAnyExt(SmallBvElemenTy, SourceArray); 7823 Ops.push_back(AnyExt.getReg(0)); 7824 } 7825 B.buildBuildVector(Unmerge->getOperand(I).getReg(), Ops); 7826 }; 7827 }; 7828 return true; 7829 }; 7830 7831 return false; 7832 } 7833 7834 bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI, 7835 BuildFnTy &MatchInfo) const { 7836 7837 bool Changed = false; 7838 auto &Shuffle = cast<GShuffleVector>(MI); 7839 ArrayRef<int> OrigMask = Shuffle.getMask(); 7840 SmallVector<int, 16> NewMask; 7841 const LLT SrcTy = MRI.getType(Shuffle.getSrc1Reg()); 7842 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1; 7843 const unsigned NumDstElts = OrigMask.size(); 7844 for (unsigned i = 0; i != NumDstElts; ++i) { 7845 int Idx = OrigMask[i]; 7846 if (Idx >= (int)NumSrcElems) { 7847 Idx = -1; 7848 Changed = true; 7849 } 7850 NewMask.push_back(Idx); 7851 } 7852 7853 if (!Changed) 7854 return false; 7855 7856 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) { 7857 B.buildShuffleVector(MI.getOperand(0), MI.getOperand(1), MI.getOperand(2), 7858 std::move(NewMask)); 7859 }; 7860 7861 return true; 7862 } 7863 7864 static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) { 7865 const unsigned MaskSize = Mask.size(); 7866 for (unsigned I = 0; I < MaskSize; ++I) { 7867 int Idx = Mask[I]; 7868 if (Idx < 0) 7869 continue; 7870 7871 if (Idx < (int)NumElems) 7872 Mask[I] = Idx + NumElems; 7873 else 7874 Mask[I] = Idx - NumElems; 7875 } 7876 } 7877 7878 bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI, 7879 BuildFnTy &MatchInfo) const { 7880 7881 auto &Shuffle = cast<GShuffleVector>(MI); 7882 // If any of the two inputs is already undef, don't check the mask again to 7883 // prevent infinite loop 7884 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc1Reg(), MRI)) 7885 return false; 7886 7887 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc2Reg(), MRI)) 7888 return false; 7889 7890 const LLT DstTy = MRI.getType(Shuffle.getReg(0)); 7891 const LLT Src1Ty = MRI.getType(Shuffle.getSrc1Reg()); 7892 if (!isLegalOrBeforeLegalizer( 7893 {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}})) 7894 return false; 7895 7896 ArrayRef<int> Mask = Shuffle.getMask(); 7897 const unsigned NumSrcElems = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; 7898 7899 bool TouchesSrc1 = false; 7900 bool TouchesSrc2 = false; 7901 const unsigned NumElems = Mask.size(); 7902 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 7903 if (Mask[Idx] < 0) 7904 continue; 7905 7906 if (Mask[Idx] < (int)NumSrcElems) 7907 TouchesSrc1 = true; 7908 else 7909 TouchesSrc2 = true; 7910 } 7911 7912 if (TouchesSrc1 == TouchesSrc2) 7913 return false; 7914 7915 Register NewSrc1 = Shuffle.getSrc1Reg(); 7916 SmallVector<int, 16> NewMask(Mask); 7917 if (TouchesSrc2) { 7918 NewSrc1 = Shuffle.getSrc2Reg(); 7919 commuteMask(NewMask, NumSrcElems); 7920 } 7921 7922 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) { 7923 auto Undef = B.buildUndef(Src1Ty); 7924 B.buildShuffleVector(Shuffle.getReg(0), NewSrc1, Undef, NewMask); 7925 }; 7926 7927 return true; 7928 } 7929 7930 bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI, 7931 BuildFnTy &MatchInfo) const { 7932 const GSubCarryOut *Subo = cast<GSubCarryOut>(&MI); 7933 7934 Register Dst = Subo->getReg(0); 7935 Register LHS = Subo->getLHSReg(); 7936 Register RHS = Subo->getRHSReg(); 7937 Register Carry = Subo->getCarryOutReg(); 7938 LLT DstTy = MRI.getType(Dst); 7939 LLT CarryTy = MRI.getType(Carry); 7940 7941 // Check legality before known bits. 7942 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SUB, {DstTy}}) || 7943 !isConstantLegalOrBeforeLegalizer(CarryTy)) 7944 return false; 7945 7946 ConstantRange KBLHS = 7947 ConstantRange::fromKnownBits(KB->getKnownBits(LHS), 7948 /* IsSigned=*/Subo->isSigned()); 7949 ConstantRange KBRHS = 7950 ConstantRange::fromKnownBits(KB->getKnownBits(RHS), 7951 /* IsSigned=*/Subo->isSigned()); 7952 7953 if (Subo->isSigned()) { 7954 // G_SSUBO 7955 switch (KBLHS.signedSubMayOverflow(KBRHS)) { 7956 case ConstantRange::OverflowResult::MayOverflow: 7957 return false; 7958 case ConstantRange::OverflowResult::NeverOverflows: { 7959 MatchInfo = [=](MachineIRBuilder &B) { 7960 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 7961 B.buildConstant(Carry, 0); 7962 }; 7963 return true; 7964 } 7965 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7966 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7967 MatchInfo = [=](MachineIRBuilder &B) { 7968 B.buildSub(Dst, LHS, RHS); 7969 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(), 7970 /*isVector=*/CarryTy.isVector(), 7971 /*isFP=*/false)); 7972 }; 7973 return true; 7974 } 7975 } 7976 return false; 7977 } 7978 7979 // G_USUBO 7980 switch (KBLHS.unsignedSubMayOverflow(KBRHS)) { 7981 case ConstantRange::OverflowResult::MayOverflow: 7982 return false; 7983 case ConstantRange::OverflowResult::NeverOverflows: { 7984 MatchInfo = [=](MachineIRBuilder &B) { 7985 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap); 7986 B.buildConstant(Carry, 0); 7987 }; 7988 return true; 7989 } 7990 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7991 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7992 MatchInfo = [=](MachineIRBuilder &B) { 7993 B.buildSub(Dst, LHS, RHS); 7994 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(), 7995 /*isVector=*/CarryTy.isVector(), 7996 /*isFP=*/false)); 7997 }; 7998 return true; 7999 } 8000 } 8001 8002 return false; 8003 } 8004