1 //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass turns chains of integer comparisons into memcmp (the memcmp is 10 // later typically inlined as a chain of efficient hardware comparisons). This 11 // typically benefits c++ member or nonmember operator==(). 12 // 13 // The basic idea is to replace a longer chain of integer comparisons loaded 14 // from contiguous memory locations into a shorter chain of larger integer 15 // comparisons. Benefits are double: 16 // - There are less jumps, and therefore less opportunities for mispredictions 17 // and I-cache misses. 18 // - Code size is smaller, both because jumps are removed and because the 19 // encoding of a 2*n byte compare is smaller than that of two n-byte 20 // compares. 21 // 22 // Example: 23 // 24 // struct S { 25 // int a; 26 // char b; 27 // char c; 28 // uint16_t d; 29 // bool operator==(const S& o) const { 30 // return a == o.a && b == o.b && c == o.c && d == o.d; 31 // } 32 // }; 33 // 34 // Is optimized as : 35 // 36 // bool S::operator==(const S& o) const { 37 // return memcmp(this, &o, 8) == 0; 38 // } 39 // 40 // Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp. 41 // 42 //===----------------------------------------------------------------------===// 43 44 #include "llvm/Analysis/Loads.h" 45 #include "llvm/Analysis/TargetLibraryInfo.h" 46 #include "llvm/Analysis/TargetTransformInfo.h" 47 #include "llvm/IR/Function.h" 48 #include "llvm/IR/IRBuilder.h" 49 #include "llvm/Pass.h" 50 #include "llvm/Transforms/Scalar.h" 51 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 52 #include "llvm/Transforms/Utils/BuildLibCalls.h" 53 #include <algorithm> 54 #include <numeric> 55 #include <utility> 56 #include <vector> 57 58 using namespace llvm; 59 60 namespace { 61 62 #define DEBUG_TYPE "mergeicmps" 63 64 // Returns true if the instruction is a simple load or a simple store 65 static bool isSimpleLoadOrStore(const Instruction *I) { 66 if (const LoadInst *LI = dyn_cast<LoadInst>(I)) 67 return LI->isSimple(); 68 if (const StoreInst *SI = dyn_cast<StoreInst>(I)) 69 return SI->isSimple(); 70 return false; 71 } 72 73 // A BCE atom "Binary Compare Expression Atom" represents an integer load 74 // that is a constant offset from a base value, e.g. `a` or `o.c` in the example 75 // at the top. 76 struct BCEAtom { 77 BCEAtom() = default; 78 BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset) 79 : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {} 80 81 // We want to order BCEAtoms by (Base, Offset). However we cannot use 82 // the pointer values for Base because these are non-deterministic. 83 // To make sure that the sort order is stable, we first assign to each atom 84 // base value an index based on its order of appearance in the chain of 85 // comparisons. We call this index `BaseOrdering`. For example, for: 86 // b[3] == c[2] && a[1] == d[1] && b[4] == c[3] 87 // | block 1 | | block 2 | | block 3 | 88 // b gets assigned index 0 and a index 1, because b appears as LHS in block 1, 89 // which is before block 2. 90 // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable. 91 bool operator<(const BCEAtom &O) const { 92 return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset); 93 } 94 95 GetElementPtrInst *GEP = nullptr; 96 LoadInst *LoadI = nullptr; 97 unsigned BaseId = 0; 98 APInt Offset; 99 }; 100 101 // A class that assigns increasing ids to values in the order in which they are 102 // seen. See comment in `BCEAtom::operator<()``. 103 class BaseIdentifier { 104 public: 105 // Returns the id for value `Base`, after assigning one if `Base` has not been 106 // seen before. 107 int getBaseId(const Value *Base) { 108 assert(Base && "invalid base"); 109 const auto Insertion = BaseToIndex.try_emplace(Base, Order); 110 if (Insertion.second) 111 ++Order; 112 return Insertion.first->second; 113 } 114 115 private: 116 unsigned Order = 1; 117 DenseMap<const Value*, int> BaseToIndex; 118 }; 119 120 // If this value is a load from a constant offset w.r.t. a base address, and 121 // there are no other users of the load or address, returns the base address and 122 // the offset. 123 BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { 124 auto *const LoadI = dyn_cast<LoadInst>(Val); 125 if (!LoadI) 126 return {}; 127 LLVM_DEBUG(dbgs() << "load\n"); 128 if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { 129 LLVM_DEBUG(dbgs() << "used outside of block\n"); 130 return {}; 131 } 132 // Do not optimize atomic loads to non-atomic memcmp 133 if (!LoadI->isSimple()) { 134 LLVM_DEBUG(dbgs() << "volatile or atomic\n"); 135 return {}; 136 } 137 Value *const Addr = LoadI->getOperand(0); 138 auto *const GEP = dyn_cast<GetElementPtrInst>(Addr); 139 if (!GEP) 140 return {}; 141 LLVM_DEBUG(dbgs() << "GEP\n"); 142 if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { 143 LLVM_DEBUG(dbgs() << "used outside of block\n"); 144 return {}; 145 } 146 const auto &DL = GEP->getModule()->getDataLayout(); 147 if (!isDereferenceablePointer(GEP, DL)) { 148 LLVM_DEBUG(dbgs() << "not dereferenceable\n"); 149 // We need to make sure that we can do comparison in any order, so we 150 // require memory to be unconditionnally dereferencable. 151 return {}; 152 } 153 APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); 154 if (!GEP->accumulateConstantOffset(DL, Offset)) 155 return {}; 156 return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()), 157 Offset); 158 } 159 160 // A basic block with a comparison between two BCE atoms, e.g. `a == o.a` in the 161 // example at the top. 162 // The block might do extra work besides the atom comparison, in which case 163 // doesOtherWork() returns true. Under some conditions, the block can be 164 // split into the atom comparison part and the "other work" part 165 // (see canSplit()). 166 // Note: the terminology is misleading: the comparison is symmetric, so there 167 // is no real {l/r}hs. What we want though is to have the same base on the 168 // left (resp. right), so that we can detect consecutive loads. To ensure this 169 // we put the smallest atom on the left. 170 class BCECmpBlock { 171 public: 172 BCECmpBlock() {} 173 174 BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits) 175 : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) { 176 if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); 177 } 178 179 bool IsValid() const { return Lhs_.BaseId != 0 && Rhs_.BaseId != 0; } 180 181 // Assert the block is consistent: If valid, it should also have 182 // non-null members besides Lhs_ and Rhs_. 183 void AssertConsistent() const { 184 if (IsValid()) { 185 assert(BB); 186 assert(CmpI); 187 assert(BranchI); 188 } 189 } 190 191 const BCEAtom &Lhs() const { return Lhs_; } 192 const BCEAtom &Rhs() const { return Rhs_; } 193 int SizeBits() const { return SizeBits_; } 194 195 // Returns true if the block does other works besides comparison. 196 bool doesOtherWork() const; 197 198 // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp 199 // instructions in the block. 200 bool canSplit(AliasAnalysis *AA) const; 201 202 // Return true if this all the relevant instructions in the BCE-cmp-block can 203 // be sunk below this instruction. By doing this, we know we can separate the 204 // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the 205 // block. 206 bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &, 207 AliasAnalysis *AA) const; 208 209 // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block 210 // instructions. Split the old block and move all non-BCE-cmp-insts into the 211 // new parent block. 212 void split(BasicBlock *NewParent, AliasAnalysis *AA) const; 213 214 // The basic block where this comparison happens. 215 BasicBlock *BB = nullptr; 216 // The ICMP for this comparison. 217 ICmpInst *CmpI = nullptr; 218 // The terminating branch. 219 BranchInst *BranchI = nullptr; 220 // The block requires splitting. 221 bool RequireSplit = false; 222 223 private: 224 BCEAtom Lhs_; 225 BCEAtom Rhs_; 226 int SizeBits_ = 0; 227 }; 228 229 bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, 230 DenseSet<Instruction *> &BlockInsts, 231 AliasAnalysis *AA) const { 232 // If this instruction has side effects and its in middle of the BCE cmp block 233 // instructions, then bail for now. 234 if (Inst->mayHaveSideEffects()) { 235 // Bail if this is not a simple load or store 236 if (!isSimpleLoadOrStore(Inst)) 237 return false; 238 // Disallow stores that might alias the BCE operands 239 MemoryLocation LLoc = MemoryLocation::get(Lhs_.LoadI); 240 MemoryLocation RLoc = MemoryLocation::get(Rhs_.LoadI); 241 if (isModSet(AA->getModRefInfo(Inst, LLoc)) || 242 isModSet(AA->getModRefInfo(Inst, RLoc))) 243 return false; 244 } 245 // Make sure this instruction does not use any of the BCE cmp block 246 // instructions as operand. 247 for (auto BI : BlockInsts) { 248 if (is_contained(Inst->operands(), BI)) 249 return false; 250 } 251 return true; 252 } 253 254 void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { 255 DenseSet<Instruction *> BlockInsts( 256 {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); 257 llvm::SmallVector<Instruction *, 4> OtherInsts; 258 for (Instruction &Inst : *BB) { 259 if (BlockInsts.count(&Inst)) 260 continue; 261 assert(canSinkBCECmpInst(&Inst, BlockInsts, AA) && 262 "Split unsplittable block"); 263 // This is a non-BCE-cmp-block instruction. And it can be separated 264 // from the BCE-cmp-block instruction. 265 OtherInsts.push_back(&Inst); 266 } 267 268 // Do the actual spliting. 269 for (Instruction *Inst : reverse(OtherInsts)) { 270 Inst->moveBefore(&*NewParent->begin()); 271 } 272 } 273 274 bool BCECmpBlock::canSplit(AliasAnalysis *AA) const { 275 DenseSet<Instruction *> BlockInsts( 276 {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); 277 for (Instruction &Inst : *BB) { 278 if (!BlockInsts.count(&Inst)) { 279 if (!canSinkBCECmpInst(&Inst, BlockInsts, AA)) 280 return false; 281 } 282 } 283 return true; 284 } 285 286 bool BCECmpBlock::doesOtherWork() const { 287 AssertConsistent(); 288 // All the instructions we care about in the BCE cmp block. 289 DenseSet<Instruction *> BlockInsts( 290 {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); 291 // TODO(courbet): Can we allow some other things ? This is very conservative. 292 // We might be able to get away with anything does not have any side 293 // effects outside of the basic block. 294 // Note: The GEPs and/or loads are not necessarily in the same block. 295 for (const Instruction &Inst : *BB) { 296 if (!BlockInsts.count(&Inst)) 297 return true; 298 } 299 return false; 300 } 301 302 // Visit the given comparison. If this is a comparison between two valid 303 // BCE atoms, returns the comparison. 304 BCECmpBlock visitICmp(const ICmpInst *const CmpI, 305 const ICmpInst::Predicate ExpectedPredicate, 306 BaseIdentifier &BaseId) { 307 // The comparison can only be used once: 308 // - For intermediate blocks, as a branch condition. 309 // - For the final block, as an incoming value for the Phi. 310 // If there are any other uses of the comparison, we cannot merge it with 311 // other comparisons as we would create an orphan use of the value. 312 if (!CmpI->hasOneUse()) { 313 LLVM_DEBUG(dbgs() << "cmp has several uses\n"); 314 return {}; 315 } 316 if (CmpI->getPredicate() != ExpectedPredicate) 317 return {}; 318 LLVM_DEBUG(dbgs() << "cmp " 319 << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") 320 << "\n"); 321 auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId); 322 if (!Lhs.BaseId) 323 return {}; 324 auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); 325 if (!Rhs.BaseId) 326 return {}; 327 const auto &DL = CmpI->getModule()->getDataLayout(); 328 return BCECmpBlock(std::move(Lhs), std::move(Rhs), 329 DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); 330 } 331 332 // Visit the given comparison block. If this is a comparison between two valid 333 // BCE atoms, returns the comparison. 334 BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, 335 const BasicBlock *const PhiBlock, 336 BaseIdentifier &BaseId) { 337 if (Block->empty()) return {}; 338 auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); 339 if (!BranchI) return {}; 340 LLVM_DEBUG(dbgs() << "branch\n"); 341 if (BranchI->isUnconditional()) { 342 // In this case, we expect an incoming value which is the result of the 343 // comparison. This is the last link in the chain of comparisons (note 344 // that this does not mean that this is the last incoming value, blocks 345 // can be reordered). 346 auto *const CmpI = dyn_cast<ICmpInst>(Val); 347 if (!CmpI) return {}; 348 LLVM_DEBUG(dbgs() << "icmp\n"); 349 auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ, BaseId); 350 Result.CmpI = CmpI; 351 Result.BranchI = BranchI; 352 return Result; 353 } else { 354 // In this case, we expect a constant incoming value (the comparison is 355 // chained). 356 const auto *const Const = dyn_cast<ConstantInt>(Val); 357 LLVM_DEBUG(dbgs() << "const\n"); 358 if (!Const->isZero()) return {}; 359 LLVM_DEBUG(dbgs() << "false\n"); 360 auto *const CmpI = dyn_cast<ICmpInst>(BranchI->getCondition()); 361 if (!CmpI) return {}; 362 LLVM_DEBUG(dbgs() << "icmp\n"); 363 assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); 364 BasicBlock *const FalseBlock = BranchI->getSuccessor(1); 365 auto Result = visitICmp( 366 CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, 367 BaseId); 368 Result.CmpI = CmpI; 369 Result.BranchI = BranchI; 370 return Result; 371 } 372 return {}; 373 } 374 375 static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, 376 BCECmpBlock &Comparison) { 377 LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName() 378 << "': Found cmp of " << Comparison.SizeBits() 379 << " bits between " << Comparison.Lhs().BaseId << " + " 380 << Comparison.Lhs().Offset << " and " 381 << Comparison.Rhs().BaseId << " + " 382 << Comparison.Rhs().Offset << "\n"); 383 LLVM_DEBUG(dbgs() << "\n"); 384 Comparisons.push_back(Comparison); 385 } 386 387 // A chain of comparisons. 388 class BCECmpChain { 389 public: 390 BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, 391 AliasAnalysis *AA); 392 393 int size() const { return Comparisons_.size(); } 394 395 #ifdef MERGEICMPS_DOT_ON 396 void dump() const; 397 #endif // MERGEICMPS_DOT_ON 398 399 bool simplify(const TargetLibraryInfo *const TLI, AliasAnalysis *AA); 400 401 private: 402 static bool IsContiguous(const BCECmpBlock &First, 403 const BCECmpBlock &Second) { 404 return First.Lhs().BaseId == Second.Lhs().BaseId && 405 First.Rhs().BaseId == Second.Rhs().BaseId && 406 First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && 407 First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; 408 } 409 410 PHINode &Phi_; 411 std::vector<BCECmpBlock> Comparisons_; 412 // The original entry block (before sorting); 413 BasicBlock *EntryBlock_; 414 }; 415 416 BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, 417 AliasAnalysis *AA) 418 : Phi_(Phi) { 419 assert(!Blocks.empty() && "a chain should have at least one block"); 420 // Now look inside blocks to check for BCE comparisons. 421 std::vector<BCECmpBlock> Comparisons; 422 BaseIdentifier BaseId; 423 for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) { 424 BasicBlock *const Block = Blocks[BlockIdx]; 425 assert(Block && "invalid block"); 426 BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), 427 Block, Phi.getParent(), BaseId); 428 Comparison.BB = Block; 429 if (!Comparison.IsValid()) { 430 LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); 431 return; 432 } 433 if (Comparison.doesOtherWork()) { 434 LLVM_DEBUG(dbgs() << "block '" << Comparison.BB->getName() 435 << "' does extra work besides compare\n"); 436 if (Comparisons.empty()) { 437 // This is the initial block in the chain, in case this block does other 438 // work, we can try to split the block and move the irrelevant 439 // instructions to the predecessor. 440 // 441 // If this is not the initial block in the chain, splitting it wont 442 // work. 443 // 444 // As once split, there will still be instructions before the BCE cmp 445 // instructions that do other work in program order, i.e. within the 446 // chain before sorting. Unless we can abort the chain at this point 447 // and start anew. 448 // 449 // NOTE: we only handle blocks a with single predecessor for now. 450 if (Comparison.canSplit(AA)) { 451 LLVM_DEBUG(dbgs() 452 << "Split initial block '" << Comparison.BB->getName() 453 << "' that does extra work besides compare\n"); 454 Comparison.RequireSplit = true; 455 enqueueBlock(Comparisons, Comparison); 456 } else { 457 LLVM_DEBUG(dbgs() 458 << "ignoring initial block '" << Comparison.BB->getName() 459 << "' that does extra work besides compare\n"); 460 } 461 continue; 462 } 463 // TODO(courbet): Right now we abort the whole chain. We could be 464 // merging only the blocks that don't do other work and resume the 465 // chain from there. For example: 466 // if (a[0] == b[0]) { // bb1 467 // if (a[1] == b[1]) { // bb2 468 // some_value = 3; //bb3 469 // if (a[2] == b[2]) { //bb3 470 // do a ton of stuff //bb4 471 // } 472 // } 473 // } 474 // 475 // This is: 476 // 477 // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+ 478 // \ \ \ \ 479 // ne ne ne \ 480 // \ \ \ v 481 // +------------+-----------+----------> bb_phi 482 // 483 // We can only merge the first two comparisons, because bb3* does 484 // "other work" (setting some_value to 3). 485 // We could still merge bb1 and bb2 though. 486 return; 487 } 488 enqueueBlock(Comparisons, Comparison); 489 } 490 491 // It is possible we have no suitable comparison to merge. 492 if (Comparisons.empty()) { 493 LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n"); 494 return; 495 } 496 EntryBlock_ = Comparisons[0].BB; 497 Comparisons_ = std::move(Comparisons); 498 #ifdef MERGEICMPS_DOT_ON 499 errs() << "BEFORE REORDERING:\n\n"; 500 dump(); 501 #endif // MERGEICMPS_DOT_ON 502 // Reorder blocks by LHS. We can do that without changing the 503 // semantics because we are only accessing dereferencable memory. 504 llvm::sort(Comparisons_, 505 [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { 506 return LhsBlock.Lhs() < RhsBlock.Lhs(); 507 }); 508 #ifdef MERGEICMPS_DOT_ON 509 errs() << "AFTER REORDERING:\n\n"; 510 dump(); 511 #endif // MERGEICMPS_DOT_ON 512 } 513 514 #ifdef MERGEICMPS_DOT_ON 515 void BCECmpChain::dump() const { 516 errs() << "digraph dag {\n"; 517 errs() << " graph [bgcolor=transparent];\n"; 518 errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n"; 519 errs() << " edge [color=black];\n"; 520 for (size_t I = 0; I < Comparisons_.size(); ++I) { 521 const auto &Comparison = Comparisons_[I]; 522 errs() << " \"" << I << "\" [label=\"%" 523 << Comparison.Lhs().Base()->getName() << " + " 524 << Comparison.Lhs().Offset << " == %" 525 << Comparison.Rhs().Base()->getName() << " + " 526 << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8) 527 << " bytes)\"];\n"; 528 const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB); 529 if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n"; 530 errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n"; 531 } 532 errs() << " \"Phi\" [label=\"Phi\"];\n"; 533 errs() << "}\n\n"; 534 } 535 #endif // MERGEICMPS_DOT_ON 536 537 namespace { 538 539 // A class to compute the name of a set of merged basic blocks. 540 // This is optimized for the common case of no block names. 541 class MergedBlockName { 542 // Storage for the uncommon case of several named blocks. 543 SmallString<16> Scratch; 544 545 public: 546 explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons) 547 : Name(makeName(Comparisons)) {} 548 const StringRef Name; 549 550 private: 551 StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) { 552 assert(!Comparisons.empty() && "no basic block"); 553 // Fast path: only one block, or no names at all. 554 if (Comparisons.size() == 1) 555 return Comparisons[0].BB->getName(); 556 const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0, 557 [](int i, const BCECmpBlock &Cmp) { 558 return i + Cmp.BB->getName().size(); 559 }); 560 if (size == 0) 561 return StringRef("", 0); 562 563 // Slow path: at least two blocks, at least one block with a name. 564 Scratch.clear(); 565 // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for 566 // separators. 567 Scratch.reserve(size + Comparisons.size() - 1); 568 const auto append = [this](StringRef str) { 569 Scratch.append(str.begin(), str.end()); 570 }; 571 append(Comparisons[0].BB->getName()); 572 for (int I = 1, E = Comparisons.size(); I < E; ++I) { 573 const BasicBlock *const BB = Comparisons[I].BB; 574 if (!BB->getName().empty()) { 575 append("+"); 576 append(BB->getName()); 577 } 578 } 579 return StringRef(Scratch); 580 } 581 }; 582 } // namespace 583 584 // Merges the given contiguous comparison blocks into one memcmp block. 585 static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, 586 BasicBlock *const NextCmpBlock, 587 PHINode &Phi, 588 const TargetLibraryInfo *const TLI, 589 AliasAnalysis *AA) { 590 assert(!Comparisons.empty() && "merging zero comparisons"); 591 LLVMContext &Context = NextCmpBlock->getContext(); 592 const BCECmpBlock &FirstCmp = Comparisons[0]; 593 594 // Create a new cmp block before next cmp block. 595 BasicBlock *const BB = 596 BasicBlock::Create(Context, MergedBlockName(Comparisons).Name, 597 NextCmpBlock->getParent(), NextCmpBlock); 598 IRBuilder<> Builder(BB); 599 // Add the GEPs from the first BCECmpBlock. 600 Value *const Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone()); 601 Value *const Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone()); 602 603 Value *IsEqual = nullptr; 604 if (Comparisons.size() == 1) { 605 LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); 606 Value *const LhsLoad = 607 Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs); 608 Value *const RhsLoad = 609 Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); 610 // There are no blocks to merge, just do the comparison. 611 IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); 612 } else { 613 LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); 614 615 // If there is one block that requires splitting, we do it now, i.e. 616 // just before we know we will collapse the chain. The instructions 617 // can be executed before any of the instructions in the chain. 618 const auto ToSplit = 619 std::find_if(Comparisons.begin(), Comparisons.end(), 620 [](const BCECmpBlock &B) { return B.RequireSplit; }); 621 if (ToSplit != Comparisons.end()) { 622 LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n"); 623 ToSplit->split(BB, AA); 624 } 625 626 const unsigned TotalSizeBits = std::accumulate( 627 Comparisons.begin(), Comparisons.end(), 0u, 628 [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); }); 629 630 // Create memcmp() == 0. 631 const auto &DL = Phi.getModule()->getDataLayout(); 632 Value *const MemCmpCall = emitMemCmp( 633 Lhs, Rhs, 634 ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder, 635 DL, TLI); 636 IsEqual = Builder.CreateICmpEQ( 637 MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); 638 } 639 640 BasicBlock *const PhiBB = Phi.getParent(); 641 // Add a branch to the next basic block in the chain. 642 if (NextCmpBlock == PhiBB) { 643 // Continue to phi, passing it the comparison result. 644 Builder.CreateBr(Phi.getParent()); 645 Phi.addIncoming(IsEqual, BB); 646 } else { 647 // Continue to next block if equal, exit to phi else. 648 Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); 649 Phi.addIncoming(ConstantInt::getFalse(Context), BB); 650 } 651 return BB; 652 } 653 654 bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI, 655 AliasAnalysis *AA) { 656 assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain"); 657 // First pass to check if there is at least one merge. If not, we don't do 658 // anything and we keep analysis passes intact. 659 const auto AtLeastOneMerged = [this]() { 660 for (size_t I = 1; I < Comparisons_.size(); ++I) { 661 if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) 662 return true; 663 } 664 return false; 665 }; 666 if (!AtLeastOneMerged()) 667 return false; 668 669 LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " 670 << EntryBlock_->getName() << "\n"); 671 672 // Effectively merge blocks. We go in the reverse direction from the phi block 673 // so that the next block is always available to branch to. 674 const auto mergeRange = [this, TLI, AA](int I, int Num, BasicBlock *Next) { 675 return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num), Next, 676 Phi_, TLI, AA); 677 }; 678 int NumMerged = 1; 679 BasicBlock *NextCmpBlock = Phi_.getParent(); 680 for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) { 681 if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) { 682 LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName() 683 << " into " << Comparisons_[I + 1].BB->getName() 684 << "\n"); 685 ++NumMerged; 686 } else { 687 NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock); 688 NumMerged = 1; 689 } 690 } 691 NextCmpBlock = mergeRange(0, NumMerged, NextCmpBlock); 692 693 // Replace the original cmp chain with the new cmp chain by pointing all 694 // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp 695 // blocks in the old chain unreachable. 696 while (!pred_empty(EntryBlock_)) { 697 BasicBlock* const Pred = *pred_begin(EntryBlock_); 698 LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName() 699 << "\n"); 700 Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock); 701 } 702 EntryBlock_ = nullptr; 703 704 // Delete merged blocks. This also removes incoming values in phi. 705 SmallVector<BasicBlock *, 16> DeadBlocks; 706 for (auto &Cmp : Comparisons_) { 707 LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n"); 708 DeadBlocks.push_back(Cmp.BB); 709 } 710 DeleteDeadBlocks(DeadBlocks); 711 712 Comparisons_.clear(); 713 return true; 714 } 715 716 std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, 717 BasicBlock *const LastBlock, 718 int NumBlocks) { 719 // Walk up from the last block to find other blocks. 720 std::vector<BasicBlock *> Blocks(NumBlocks); 721 assert(LastBlock && "invalid last block"); 722 BasicBlock *CurBlock = LastBlock; 723 for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) { 724 if (CurBlock->hasAddressTaken()) { 725 // Somebody is jumping to the block through an address, all bets are 726 // off. 727 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex 728 << " has its address taken\n"); 729 return {}; 730 } 731 Blocks[BlockIndex] = CurBlock; 732 auto *SinglePredecessor = CurBlock->getSinglePredecessor(); 733 if (!SinglePredecessor) { 734 // The block has two or more predecessors. 735 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex 736 << " has two or more predecessors\n"); 737 return {}; 738 } 739 if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) { 740 // The block does not link back to the phi. 741 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex 742 << " does not link back to the phi\n"); 743 return {}; 744 } 745 CurBlock = SinglePredecessor; 746 } 747 Blocks[0] = CurBlock; 748 return Blocks; 749 } 750 751 bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, 752 AliasAnalysis *AA) { 753 LLVM_DEBUG(dbgs() << "processPhi()\n"); 754 if (Phi.getNumIncomingValues() <= 1) { 755 LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); 756 return false; 757 } 758 // We are looking for something that has the following structure: 759 // bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+ 760 // \ \ \ \ 761 // ne ne ne \ 762 // \ \ \ v 763 // +------------+-----------+----------> bb_phi 764 // 765 // - The last basic block (bb4 here) must branch unconditionally to bb_phi. 766 // It's the only block that contributes a non-constant value to the Phi. 767 // - All other blocks (b1, b2, b3) must have exactly two successors, one of 768 // them being the phi block. 769 // - All intermediate blocks (bb2, bb3) must have only one predecessor. 770 // - Blocks cannot do other work besides the comparison, see doesOtherWork() 771 772 // The blocks are not necessarily ordered in the phi, so we start from the 773 // last block and reconstruct the order. 774 BasicBlock *LastBlock = nullptr; 775 for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) { 776 if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue; 777 if (LastBlock) { 778 // There are several non-constant values. 779 LLVM_DEBUG(dbgs() << "skip: several non-constant values\n"); 780 return false; 781 } 782 if (!isa<ICmpInst>(Phi.getIncomingValue(I)) || 783 cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() != 784 Phi.getIncomingBlock(I)) { 785 // Non-constant incoming value is not from a cmp instruction or not 786 // produced by the last block. We could end up processing the value 787 // producing block more than once. 788 // 789 // This is an uncommon case, so we bail. 790 LLVM_DEBUG( 791 dbgs() 792 << "skip: non-constant value not from cmp or not from last block.\n"); 793 return false; 794 } 795 LastBlock = Phi.getIncomingBlock(I); 796 } 797 if (!LastBlock) { 798 // There is no non-constant block. 799 LLVM_DEBUG(dbgs() << "skip: no non-constant block\n"); 800 return false; 801 } 802 if (LastBlock->getSingleSuccessor() != Phi.getParent()) { 803 LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n"); 804 return false; 805 } 806 807 const auto Blocks = 808 getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues()); 809 if (Blocks.empty()) return false; 810 BCECmpChain CmpChain(Blocks, Phi, AA); 811 812 if (CmpChain.size() < 2) { 813 LLVM_DEBUG(dbgs() << "skip: only one compare block\n"); 814 return false; 815 } 816 817 return CmpChain.simplify(TLI, AA); 818 } 819 820 class MergeICmps : public FunctionPass { 821 public: 822 static char ID; 823 824 MergeICmps() : FunctionPass(ID) { 825 initializeMergeICmpsPass(*PassRegistry::getPassRegistry()); 826 } 827 828 bool runOnFunction(Function &F) override { 829 if (skipFunction(F)) return false; 830 const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 831 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 832 AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 833 auto PA = runImpl(F, &TLI, &TTI, AA); 834 return !PA.areAllPreserved(); 835 } 836 837 private: 838 void getAnalysisUsage(AnalysisUsage &AU) const override { 839 AU.addRequired<TargetLibraryInfoWrapperPass>(); 840 AU.addRequired<TargetTransformInfoWrapperPass>(); 841 AU.addRequired<AAResultsWrapperPass>(); 842 } 843 844 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 845 const TargetTransformInfo *TTI, AliasAnalysis *AA); 846 }; 847 848 PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, 849 const TargetTransformInfo *TTI, 850 AliasAnalysis *AA) { 851 LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); 852 853 // We only try merging comparisons if the target wants to expand memcmp later. 854 // The rationale is to avoid turning small chains into memcmp calls. 855 if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); 856 857 // If we don't have memcmp avaiable we can't emit calls to it. 858 if (!TLI->has(LibFunc_memcmp)) 859 return PreservedAnalyses::all(); 860 861 bool MadeChange = false; 862 863 for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { 864 // A Phi operation is always first in a basic block. 865 if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) 866 MadeChange |= processPhi(*Phi, TLI, AA); 867 } 868 869 if (MadeChange) return PreservedAnalyses::none(); 870 return PreservedAnalyses::all(); 871 } 872 873 } // namespace 874 875 char MergeICmps::ID = 0; 876 INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", 877 "Merge contiguous icmps into a memcmp", false, false) 878 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 879 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 880 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 881 INITIALIZE_PASS_END(MergeICmps, "mergeicmps", 882 "Merge contiguous icmps into a memcmp", false, false) 883 884 Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); } 885