1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo 10 /// instructions into machine operations. 11 /// The expectation is that the loop contains three pseudo instructions: 12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop 13 /// form should be in the preheader, whereas the while form should be in the 14 /// preheaders only predecessor. 15 /// - t2LoopDec - placed within in the loop body. 16 /// - t2LoopEnd - the loop latch terminator. 17 /// 18 /// In addition to this, we also look for the presence of the VCTP instruction, 19 /// which determines whether we can generated the tail-predicated low-overhead 20 /// loop form. 21 /// 22 /// Assumptions and Dependencies: 23 /// Low-overhead loops are constructed and executed using a setup instruction: 24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP. 25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range 26 /// but fixed polarity: WLS can only branch forwards and LE can only branch 27 /// backwards. These restrictions mean that this pass is dependent upon block 28 /// layout and block sizes, which is why it's the last pass to run. The same is 29 /// true for ConstantIslands, but this pass does not increase the size of the 30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed 31 /// during the transform and pseudo instructions are replaced by real ones. In 32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce 33 /// multiple instructions for a single pseudo (see RevertWhile and 34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd 35 /// are defined to be as large as this maximum sequence of replacement 36 /// instructions. 37 /// 38 /// A note on VPR.P0 (the lane mask): 39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a 40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks). 41 /// They will simply "and" the result of their calculation with the current 42 /// value of VPR.P0. You can think of it like this: 43 /// \verbatim 44 /// if VPT active: ; Between a DLSTP/LETP, or for predicated instrs 45 /// VPR.P0 &= Value 46 /// else 47 /// VPR.P0 = Value 48 /// \endverbatim 49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always 50 /// fall in the "VPT active" case, so we can consider that all VPR writes by 51 /// one of those instruction is actually a "and". 52 //===----------------------------------------------------------------------===// 53 54 #include "ARM.h" 55 #include "ARMBaseInstrInfo.h" 56 #include "ARMBaseRegisterInfo.h" 57 #include "ARMBasicBlockInfo.h" 58 #include "ARMSubtarget.h" 59 #include "Thumb2InstrInfo.h" 60 #include "llvm/ADT/SetOperations.h" 61 #include "llvm/ADT/SmallSet.h" 62 #include "llvm/CodeGen/LivePhysRegs.h" 63 #include "llvm/CodeGen/MachineFunctionPass.h" 64 #include "llvm/CodeGen/MachineLoopInfo.h" 65 #include "llvm/CodeGen/MachineLoopUtils.h" 66 #include "llvm/CodeGen/MachineRegisterInfo.h" 67 #include "llvm/CodeGen/Passes.h" 68 #include "llvm/CodeGen/ReachingDefAnalysis.h" 69 #include "llvm/MC/MCInstrDesc.h" 70 71 using namespace llvm; 72 73 #define DEBUG_TYPE "arm-low-overhead-loops" 74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass" 75 76 static bool isVectorPredicated(MachineInstr *MI) { 77 int PIdx = llvm::findFirstVPTPredOperandIdx(*MI); 78 return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR; 79 } 80 81 static bool isVectorPredicate(MachineInstr *MI) { 82 return MI->findRegisterDefOperandIdx(ARM::VPR) != -1; 83 } 84 85 static bool hasVPRUse(MachineInstr *MI) { 86 return MI->findRegisterUseOperandIdx(ARM::VPR) != -1; 87 } 88 89 static bool isDomainMVE(MachineInstr *MI) { 90 uint64_t Domain = MI->getDesc().TSFlags & ARMII::DomainMask; 91 return Domain == ARMII::DomainMVE; 92 } 93 94 static bool shouldInspect(MachineInstr &MI) { 95 return isDomainMVE(&MI) || isVectorPredicate(&MI) || 96 hasVPRUse(&MI); 97 } 98 99 namespace { 100 101 using InstSet = SmallPtrSetImpl<MachineInstr *>; 102 103 class PostOrderLoopTraversal { 104 MachineLoop &ML; 105 MachineLoopInfo &MLI; 106 SmallPtrSet<MachineBasicBlock*, 4> Visited; 107 SmallVector<MachineBasicBlock*, 4> Order; 108 109 public: 110 PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI) 111 : ML(ML), MLI(MLI) { } 112 113 const SmallVectorImpl<MachineBasicBlock*> &getOrder() const { 114 return Order; 115 } 116 117 // Visit all the blocks within the loop, as well as exit blocks and any 118 // blocks properly dominating the header. 119 void ProcessLoop() { 120 std::function<void(MachineBasicBlock*)> Search = [this, &Search] 121 (MachineBasicBlock *MBB) -> void { 122 if (Visited.count(MBB)) 123 return; 124 125 Visited.insert(MBB); 126 for (auto *Succ : MBB->successors()) { 127 if (!ML.contains(Succ)) 128 continue; 129 Search(Succ); 130 } 131 Order.push_back(MBB); 132 }; 133 134 // Insert exit blocks. 135 SmallVector<MachineBasicBlock*, 2> ExitBlocks; 136 ML.getExitBlocks(ExitBlocks); 137 for (auto *MBB : ExitBlocks) 138 Order.push_back(MBB); 139 140 // Then add the loop body. 141 Search(ML.getHeader()); 142 143 // Then try the preheader and its predecessors. 144 std::function<void(MachineBasicBlock*)> GetPredecessor = 145 [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void { 146 Order.push_back(MBB); 147 if (MBB->pred_size() == 1) 148 GetPredecessor(*MBB->pred_begin()); 149 }; 150 151 if (auto *Preheader = ML.getLoopPreheader()) 152 GetPredecessor(Preheader); 153 else if (auto *Preheader = MLI.findLoopPreheader(&ML, true)) 154 GetPredecessor(Preheader); 155 } 156 }; 157 158 struct PredicatedMI { 159 MachineInstr *MI = nullptr; 160 SetVector<MachineInstr*> Predicates; 161 162 public: 163 PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) { 164 assert(I && "Instruction must not be null!"); 165 Predicates.insert(Preds.begin(), Preds.end()); 166 } 167 }; 168 169 // Represent the current state of the VPR and hold all instances which 170 // represent a VPT block, which is a list of instructions that begins with a 171 // VPT/VPST and has a maximum of four proceeding instructions. All 172 // instructions within the block are predicated upon the vpr and we allow 173 // instructions to define the vpr within in the block too. 174 class VPTState { 175 friend struct LowOverheadLoop; 176 177 SmallVector<MachineInstr *, 4> Insts; 178 179 static SmallVector<VPTState, 4> Blocks; 180 static SetVector<MachineInstr *> CurrentPredicates; 181 static std::map<MachineInstr *, 182 std::unique_ptr<PredicatedMI>> PredicatedInsts; 183 184 static void CreateVPTBlock(MachineInstr *MI) { 185 assert(CurrentPredicates.size() && "Can't begin VPT without predicate"); 186 Blocks.emplace_back(MI); 187 // The execution of MI is predicated upon the current set of instructions 188 // that are AND'ed together to form the VPR predicate value. In the case 189 // that MI is a VPT, CurrentPredicates will also just be MI. 190 PredicatedInsts.emplace( 191 MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates)); 192 } 193 194 static void reset() { 195 Blocks.clear(); 196 PredicatedInsts.clear(); 197 CurrentPredicates.clear(); 198 } 199 200 static void addInst(MachineInstr *MI) { 201 Blocks.back().insert(MI); 202 PredicatedInsts.emplace( 203 MI, std::make_unique<PredicatedMI>(MI, CurrentPredicates)); 204 } 205 206 static void addPredicate(MachineInstr *MI) { 207 LLVM_DEBUG(dbgs() << "ARM Loops: Adding VPT Predicate: " << *MI); 208 CurrentPredicates.insert(MI); 209 } 210 211 static void resetPredicate(MachineInstr *MI) { 212 LLVM_DEBUG(dbgs() << "ARM Loops: Resetting VPT Predicate: " << *MI); 213 CurrentPredicates.clear(); 214 CurrentPredicates.insert(MI); 215 } 216 217 public: 218 // Have we found an instruction within the block which defines the vpr? If 219 // so, not all the instructions in the block will have the same predicate. 220 static bool hasUniformPredicate(VPTState &Block) { 221 return getDivergent(Block) == nullptr; 222 } 223 224 // If it exists, return the first internal instruction which modifies the 225 // VPR. 226 static MachineInstr *getDivergent(VPTState &Block) { 227 SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts(); 228 for (unsigned i = 1; i < Insts.size(); ++i) { 229 MachineInstr *Next = Insts[i]; 230 if (isVectorPredicate(Next)) 231 return Next; // Found an instruction altering the vpr. 232 } 233 return nullptr; 234 } 235 236 // Return whether the given instruction is predicated upon a VCTP. 237 static bool isPredicatedOnVCTP(MachineInstr *MI, bool Exclusive = false) { 238 SetVector<MachineInstr *> &Predicates = PredicatedInsts[MI]->Predicates; 239 if (Exclusive && Predicates.size() != 1) 240 return false; 241 for (auto *PredMI : Predicates) 242 if (isVCTP(PredMI)) 243 return true; 244 return false; 245 } 246 247 // Is the VPST, controlling the block entry, predicated upon a VCTP. 248 static bool isEntryPredicatedOnVCTP(VPTState &Block, 249 bool Exclusive = false) { 250 SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts(); 251 return isPredicatedOnVCTP(Insts.front(), Exclusive); 252 } 253 254 static bool isValid() { 255 // All predication within the loop should be based on vctp. If the block 256 // isn't predicated on entry, check whether the vctp is within the block 257 // and that all other instructions are then predicated on it. 258 for (auto &Block : Blocks) { 259 if (isEntryPredicatedOnVCTP(Block)) 260 continue; 261 262 SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts(); 263 for (auto *MI : Insts) { 264 // Check that any internal VCTPs are 'Then' predicated. 265 if (isVCTP(MI) && getVPTInstrPredicate(*MI) != ARMVCC::Then) 266 return false; 267 // Skip other instructions that build up the predicate. 268 if (MI->getOpcode() == ARM::MVE_VPST || isVectorPredicate(MI)) 269 continue; 270 // Check that any other instructions are predicated upon a vctp. 271 // TODO: We could infer when VPTs are implicitly predicated on the 272 // vctp (when the operands are predicated). 273 if (!isPredicatedOnVCTP(MI)) { 274 LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *MI); 275 return false; 276 } 277 } 278 } 279 return true; 280 } 281 282 VPTState(MachineInstr *MI) { Insts.push_back(MI); } 283 284 void insert(MachineInstr *MI) { 285 Insts.push_back(MI); 286 // VPT/VPST + 4 predicated instructions. 287 assert(Insts.size() <= 5 && "Too many instructions in VPT block!"); 288 } 289 290 bool containsVCTP() const { 291 for (auto *MI : Insts) 292 if (isVCTP(MI)) 293 return true; 294 return false; 295 } 296 297 unsigned size() const { return Insts.size(); } 298 SmallVectorImpl<MachineInstr *> &getInsts() { return Insts; } 299 }; 300 301 struct LowOverheadLoop { 302 303 MachineLoop &ML; 304 MachineBasicBlock *Preheader = nullptr; 305 MachineLoopInfo &MLI; 306 ReachingDefAnalysis &RDA; 307 const TargetRegisterInfo &TRI; 308 const ARMBaseInstrInfo &TII; 309 MachineFunction *MF = nullptr; 310 MachineInstr *InsertPt = nullptr; 311 MachineInstr *Start = nullptr; 312 MachineInstr *Dec = nullptr; 313 MachineInstr *End = nullptr; 314 MachineOperand TPNumElements; 315 SmallVector<MachineInstr*, 4> VCTPs; 316 SmallPtrSet<MachineInstr*, 4> ToRemove; 317 SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute; 318 bool Revert = false; 319 bool CannotTailPredicate = false; 320 321 LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI, 322 ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI, 323 const ARMBaseInstrInfo &TII) 324 : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII), 325 TPNumElements(MachineOperand::CreateImm(0)) { 326 MF = ML.getHeader()->getParent(); 327 if (auto *MBB = ML.getLoopPreheader()) 328 Preheader = MBB; 329 else if (auto *MBB = MLI.findLoopPreheader(&ML, true)) 330 Preheader = MBB; 331 VPTState::reset(); 332 } 333 334 // If this is an MVE instruction, check that we know how to use tail 335 // predication with it. Record VPT blocks and return whether the 336 // instruction is valid for tail predication. 337 bool ValidateMVEInst(MachineInstr *MI); 338 339 void AnalyseMVEInst(MachineInstr *MI) { 340 CannotTailPredicate = !ValidateMVEInst(MI); 341 } 342 343 bool IsTailPredicationLegal() const { 344 // For now, let's keep things really simple and only support a single 345 // block for tail predication. 346 return !Revert && FoundAllComponents() && !VCTPs.empty() && 347 !CannotTailPredicate && ML.getNumBlocks() == 1; 348 } 349 350 // Given that MI is a VCTP, check that is equivalent to any other VCTPs 351 // found. 352 bool AddVCTP(MachineInstr *MI); 353 354 // Check that the predication in the loop will be equivalent once we 355 // perform the conversion. Also ensure that we can provide the number 356 // of elements to the loop start instruction. 357 bool ValidateTailPredicate(MachineInstr *StartInsertPt); 358 359 // Check that any values available outside of the loop will be the same 360 // after tail predication conversion. 361 bool ValidateLiveOuts(); 362 363 // Is it safe to define LR with DLS/WLS? 364 // LR can be defined if it is the operand to start, because it's the same 365 // value, or if it's going to be equivalent to the operand to Start. 366 MachineInstr *isSafeToDefineLR(); 367 368 // Check the branch targets are within range and we satisfy our 369 // restrictions. 370 void CheckLegality(ARMBasicBlockUtils *BBUtils); 371 372 bool FoundAllComponents() const { 373 return Start && Dec && End; 374 } 375 376 SmallVectorImpl<VPTState> &getVPTBlocks() { 377 return VPTState::Blocks; 378 } 379 380 // Return the operand for the loop start instruction. This will be the loop 381 // iteration count, or the number of elements if we're tail predicating. 382 MachineOperand &getLoopStartOperand() { 383 return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0); 384 } 385 386 unsigned getStartOpcode() const { 387 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 388 if (!IsTailPredicationLegal()) 389 return IsDo ? ARM::t2DLS : ARM::t2WLS; 390 391 return VCTPOpcodeToLSTP(VCTPs.back()->getOpcode(), IsDo); 392 } 393 394 void dump() const { 395 if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; 396 if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; 397 if (End) dbgs() << "ARM Loops: Found Loop End: " << *End; 398 if (!VCTPs.empty()) { 399 dbgs() << "ARM Loops: Found VCTP(s):\n"; 400 for (auto *MI : VCTPs) 401 dbgs() << " - " << *MI; 402 } 403 if (!FoundAllComponents()) 404 dbgs() << "ARM Loops: Not a low-overhead loop.\n"; 405 else if (!(Start && Dec && End)) 406 dbgs() << "ARM Loops: Failed to find all loop components.\n"; 407 } 408 }; 409 410 class ARMLowOverheadLoops : public MachineFunctionPass { 411 MachineFunction *MF = nullptr; 412 MachineLoopInfo *MLI = nullptr; 413 ReachingDefAnalysis *RDA = nullptr; 414 const ARMBaseInstrInfo *TII = nullptr; 415 MachineRegisterInfo *MRI = nullptr; 416 const TargetRegisterInfo *TRI = nullptr; 417 std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr; 418 419 public: 420 static char ID; 421 422 ARMLowOverheadLoops() : MachineFunctionPass(ID) { } 423 424 void getAnalysisUsage(AnalysisUsage &AU) const override { 425 AU.setPreservesCFG(); 426 AU.addRequired<MachineLoopInfo>(); 427 AU.addRequired<ReachingDefAnalysis>(); 428 MachineFunctionPass::getAnalysisUsage(AU); 429 } 430 431 bool runOnMachineFunction(MachineFunction &MF) override; 432 433 MachineFunctionProperties getRequiredProperties() const override { 434 return MachineFunctionProperties().set( 435 MachineFunctionProperties::Property::NoVRegs).set( 436 MachineFunctionProperties::Property::TracksLiveness); 437 } 438 439 StringRef getPassName() const override { 440 return ARM_LOW_OVERHEAD_LOOPS_NAME; 441 } 442 443 private: 444 bool ProcessLoop(MachineLoop *ML); 445 446 bool RevertNonLoops(); 447 448 void RevertWhile(MachineInstr *MI) const; 449 450 bool RevertLoopDec(MachineInstr *MI) const; 451 452 void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; 453 454 void ConvertVPTBlocks(LowOverheadLoop &LoLoop); 455 456 MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); 457 458 void Expand(LowOverheadLoop &LoLoop); 459 460 void IterationCountDCE(LowOverheadLoop &LoLoop); 461 }; 462 } 463 464 char ARMLowOverheadLoops::ID = 0; 465 466 SmallVector<VPTState, 4> VPTState::Blocks; 467 SetVector<MachineInstr *> VPTState::CurrentPredicates; 468 std::map<MachineInstr *, 469 std::unique_ptr<PredicatedMI>> VPTState::PredicatedInsts; 470 471 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, 472 false, false) 473 474 MachineInstr *LowOverheadLoop::isSafeToDefineLR() { 475 // We can define LR because LR already contains the same value. 476 if (Start->getOperand(0).getReg() == ARM::LR) 477 return Start; 478 479 unsigned CountReg = Start->getOperand(0).getReg(); 480 auto IsMoveLR = [&CountReg](MachineInstr *MI) { 481 return MI->getOpcode() == ARM::tMOVr && 482 MI->getOperand(0).getReg() == ARM::LR && 483 MI->getOperand(1).getReg() == CountReg && 484 MI->getOperand(2).getImm() == ARMCC::AL; 485 }; 486 487 MachineBasicBlock *MBB = Start->getParent(); 488 489 // Find an insertion point: 490 // - Is there a (mov lr, Count) before Start? If so, and nothing else writes 491 // to Count before Start, we can insert at that mov. 492 if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR)) 493 if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) 494 return LRDef; 495 496 // - Is there a (mov lr, Count) after Start? If so, and nothing else writes 497 // to Count after Start, we can insert at that mov. 498 if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR)) 499 if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) 500 return LRDef; 501 502 // We've found no suitable LR def and Start doesn't use LR directly. Can we 503 // just define LR anyway? 504 return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr; 505 } 506 507 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) { 508 assert(!VCTPs.empty() && "VCTP instruction expected but is not set"); 509 510 if (!VPTState::isValid()) 511 return false; 512 513 if (!ValidateLiveOuts()) { 514 LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n"); 515 return false; 516 } 517 518 // For tail predication, we need to provide the number of elements, instead 519 // of the iteration count, to the loop start instruction. The number of 520 // elements is provided to the vctp instruction, so we need to check that 521 // we can use this register at InsertPt. 522 MachineInstr *VCTP = VCTPs.back(); 523 TPNumElements = VCTP->getOperand(1); 524 Register NumElements = TPNumElements.getReg(); 525 526 // If the register is defined within loop, then we can't perform TP. 527 // TODO: Check whether this is just a mov of a register that would be 528 // available. 529 if (RDA.hasLocalDefBefore(VCTP, NumElements)) { 530 LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n"); 531 return false; 532 } 533 534 // The element count register maybe defined after InsertPt, in which case we 535 // need to try to move either InsertPt or the def so that the [w|d]lstp can 536 // use the value. 537 MachineBasicBlock *InsertBB = StartInsertPt->getParent(); 538 539 if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) { 540 if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) { 541 if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) { 542 ElemDef->removeFromParent(); 543 InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef); 544 LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: " 545 << *ElemDef); 546 } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) { 547 StartInsertPt->removeFromParent(); 548 InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), 549 StartInsertPt); 550 LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); 551 } else { 552 // If we fail to move an instruction and the element count is provided 553 // by a mov, use the mov operand if it will have the same value at the 554 // insertion point 555 MachineOperand Operand = ElemDef->getOperand(1); 556 if (isMovRegOpcode(ElemDef->getOpcode()) && 557 RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) == 558 RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) { 559 TPNumElements = Operand; 560 NumElements = TPNumElements.getReg(); 561 } else { 562 LLVM_DEBUG(dbgs() 563 << "ARM Loops: Unable to move element count to loop " 564 << "start instruction.\n"); 565 return false; 566 } 567 } 568 } 569 } 570 571 // Could inserting the [W|D]LSTP cause some unintended affects? In a perfect 572 // world the [w|d]lstp instruction would be last instruction in the preheader 573 // and so it would only affect instructions within the loop body. But due to 574 // scheduling, and/or the logic in this pass (above), the insertion point can 575 // be moved earlier. So if the Loop Start isn't the last instruction in the 576 // preheader, and if the initial element count is smaller than the vector 577 // width, the Loop Start instruction will immediately generate one or more 578 // false lane mask which can, incorrectly, affect the proceeding MVE 579 // instructions in the preheader. 580 auto cannotInsertWDLSTPBetween = [](MachineInstr *Begin, 581 MachineInstr *End) { 582 auto I = MachineBasicBlock::iterator(Begin); 583 auto E = MachineBasicBlock::iterator(End); 584 for (; I != E; ++I) 585 if (shouldInspect(*I)) 586 return true; 587 return false; 588 }; 589 590 if (cannotInsertWDLSTPBetween(StartInsertPt, &InsertBB->back())) 591 return false; 592 593 // Especially in the case of while loops, InsertBB may not be the 594 // preheader, so we need to check that the register isn't redefined 595 // before entering the loop. 596 auto CannotProvideElements = [this](MachineBasicBlock *MBB, 597 Register NumElements) { 598 // NumElements is redefined in this block. 599 if (RDA.hasLocalDefBefore(&MBB->back(), NumElements)) 600 return true; 601 602 // Don't continue searching up through multiple predecessors. 603 if (MBB->pred_size() > 1) 604 return true; 605 606 return false; 607 }; 608 609 // First, find the block that looks like the preheader. 610 MachineBasicBlock *MBB = Preheader; 611 if (!MBB) { 612 LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n"); 613 return false; 614 } 615 616 // Then search backwards for a def, until we get to InsertBB. 617 while (MBB != InsertBB) { 618 if (CannotProvideElements(MBB, NumElements)) { 619 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n"); 620 return false; 621 } 622 MBB = *MBB->pred_begin(); 623 } 624 625 // Check that the value change of the element count is what we expect and 626 // that the predication will be equivalent. For this we need: 627 // NumElements = NumElements - VectorWidth. The sub will be a sub immediate 628 // and we can also allow register copies within the chain too. 629 auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) { 630 return -getAddSubImmediate(*MI) == ExpectedVecWidth; 631 }; 632 633 MBB = VCTP->getParent(); 634 // Remove modifications to the element count since they have no purpose in a 635 // tail predicated loop. Explicitly refer to the vctp operand no matter which 636 // register NumElements has been assigned to, since that is what the 637 // modifications will be using 638 if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), 639 VCTP->getOperand(1).getReg())) { 640 SmallPtrSet<MachineInstr*, 2> ElementChain; 641 SmallPtrSet<MachineInstr*, 2> Ignore; 642 unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode()); 643 644 Ignore.insert(VCTPs.begin(), VCTPs.end()); 645 646 if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) { 647 bool FoundSub = false; 648 649 for (auto *MI : ElementChain) { 650 if (isMovRegOpcode(MI->getOpcode())) 651 continue; 652 653 if (isSubImmOpcode(MI->getOpcode())) { 654 if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth)) 655 return false; 656 FoundSub = true; 657 } else 658 return false; 659 } 660 661 LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n"; 662 for (auto *MI : ElementChain) 663 dbgs() << " - " << *MI); 664 ToRemove.insert(ElementChain.begin(), ElementChain.end()); 665 } 666 } 667 return true; 668 } 669 670 static bool isRegInClass(const MachineOperand &MO, 671 const TargetRegisterClass *Class) { 672 return MO.isReg() && MO.getReg() && Class->contains(MO.getReg()); 673 } 674 675 // MVE 'narrowing' operate on half a lane, reading from half and writing 676 // to half, which are referred to has the top and bottom half. The other 677 // half retains its previous value. 678 static bool retainsPreviousHalfElement(const MachineInstr &MI) { 679 const MCInstrDesc &MCID = MI.getDesc(); 680 uint64_t Flags = MCID.TSFlags; 681 return (Flags & ARMII::RetainsPreviousHalfElement) != 0; 682 } 683 684 // Some MVE instructions read from the top/bottom halves of their operand(s) 685 // and generate a vector result with result elements that are double the 686 // width of the input. 687 static bool producesDoubleWidthResult(const MachineInstr &MI) { 688 const MCInstrDesc &MCID = MI.getDesc(); 689 uint64_t Flags = MCID.TSFlags; 690 return (Flags & ARMII::DoubleWidthResult) != 0; 691 } 692 693 static bool isHorizontalReduction(const MachineInstr &MI) { 694 const MCInstrDesc &MCID = MI.getDesc(); 695 uint64_t Flags = MCID.TSFlags; 696 return (Flags & ARMII::HorizontalReduction) != 0; 697 } 698 699 // Can this instruction generate a non-zero result when given only zeroed 700 // operands? This allows us to know that, given operands with false bytes 701 // zeroed by masked loads, that the result will also contain zeros in those 702 // bytes. 703 static bool canGenerateNonZeros(const MachineInstr &MI) { 704 705 // Check for instructions which can write into a larger element size, 706 // possibly writing into a previous zero'd lane. 707 if (producesDoubleWidthResult(MI)) 708 return true; 709 710 switch (MI.getOpcode()) { 711 default: 712 break; 713 // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow 714 // fp16 -> fp32 vector conversions. 715 // Instructions that perform a NOT will generate 1s from 0s. 716 case ARM::MVE_VMVN: 717 case ARM::MVE_VORN: 718 // Count leading zeros will do just that! 719 case ARM::MVE_VCLZs8: 720 case ARM::MVE_VCLZs16: 721 case ARM::MVE_VCLZs32: 722 return true; 723 } 724 return false; 725 } 726 727 // Look at its register uses to see if it only can only receive zeros 728 // into its false lanes which would then produce zeros. Also check that 729 // the output register is also defined by an FalseLanesZero instruction 730 // so that if tail-predication happens, the lanes that aren't updated will 731 // still be zeros. 732 static bool producesFalseLanesZero(MachineInstr &MI, 733 const TargetRegisterClass *QPRs, 734 const ReachingDefAnalysis &RDA, 735 InstSet &FalseLanesZero) { 736 if (canGenerateNonZeros(MI)) 737 return false; 738 739 bool isPredicated = isVectorPredicated(&MI); 740 // Predicated loads will write zeros to the falsely predicated bytes of the 741 // destination register. 742 if (MI.mayLoad()) 743 return isPredicated; 744 745 auto IsZeroInit = [](MachineInstr *Def) { 746 return !isVectorPredicated(Def) && 747 Def->getOpcode() == ARM::MVE_VMOVimmi32 && 748 Def->getOperand(1).getImm() == 0; 749 }; 750 751 bool AllowScalars = isHorizontalReduction(MI); 752 for (auto &MO : MI.operands()) { 753 if (!MO.isReg() || !MO.getReg()) 754 continue; 755 if (!isRegInClass(MO, QPRs) && AllowScalars) 756 continue; 757 758 // Check that this instruction will produce zeros in its false lanes: 759 // - If it only consumes false lanes zero or constant 0 (vmov #0) 760 // - If it's predicated, it only matters that it's def register already has 761 // false lane zeros, so we can ignore the uses. 762 SmallPtrSet<MachineInstr *, 2> Defs; 763 RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs); 764 for (auto *Def : Defs) { 765 if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def)) 766 continue; 767 if (MO.isUse() && isPredicated) 768 continue; 769 return false; 770 } 771 } 772 LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI); 773 return true; 774 } 775 776 bool LowOverheadLoop::ValidateLiveOuts() { 777 // We want to find out if the tail-predicated version of this loop will 778 // produce the same values as the loop in its original form. For this to 779 // be true, the newly inserted implicit predication must not change the 780 // the (observable) results. 781 // We're doing this because many instructions in the loop will not be 782 // predicated and so the conversion from VPT predication to tail-predication 783 // can result in different values being produced; due to the tail-predication 784 // preventing many instructions from updating their falsely predicated 785 // lanes. This analysis assumes that all the instructions perform lane-wise 786 // operations and don't perform any exchanges. 787 // A masked load, whether through VPT or tail predication, will write zeros 788 // to any of the falsely predicated bytes. So, from the loads, we know that 789 // the false lanes are zeroed and here we're trying to track that those false 790 // lanes remain zero, or where they change, the differences are masked away 791 // by their user(s). 792 // All MVE stores have to be predicated, so we know that any predicate load 793 // operands, or stored results are equivalent already. Other explicitly 794 // predicated instructions will perform the same operation in the original 795 // loop and the tail-predicated form too. Because of this, we can insert 796 // loads, stores and other predicated instructions into our Predicated 797 // set and build from there. 798 const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID); 799 SetVector<MachineInstr *> FalseLanesUnknown; 800 SmallPtrSet<MachineInstr *, 4> FalseLanesZero; 801 SmallPtrSet<MachineInstr *, 4> Predicated; 802 MachineBasicBlock *Header = ML.getHeader(); 803 804 for (auto &MI : *Header) { 805 if (!shouldInspect(MI)) 806 continue; 807 808 if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode())) 809 continue; 810 811 bool isPredicated = isVectorPredicated(&MI); 812 bool retainsOrReduces = 813 retainsPreviousHalfElement(MI) || isHorizontalReduction(MI); 814 815 if (isPredicated) 816 Predicated.insert(&MI); 817 if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) 818 FalseLanesZero.insert(&MI); 819 else if (MI.getNumDefs() == 0) 820 continue; 821 else if (!isPredicated && retainsOrReduces) 822 return false; 823 else if (!isPredicated) 824 FalseLanesUnknown.insert(&MI); 825 } 826 827 auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO, 828 SmallPtrSetImpl<MachineInstr *> &Predicated) { 829 SmallPtrSet<MachineInstr *, 2> Uses; 830 RDA.getGlobalUses(MI, MO.getReg(), Uses); 831 for (auto *Use : Uses) { 832 if (Use != MI && !Predicated.count(Use)) 833 return false; 834 } 835 return true; 836 }; 837 838 // Visit the unknowns in reverse so that we can start at the values being 839 // stored and then we can work towards the leaves, hopefully adding more 840 // instructions to Predicated. Successfully terminating the loop means that 841 // all the unknown values have to found to be masked by predicated user(s). 842 // For any unpredicated values, we store them in NonPredicated so that we 843 // can later check whether these form a reduction. 844 SmallPtrSet<MachineInstr*, 2> NonPredicated; 845 for (auto *MI : reverse(FalseLanesUnknown)) { 846 for (auto &MO : MI->operands()) { 847 if (!isRegInClass(MO, QPRs) || !MO.isDef()) 848 continue; 849 if (!HasPredicatedUsers(MI, MO, Predicated)) { 850 LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : " 851 << TRI.getRegAsmName(MO.getReg()) << " at " << *MI); 852 NonPredicated.insert(MI); 853 break; 854 } 855 } 856 // Any unknown false lanes have been masked away by the user(s). 857 if (!NonPredicated.contains(MI)) 858 Predicated.insert(MI); 859 } 860 861 SmallPtrSet<MachineInstr *, 2> LiveOutMIs; 862 SmallVector<MachineBasicBlock *, 2> ExitBlocks; 863 ML.getExitBlocks(ExitBlocks); 864 assert(ML.getNumBlocks() == 1 && "Expected single block loop!"); 865 assert(ExitBlocks.size() == 1 && "Expected a single exit block"); 866 MachineBasicBlock *ExitBB = ExitBlocks.front(); 867 for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) { 868 // TODO: Instead of blocking predication, we could move the vctp to the exit 869 // block and calculate it's operand there in or the preheader. 870 if (RegMask.PhysReg == ARM::VPR) 871 return false; 872 // Check Q-regs that are live in the exit blocks. We don't collect scalars 873 // because they won't be affected by lane predication. 874 if (QPRs->contains(RegMask.PhysReg)) 875 if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg)) 876 LiveOutMIs.insert(MI); 877 } 878 879 // We've already validated that any VPT predication within the loop will be 880 // equivalent when we perform the predication transformation; so we know that 881 // any VPT predicated instruction is predicated upon VCTP. Any live-out 882 // instruction needs to be predicated, so check this here. The instructions 883 // in NonPredicated have been found to be a reduction that we can ensure its 884 // legality. 885 for (auto *MI : LiveOutMIs) { 886 if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) { 887 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to handle live out: " << *MI); 888 return false; 889 } 890 } 891 892 return true; 893 } 894 895 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) { 896 if (Revert) 897 return; 898 899 if (!End->getOperand(1).isMBB()) 900 report_fatal_error("Expected LoopEnd to target basic block"); 901 902 // TODO Maybe there's cases where the target doesn't have to be the header, 903 // but for now be safe and revert. 904 if (End->getOperand(1).getMBB() != ML.getHeader()) { 905 LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); 906 Revert = true; 907 return; 908 } 909 910 // The WLS and LE instructions have 12-bits for the label offset. WLS 911 // requires a positive offset, while LE uses negative. 912 if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) || 913 !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) { 914 LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); 915 Revert = true; 916 return; 917 } 918 919 if (Start->getOpcode() == ARM::t2WhileLoopStart && 920 (BBUtils->getOffsetOf(Start) > 921 BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || 922 !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { 923 LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); 924 Revert = true; 925 return; 926 } 927 928 InsertPt = Revert ? nullptr : isSafeToDefineLR(); 929 if (!InsertPt) { 930 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); 931 Revert = true; 932 return; 933 } else 934 LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); 935 936 if (!IsTailPredicationLegal()) { 937 LLVM_DEBUG(if (VCTPs.empty()) 938 dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n"; 939 dbgs() << "ARM Loops: Tail-predication is not valid.\n"); 940 return; 941 } 942 943 assert(ML.getBlocks().size() == 1 && 944 "Shouldn't be processing a loop with more than one block"); 945 CannotTailPredicate = !ValidateTailPredicate(InsertPt); 946 LLVM_DEBUG(if (CannotTailPredicate) 947 dbgs() << "ARM Loops: Couldn't validate tail predicate.\n"); 948 } 949 950 bool LowOverheadLoop::AddVCTP(MachineInstr *MI) { 951 LLVM_DEBUG(dbgs() << "ARM Loops: Adding VCTP: " << *MI); 952 if (VCTPs.empty()) { 953 VCTPs.push_back(MI); 954 return true; 955 } 956 957 // If we find another VCTP, check whether it uses the same value as the main VCTP. 958 // If it does, store it in the VCTPs set, else refuse it. 959 MachineInstr *Prev = VCTPs.back(); 960 if (!Prev->getOperand(1).isIdenticalTo(MI->getOperand(1)) || 961 !RDA.hasSameReachingDef(Prev, MI, MI->getOperand(1).getReg())) { 962 LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching " 963 "definition from the main VCTP"); 964 return false; 965 } 966 VCTPs.push_back(MI); 967 return true; 968 } 969 970 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) { 971 if (CannotTailPredicate) 972 return false; 973 974 if (!shouldInspect(*MI)) 975 return true; 976 977 if (MI->getOpcode() == ARM::MVE_VPSEL || 978 MI->getOpcode() == ARM::MVE_VPNOT) { 979 // TODO: Allow VPSEL and VPNOT, we currently cannot because: 980 // 1) It will use the VPR as a predicate operand, but doesn't have to be 981 // instead a VPT block, which means we can assert while building up 982 // the VPT block because we don't find another VPT or VPST to being a new 983 // one. 984 // 2) VPSEL still requires a VPR operand even after tail predicating, 985 // which means we can't remove it unless there is another 986 // instruction, such as vcmp, that can provide the VPR def. 987 return false; 988 } 989 990 // Record all VCTPs and check that they're equivalent to one another. 991 if (isVCTP(MI) && !AddVCTP(MI)) 992 return false; 993 994 // Inspect uses first so that any instructions that alter the VPR don't 995 // alter the predicate upon themselves. 996 const MCInstrDesc &MCID = MI->getDesc(); 997 bool IsUse = false; 998 unsigned LastOpIdx = MI->getNumOperands() - 1; 999 for (auto &Op : enumerate(reverse(MCID.operands()))) { 1000 const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index()); 1001 if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR) 1002 continue; 1003 1004 if (ARM::isVpred(Op.value().OperandType)) { 1005 VPTState::addInst(MI); 1006 IsUse = true; 1007 } else if (MI->getOpcode() != ARM::MVE_VPST) { 1008 LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); 1009 return false; 1010 } 1011 } 1012 1013 // If we find an instruction that has been marked as not valid for tail 1014 // predication, only allow the instruction if it's contained within a valid 1015 // VPT block. 1016 bool RequiresExplicitPredication = 1017 (MCID.TSFlags & ARMII::ValidForTailPredication) == 0; 1018 if (isDomainMVE(MI) && RequiresExplicitPredication) { 1019 LLVM_DEBUG(if (!IsUse) 1020 dbgs() << "ARM Loops: Can't tail predicate: " << *MI); 1021 return IsUse; 1022 } 1023 1024 // If the instruction is already explicitly predicated, then the conversion 1025 // will be fine, but ensure that all store operations are predicated. 1026 if (MI->mayStore()) 1027 return IsUse; 1028 1029 // If this instruction defines the VPR, update the predicate for the 1030 // proceeding instructions. 1031 if (isVectorPredicate(MI)) { 1032 // Clear the existing predicate when we're not in VPT Active state, 1033 // otherwise we add to it. 1034 if (!isVectorPredicated(MI)) 1035 VPTState::resetPredicate(MI); 1036 else 1037 VPTState::addPredicate(MI); 1038 } 1039 1040 // Finally once the predicate has been modified, we can start a new VPT 1041 // block if necessary. 1042 if (isVPTOpcode(MI->getOpcode())) 1043 VPTState::CreateVPTBlock(MI); 1044 1045 return true; 1046 } 1047 1048 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { 1049 const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); 1050 if (!ST.hasLOB()) 1051 return false; 1052 1053 MF = &mf; 1054 LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); 1055 1056 MLI = &getAnalysis<MachineLoopInfo>(); 1057 RDA = &getAnalysis<ReachingDefAnalysis>(); 1058 MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); 1059 MRI = &MF->getRegInfo(); 1060 TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); 1061 TRI = ST.getRegisterInfo(); 1062 BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); 1063 BBUtils->computeAllBlockSizes(); 1064 BBUtils->adjustBBOffsetsAfter(&MF->front()); 1065 1066 bool Changed = false; 1067 for (auto ML : *MLI) { 1068 if (ML->isOutermost()) 1069 Changed |= ProcessLoop(ML); 1070 } 1071 Changed |= RevertNonLoops(); 1072 return Changed; 1073 } 1074 1075 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { 1076 1077 bool Changed = false; 1078 1079 // Process inner loops first. 1080 for (auto I = ML->begin(), E = ML->end(); I != E; ++I) 1081 Changed |= ProcessLoop(*I); 1082 1083 LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n"; 1084 if (auto *Preheader = ML->getLoopPreheader()) 1085 dbgs() << " - " << Preheader->getName() << "\n"; 1086 else if (auto *Preheader = MLI->findLoopPreheader(ML)) 1087 dbgs() << " - " << Preheader->getName() << "\n"; 1088 else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) 1089 dbgs() << " - " << Preheader->getName() << "\n"; 1090 for (auto *MBB : ML->getBlocks()) 1091 dbgs() << " - " << MBB->getName() << "\n"; 1092 ); 1093 1094 // Search the given block for a loop start instruction. If one isn't found, 1095 // and there's only one predecessor block, search that one too. 1096 std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart = 1097 [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* { 1098 for (auto &MI : *MBB) { 1099 if (isLoopStart(MI)) 1100 return &MI; 1101 } 1102 if (MBB->pred_size() == 1) 1103 return SearchForStart(*MBB->pred_begin()); 1104 return nullptr; 1105 }; 1106 1107 LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII); 1108 // Search the preheader for the start intrinsic. 1109 // FIXME: I don't see why we shouldn't be supporting multiple predecessors 1110 // with potentially multiple set.loop.iterations, so we need to enable this. 1111 if (LoLoop.Preheader) 1112 LoLoop.Start = SearchForStart(LoLoop.Preheader); 1113 else 1114 return false; 1115 1116 // Find the low-overhead loop components and decide whether or not to fall 1117 // back to a normal loop. Also look for a vctp instructions and decide 1118 // whether we can convert that predicate using tail predication. 1119 for (auto *MBB : reverse(ML->getBlocks())) { 1120 for (auto &MI : *MBB) { 1121 if (MI.isDebugValue()) 1122 continue; 1123 else if (MI.getOpcode() == ARM::t2LoopDec) 1124 LoLoop.Dec = &MI; 1125 else if (MI.getOpcode() == ARM::t2LoopEnd) 1126 LoLoop.End = &MI; 1127 else if (isLoopStart(MI)) 1128 LoLoop.Start = &MI; 1129 else if (MI.getDesc().isCall()) { 1130 // TODO: Though the call will require LE to execute again, does this 1131 // mean we should revert? Always executing LE hopefully should be 1132 // faster than performing a sub,cmp,br or even subs,br. 1133 LoLoop.Revert = true; 1134 LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); 1135 } else { 1136 // Record VPR defs and build up their corresponding vpt blocks. 1137 // Check we know how to tail predicate any mve instructions. 1138 LoLoop.AnalyseMVEInst(&MI); 1139 } 1140 } 1141 } 1142 1143 LLVM_DEBUG(LoLoop.dump()); 1144 if (!LoLoop.FoundAllComponents()) { 1145 LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n"); 1146 return false; 1147 } 1148 1149 // Check that the only instruction using LoopDec is LoopEnd. 1150 // TODO: Check for copy chains that really have no effect. 1151 SmallPtrSet<MachineInstr*, 2> Uses; 1152 RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses); 1153 if (Uses.size() > 1 || !Uses.count(LoLoop.End)) { 1154 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n"); 1155 LoLoop.Revert = true; 1156 } 1157 LoLoop.CheckLegality(BBUtils.get()); 1158 Expand(LoLoop); 1159 return true; 1160 } 1161 1162 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a 1163 // beq that branches to the exit branch. 1164 // TODO: We could also try to generate a cbz if the value in LR is also in 1165 // another low register. 1166 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const { 1167 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI); 1168 MachineBasicBlock *MBB = MI->getParent(); 1169 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1170 TII->get(ARM::t2CMPri)); 1171 MIB.add(MI->getOperand(0)); 1172 MIB.addImm(0); 1173 MIB.addImm(ARMCC::AL); 1174 MIB.addReg(ARM::NoRegister); 1175 1176 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 1177 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 1178 ARM::tBcc : ARM::t2Bcc; 1179 1180 MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 1181 MIB.add(MI->getOperand(1)); // branch target 1182 MIB.addImm(ARMCC::EQ); // condition code 1183 MIB.addReg(ARM::CPSR); 1184 MI->eraseFromParent(); 1185 } 1186 1187 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const { 1188 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI); 1189 MachineBasicBlock *MBB = MI->getParent(); 1190 SmallPtrSet<MachineInstr*, 1> Ignore; 1191 for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) { 1192 if (I->getOpcode() == ARM::t2LoopEnd) { 1193 Ignore.insert(&*I); 1194 break; 1195 } 1196 } 1197 1198 // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS. 1199 bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore); 1200 1201 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1202 TII->get(ARM::t2SUBri)); 1203 MIB.addDef(ARM::LR); 1204 MIB.add(MI->getOperand(1)); 1205 MIB.add(MI->getOperand(2)); 1206 MIB.addImm(ARMCC::AL); 1207 MIB.addReg(0); 1208 1209 if (SetFlags) { 1210 MIB.addReg(ARM::CPSR); 1211 MIB->getOperand(5).setIsDef(true); 1212 } else 1213 MIB.addReg(0); 1214 1215 MI->eraseFromParent(); 1216 return SetFlags; 1217 } 1218 1219 // Generate a subs, or sub and cmp, and a branch instead of an LE. 1220 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { 1221 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI); 1222 1223 MachineBasicBlock *MBB = MI->getParent(); 1224 // Create cmp 1225 if (!SkipCmp) { 1226 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1227 TII->get(ARM::t2CMPri)); 1228 MIB.addReg(ARM::LR); 1229 MIB.addImm(0); 1230 MIB.addImm(ARMCC::AL); 1231 MIB.addReg(ARM::NoRegister); 1232 } 1233 1234 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 1235 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 1236 ARM::tBcc : ARM::t2Bcc; 1237 1238 // Create bne 1239 MachineInstrBuilder MIB = 1240 BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 1241 MIB.add(MI->getOperand(1)); // branch target 1242 MIB.addImm(ARMCC::NE); // condition code 1243 MIB.addReg(ARM::CPSR); 1244 MI->eraseFromParent(); 1245 } 1246 1247 // Perform dead code elimation on the loop iteration count setup expression. 1248 // If we are tail-predicating, the number of elements to be processed is the 1249 // operand of the VCTP instruction in the vector body, see getCount(), which is 1250 // register $r3 in this example: 1251 // 1252 // $lr = big-itercount-expression 1253 // .. 1254 // t2DoLoopStart renamable $lr 1255 // vector.body: 1256 // .. 1257 // $vpr = MVE_VCTP32 renamable $r3 1258 // renamable $lr = t2LoopDec killed renamable $lr, 1 1259 // t2LoopEnd renamable $lr, %vector.body 1260 // tB %end 1261 // 1262 // What we would like achieve here is to replace the do-loop start pseudo 1263 // instruction t2DoLoopStart with: 1264 // 1265 // $lr = MVE_DLSTP_32 killed renamable $r3 1266 // 1267 // Thus, $r3 which defines the number of elements, is written to $lr, 1268 // and then we want to delete the whole chain that used to define $lr, 1269 // see the comment below how this chain could look like. 1270 // 1271 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) { 1272 if (!LoLoop.IsTailPredicationLegal()) 1273 return; 1274 1275 LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n"); 1276 1277 MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0); 1278 if (!Def) { 1279 LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n"); 1280 return; 1281 } 1282 1283 // Collect and remove the users of iteration count. 1284 SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec, 1285 LoLoop.End, LoLoop.InsertPt }; 1286 SmallPtrSet<MachineInstr*, 2> Remove; 1287 if (RDA->isSafeToRemove(Def, Remove, Killed)) 1288 LoLoop.ToRemove.insert(Remove.begin(), Remove.end()); 1289 else { 1290 LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n"); 1291 return; 1292 } 1293 1294 // Collect the dead code and the MBBs in which they reside. 1295 RDA->collectKilledOperands(Def, Killed); 1296 SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks; 1297 for (auto *MI : Killed) 1298 BasicBlocks.insert(MI->getParent()); 1299 1300 // Collect IT blocks in all affected basic blocks. 1301 std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks; 1302 for (auto *MBB : BasicBlocks) { 1303 for (auto &MI : *MBB) { 1304 if (MI.getOpcode() != ARM::t2IT) 1305 continue; 1306 RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]); 1307 } 1308 } 1309 1310 // If we're removing all of the instructions within an IT block, then 1311 // also remove the IT instruction. 1312 SmallPtrSet<MachineInstr*, 2> ModifiedITs; 1313 for (auto *MI : Killed) { 1314 if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) { 1315 MachineInstr *IT = RDA->getMIOperand(MI, *MO); 1316 auto &CurrentBlock = ITBlocks[IT]; 1317 CurrentBlock.erase(MI); 1318 if (CurrentBlock.empty()) 1319 ModifiedITs.erase(IT); 1320 else 1321 ModifiedITs.insert(IT); 1322 } 1323 } 1324 1325 // Delete the killed instructions only if we don't have any IT blocks that 1326 // need to be modified because we need to fixup the mask. 1327 // TODO: Handle cases where IT blocks are modified. 1328 if (ModifiedITs.empty()) { 1329 LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n"; 1330 for (auto *MI : Killed) 1331 dbgs() << " - " << *MI); 1332 LoLoop.ToRemove.insert(Killed.begin(), Killed.end()); 1333 } else 1334 LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n"); 1335 } 1336 1337 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { 1338 LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n"); 1339 // When using tail-predication, try to delete the dead code that was used to 1340 // calculate the number of loop iterations. 1341 IterationCountDCE(LoLoop); 1342 1343 MachineInstr *InsertPt = LoLoop.InsertPt; 1344 MachineInstr *Start = LoLoop.Start; 1345 MachineBasicBlock *MBB = InsertPt->getParent(); 1346 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 1347 unsigned Opc = LoLoop.getStartOpcode(); 1348 MachineOperand &Count = LoLoop.getLoopStartOperand(); 1349 1350 MachineInstrBuilder MIB = 1351 BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); 1352 1353 MIB.addDef(ARM::LR); 1354 MIB.add(Count); 1355 if (!IsDo) 1356 MIB.add(Start->getOperand(1)); 1357 1358 // If we're inserting at a mov lr, then remove it as it's redundant. 1359 if (InsertPt != Start) 1360 LoLoop.ToRemove.insert(InsertPt); 1361 LoLoop.ToRemove.insert(Start); 1362 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); 1363 return &*MIB; 1364 } 1365 1366 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { 1367 auto RemovePredicate = [](MachineInstr *MI) { 1368 LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); 1369 if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) { 1370 assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then && 1371 "Expected Then predicate!"); 1372 MI->getOperand(PIdx).setImm(ARMVCC::None); 1373 MI->getOperand(PIdx+1).setReg(0); 1374 } else 1375 llvm_unreachable("trying to unpredicate a non-predicated instruction"); 1376 }; 1377 1378 for (auto &Block : LoLoop.getVPTBlocks()) { 1379 SmallVectorImpl<MachineInstr *> &Insts = Block.getInsts(); 1380 1381 if (VPTState::isEntryPredicatedOnVCTP(Block, /*exclusive*/true)) { 1382 if (VPTState::hasUniformPredicate(Block)) { 1383 // A vpt block starting with VPST, is only predicated upon vctp and has no 1384 // internal vpr defs: 1385 // - Remove vpst. 1386 // - Unpredicate the remaining instructions. 1387 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Insts.front()); 1388 LoLoop.ToRemove.insert(Insts.front()); 1389 for (unsigned i = 1; i < Insts.size(); ++i) 1390 RemovePredicate(Insts[i]); 1391 } else { 1392 // The VPT block has a non-uniform predicate but it uses a vpst and its 1393 // entry is guarded only by a vctp, which means we: 1394 // - Need to remove the original vpst. 1395 // - Then need to unpredicate any following instructions, until 1396 // we come across the divergent vpr def. 1397 // - Insert a new vpst to predicate the instruction(s) that following 1398 // the divergent vpr def. 1399 // TODO: We could be producing more VPT blocks than necessary and could 1400 // fold the newly created one into a proceeding one. 1401 MachineInstr *Divergent = VPTState::getDivergent(Block); 1402 for (auto I = ++MachineBasicBlock::iterator(Insts.front()), 1403 E = ++MachineBasicBlock::iterator(Divergent); I != E; ++I) 1404 RemovePredicate(&*I); 1405 1406 // Check if the instruction defining vpr is a vcmp so it can be combined 1407 // with the VPST This should be the divergent instruction 1408 MachineInstr *VCMP = VCMPOpcodeToVPT(Divergent->getOpcode()) != 0 1409 ? Divergent 1410 : nullptr; 1411 1412 unsigned Size = 0; 1413 auto E = MachineBasicBlock::reverse_iterator(Divergent); 1414 auto I = MachineBasicBlock::reverse_iterator(Insts.back()); 1415 MachineInstr *InsertAt = nullptr; 1416 while (I != E) { 1417 InsertAt = &*I; 1418 ++Size; 1419 ++I; 1420 } 1421 1422 MachineInstrBuilder MIB; 1423 if (VCMP) { 1424 // Combine the VPST and VCMP into a VPT 1425 MIB = 1426 BuildMI(*InsertAt->getParent(), InsertAt, InsertAt->getDebugLoc(), 1427 TII->get(VCMPOpcodeToVPT(VCMP->getOpcode()))); 1428 MIB.addImm(ARMVCC::Then); 1429 // Register one 1430 MIB.add(VCMP->getOperand(1)); 1431 // Register two 1432 MIB.add(VCMP->getOperand(2)); 1433 // The comparison code, e.g. ge, eq, lt 1434 MIB.add(VCMP->getOperand(3)); 1435 LLVM_DEBUG(dbgs() 1436 << "ARM Loops: Combining with VCMP to VPT: " << *MIB); 1437 LoLoop.ToRemove.insert(VCMP); 1438 } else { 1439 // Create a VPST (with a null mask for now, we'll recompute it later) 1440 // or a VPT in case there was a VCMP right before it 1441 MIB = BuildMI(*InsertAt->getParent(), InsertAt, 1442 InsertAt->getDebugLoc(), TII->get(ARM::MVE_VPST)); 1443 MIB.addImm(0); 1444 LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); 1445 } 1446 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Insts.front()); 1447 LoLoop.ToRemove.insert(Insts.front()); 1448 LoLoop.BlockMasksToRecompute.insert(MIB.getInstr()); 1449 } 1450 } else if (Block.containsVCTP()) { 1451 // The vctp will be removed, so the block mask of the vp(s)t will need 1452 // to be recomputed. 1453 LoLoop.BlockMasksToRecompute.insert(Insts.front()); 1454 } 1455 } 1456 1457 LoLoop.ToRemove.insert(LoLoop.VCTPs.begin(), LoLoop.VCTPs.end()); 1458 } 1459 1460 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { 1461 1462 // Combine the LoopDec and LoopEnd instructions into LE(TP). 1463 auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) { 1464 MachineInstr *End = LoLoop.End; 1465 MachineBasicBlock *MBB = End->getParent(); 1466 unsigned Opc = LoLoop.IsTailPredicationLegal() ? 1467 ARM::MVE_LETP : ARM::t2LEUpdate; 1468 MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(), 1469 TII->get(Opc)); 1470 MIB.addDef(ARM::LR); 1471 MIB.add(End->getOperand(0)); 1472 MIB.add(End->getOperand(1)); 1473 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); 1474 LoLoop.ToRemove.insert(LoLoop.Dec); 1475 LoLoop.ToRemove.insert(End); 1476 return &*MIB; 1477 }; 1478 1479 // TODO: We should be able to automatically remove these branches before we 1480 // get here - probably by teaching analyzeBranch about the pseudo 1481 // instructions. 1482 // If there is an unconditional branch, after I, that just branches to the 1483 // next block, remove it. 1484 auto RemoveDeadBranch = [](MachineInstr *I) { 1485 MachineBasicBlock *BB = I->getParent(); 1486 MachineInstr *Terminator = &BB->instr_back(); 1487 if (Terminator->isUnconditionalBranch() && I != Terminator) { 1488 MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB(); 1489 if (BB->isLayoutSuccessor(Succ)) { 1490 LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator); 1491 Terminator->eraseFromParent(); 1492 } 1493 } 1494 }; 1495 1496 if (LoLoop.Revert) { 1497 if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart) 1498 RevertWhile(LoLoop.Start); 1499 else 1500 LoLoop.Start->eraseFromParent(); 1501 bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec); 1502 RevertLoopEnd(LoLoop.End, FlagsAlreadySet); 1503 } else { 1504 LoLoop.Start = ExpandLoopStart(LoLoop); 1505 RemoveDeadBranch(LoLoop.Start); 1506 LoLoop.End = ExpandLoopEnd(LoLoop); 1507 RemoveDeadBranch(LoLoop.End); 1508 if (LoLoop.IsTailPredicationLegal()) 1509 ConvertVPTBlocks(LoLoop); 1510 for (auto *I : LoLoop.ToRemove) { 1511 LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I); 1512 I->eraseFromParent(); 1513 } 1514 for (auto *I : LoLoop.BlockMasksToRecompute) { 1515 LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I); 1516 recomputeVPTBlockMask(*I); 1517 LLVM_DEBUG(dbgs() << " ... done: " << *I); 1518 } 1519 } 1520 1521 PostOrderLoopTraversal DFS(LoLoop.ML, *MLI); 1522 DFS.ProcessLoop(); 1523 const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder(); 1524 for (auto *MBB : PostOrder) { 1525 recomputeLiveIns(*MBB); 1526 // FIXME: For some reason, the live-in print order is non-deterministic for 1527 // our tests and I can't out why... So just sort them. 1528 MBB->sortUniqueLiveIns(); 1529 } 1530 1531 for (auto *MBB : reverse(PostOrder)) 1532 recomputeLivenessFlags(*MBB); 1533 1534 // We've moved, removed and inserted new instructions, so update RDA. 1535 RDA->reset(); 1536 } 1537 1538 bool ARMLowOverheadLoops::RevertNonLoops() { 1539 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n"); 1540 bool Changed = false; 1541 1542 for (auto &MBB : *MF) { 1543 SmallVector<MachineInstr*, 4> Starts; 1544 SmallVector<MachineInstr*, 4> Decs; 1545 SmallVector<MachineInstr*, 4> Ends; 1546 1547 for (auto &I : MBB) { 1548 if (isLoopStart(I)) 1549 Starts.push_back(&I); 1550 else if (I.getOpcode() == ARM::t2LoopDec) 1551 Decs.push_back(&I); 1552 else if (I.getOpcode() == ARM::t2LoopEnd) 1553 Ends.push_back(&I); 1554 } 1555 1556 if (Starts.empty() && Decs.empty() && Ends.empty()) 1557 continue; 1558 1559 Changed = true; 1560 1561 for (auto *Start : Starts) { 1562 if (Start->getOpcode() == ARM::t2WhileLoopStart) 1563 RevertWhile(Start); 1564 else 1565 Start->eraseFromParent(); 1566 } 1567 for (auto *Dec : Decs) 1568 RevertLoopDec(Dec); 1569 1570 for (auto *End : Ends) 1571 RevertLoopEnd(End); 1572 } 1573 return Changed; 1574 } 1575 1576 FunctionPass *llvm::createARMLowOverheadLoopsPass() { 1577 return new ARMLowOverheadLoops(); 1578 } 1579