1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo 10 /// instructions into machine operations. 11 /// The expectation is that the loop contains three pseudo instructions: 12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop 13 /// form should be in the preheader, whereas the while form should be in the 14 /// preheaders only predecessor. 15 /// - t2LoopDec - placed within in the loop body. 16 /// - t2LoopEnd - the loop latch terminator. 17 /// 18 /// In addition to this, we also look for the presence of the VCTP instruction, 19 /// which determines whether we can generated the tail-predicated low-overhead 20 /// loop form. 21 /// 22 /// Assumptions and Dependencies: 23 /// Low-overhead loops are constructed and executed using a setup instruction: 24 /// DLS, WLS, DLSTP or WLSTP and an instruction that loops back: LE or LETP. 25 /// WLS(TP) and LE(TP) are branching instructions with a (large) limited range 26 /// but fixed polarity: WLS can only branch forwards and LE can only branch 27 /// backwards. These restrictions mean that this pass is dependent upon block 28 /// layout and block sizes, which is why it's the last pass to run. The same is 29 /// true for ConstantIslands, but this pass does not increase the size of the 30 /// basic blocks, nor does it change the CFG. Instructions are mainly removed 31 /// during the transform and pseudo instructions are replaced by real ones. In 32 /// some cases, when we have to revert to a 'normal' loop, we have to introduce 33 /// multiple instructions for a single pseudo (see RevertWhile and 34 /// RevertLoopEnd). To handle this situation, t2WhileLoopStart and t2LoopEnd 35 /// are defined to be as large as this maximum sequence of replacement 36 /// instructions. 37 /// 38 /// A note on VPR.P0 (the lane mask): 39 /// VPT, VCMP, VPNOT and VCTP won't overwrite VPR.P0 when they update it in a 40 /// "VPT Active" context (which includes low-overhead loops and vpt blocks). 41 /// They will simply "and" the result of their calculation with the current 42 /// value of VPR.P0. You can think of it like this: 43 /// \verbatim 44 /// if VPT active: ; Between a DLSTP/LETP, or for predicated instrs 45 /// VPR.P0 &= Value 46 /// else 47 /// VPR.P0 = Value 48 /// \endverbatim 49 /// When we're inside the low-overhead loop (between DLSTP and LETP), we always 50 /// fall in the "VPT active" case, so we can consider that all VPR writes by 51 /// one of those instruction is actually a "and". 52 //===----------------------------------------------------------------------===// 53 54 #include "ARM.h" 55 #include "ARMBaseInstrInfo.h" 56 #include "ARMBaseRegisterInfo.h" 57 #include "ARMBasicBlockInfo.h" 58 #include "ARMSubtarget.h" 59 #include "Thumb2InstrInfo.h" 60 #include "llvm/ADT/SetOperations.h" 61 #include "llvm/ADT/SmallSet.h" 62 #include "llvm/CodeGen/LivePhysRegs.h" 63 #include "llvm/CodeGen/MachineFunctionPass.h" 64 #include "llvm/CodeGen/MachineLoopInfo.h" 65 #include "llvm/CodeGen/MachineLoopUtils.h" 66 #include "llvm/CodeGen/MachineRegisterInfo.h" 67 #include "llvm/CodeGen/Passes.h" 68 #include "llvm/CodeGen/ReachingDefAnalysis.h" 69 #include "llvm/MC/MCInstrDesc.h" 70 71 using namespace llvm; 72 73 #define DEBUG_TYPE "arm-low-overhead-loops" 74 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass" 75 76 static bool isVectorPredicated(MachineInstr *MI) { 77 int PIdx = llvm::findFirstVPTPredOperandIdx(*MI); 78 return PIdx != -1 && MI->getOperand(PIdx + 1).getReg() == ARM::VPR; 79 } 80 81 static bool isVectorPredicate(MachineInstr *MI) { 82 return MI->findRegisterDefOperandIdx(ARM::VPR) != -1; 83 } 84 85 static bool hasVPRUse(MachineInstr *MI) { 86 return MI->findRegisterUseOperandIdx(ARM::VPR) != -1; 87 } 88 89 static bool isDomainMVE(MachineInstr *MI) { 90 uint64_t Domain = MI->getDesc().TSFlags & ARMII::DomainMask; 91 return Domain == ARMII::DomainMVE; 92 } 93 94 static bool shouldInspect(MachineInstr &MI) { 95 return isDomainMVE(&MI) || isVectorPredicate(&MI) || 96 hasVPRUse(&MI); 97 } 98 99 namespace { 100 101 using InstSet = SmallPtrSetImpl<MachineInstr *>; 102 103 class PostOrderLoopTraversal { 104 MachineLoop &ML; 105 MachineLoopInfo &MLI; 106 SmallPtrSet<MachineBasicBlock*, 4> Visited; 107 SmallVector<MachineBasicBlock*, 4> Order; 108 109 public: 110 PostOrderLoopTraversal(MachineLoop &ML, MachineLoopInfo &MLI) 111 : ML(ML), MLI(MLI) { } 112 113 const SmallVectorImpl<MachineBasicBlock*> &getOrder() const { 114 return Order; 115 } 116 117 // Visit all the blocks within the loop, as well as exit blocks and any 118 // blocks properly dominating the header. 119 void ProcessLoop() { 120 std::function<void(MachineBasicBlock*)> Search = [this, &Search] 121 (MachineBasicBlock *MBB) -> void { 122 if (Visited.count(MBB)) 123 return; 124 125 Visited.insert(MBB); 126 for (auto *Succ : MBB->successors()) { 127 if (!ML.contains(Succ)) 128 continue; 129 Search(Succ); 130 } 131 Order.push_back(MBB); 132 }; 133 134 // Insert exit blocks. 135 SmallVector<MachineBasicBlock*, 2> ExitBlocks; 136 ML.getExitBlocks(ExitBlocks); 137 for (auto *MBB : ExitBlocks) 138 Order.push_back(MBB); 139 140 // Then add the loop body. 141 Search(ML.getHeader()); 142 143 // Then try the preheader and its predecessors. 144 std::function<void(MachineBasicBlock*)> GetPredecessor = 145 [this, &GetPredecessor] (MachineBasicBlock *MBB) -> void { 146 Order.push_back(MBB); 147 if (MBB->pred_size() == 1) 148 GetPredecessor(*MBB->pred_begin()); 149 }; 150 151 if (auto *Preheader = ML.getLoopPreheader()) 152 GetPredecessor(Preheader); 153 else if (auto *Preheader = MLI.findLoopPreheader(&ML, true)) 154 GetPredecessor(Preheader); 155 } 156 }; 157 158 struct PredicatedMI { 159 MachineInstr *MI = nullptr; 160 SetVector<MachineInstr*> Predicates; 161 162 public: 163 PredicatedMI(MachineInstr *I, SetVector<MachineInstr *> &Preds) : MI(I) { 164 assert(I && "Instruction must not be null!"); 165 Predicates.insert(Preds.begin(), Preds.end()); 166 } 167 }; 168 169 // Represent a VPT block, a list of instructions that begins with a VPT/VPST 170 // and has a maximum of four proceeding instructions. All instructions within 171 // the block are predicated upon the vpr and we allow instructions to define 172 // the vpr within in the block too. 173 class VPTBlock { 174 // The predicate then instruction, which is either a VPT, or a VPST 175 // instruction. 176 std::unique_ptr<PredicatedMI> PredicateThen; 177 PredicatedMI *Divergent = nullptr; 178 SmallVector<PredicatedMI, 4> Insts; 179 180 public: 181 VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 182 PredicateThen = std::make_unique<PredicatedMI>(MI, Preds); 183 } 184 185 void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 186 LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); 187 if (!Divergent && !set_difference(Preds, PredicateThen->Predicates).empty()) { 188 Divergent = &Insts.back(); 189 LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); 190 } 191 Insts.emplace_back(MI, Preds); 192 assert(Insts.size() <= 4 && "Too many instructions in VPT block!"); 193 } 194 195 // Have we found an instruction within the block which defines the vpr? If 196 // so, not all the instructions in the block will have the same predicate. 197 bool HasNonUniformPredicate() const { 198 return Divergent != nullptr; 199 } 200 201 // Is the given instruction part of the predicate set controlling the entry 202 // to the block. 203 bool IsPredicatedOn(MachineInstr *MI) const { 204 return PredicateThen->Predicates.count(MI); 205 } 206 207 // Returns true if this is a VPT instruction. 208 bool isVPT() const { return !isVPST(); } 209 210 // Returns true if this is a VPST instruction. 211 bool isVPST() const { 212 return PredicateThen->MI->getOpcode() == ARM::MVE_VPST; 213 } 214 215 // Is the given instruction the only predicate which controls the entry to 216 // the block. 217 bool IsOnlyPredicatedOn(MachineInstr *MI) const { 218 return IsPredicatedOn(MI) && PredicateThen->Predicates.size() == 1; 219 } 220 221 unsigned size() const { return Insts.size(); } 222 SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } 223 MachineInstr *getPredicateThen() const { return PredicateThen->MI; } 224 PredicatedMI *getDivergent() const { return Divergent; } 225 }; 226 227 struct LowOverheadLoop { 228 229 MachineLoop &ML; 230 MachineBasicBlock *Preheader = nullptr; 231 MachineLoopInfo &MLI; 232 ReachingDefAnalysis &RDA; 233 const TargetRegisterInfo &TRI; 234 const ARMBaseInstrInfo &TII; 235 MachineFunction *MF = nullptr; 236 MachineInstr *InsertPt = nullptr; 237 MachineInstr *Start = nullptr; 238 MachineInstr *Dec = nullptr; 239 MachineInstr *End = nullptr; 240 MachineInstr *VCTP = nullptr; 241 MachineOperand TPNumElements; 242 SmallPtrSet<MachineInstr*, 4> SecondaryVCTPs; 243 VPTBlock *CurrentBlock = nullptr; 244 SetVector<MachineInstr*> CurrentPredicate; 245 SmallVector<VPTBlock, 4> VPTBlocks; 246 SmallPtrSet<MachineInstr*, 4> ToRemove; 247 SmallPtrSet<MachineInstr*, 4> BlockMasksToRecompute; 248 bool Revert = false; 249 bool CannotTailPredicate = false; 250 251 LowOverheadLoop(MachineLoop &ML, MachineLoopInfo &MLI, 252 ReachingDefAnalysis &RDA, const TargetRegisterInfo &TRI, 253 const ARMBaseInstrInfo &TII) 254 : ML(ML), MLI(MLI), RDA(RDA), TRI(TRI), TII(TII), 255 TPNumElements(MachineOperand::CreateImm(0)) { 256 MF = ML.getHeader()->getParent(); 257 if (auto *MBB = ML.getLoopPreheader()) 258 Preheader = MBB; 259 else if (auto *MBB = MLI.findLoopPreheader(&ML, true)) 260 Preheader = MBB; 261 } 262 263 // If this is an MVE instruction, check that we know how to use tail 264 // predication with it. Record VPT blocks and return whether the 265 // instruction is valid for tail predication. 266 bool ValidateMVEInst(MachineInstr *MI); 267 268 void AnalyseMVEInst(MachineInstr *MI) { 269 CannotTailPredicate = !ValidateMVEInst(MI); 270 } 271 272 bool IsTailPredicationLegal() const { 273 // For now, let's keep things really simple and only support a single 274 // block for tail predication. 275 return !Revert && FoundAllComponents() && VCTP && 276 !CannotTailPredicate && ML.getNumBlocks() == 1; 277 } 278 279 // Check that the predication in the loop will be equivalent once we 280 // perform the conversion. Also ensure that we can provide the number 281 // of elements to the loop start instruction. 282 bool ValidateTailPredicate(MachineInstr *StartInsertPt); 283 284 // Check that any values available outside of the loop will be the same 285 // after tail predication conversion. 286 bool ValidateLiveOuts(); 287 288 // Is it safe to define LR with DLS/WLS? 289 // LR can be defined if it is the operand to start, because it's the same 290 // value, or if it's going to be equivalent to the operand to Start. 291 MachineInstr *isSafeToDefineLR(); 292 293 // Check the branch targets are within range and we satisfy our 294 // restrictions. 295 void CheckLegality(ARMBasicBlockUtils *BBUtils); 296 297 bool FoundAllComponents() const { 298 return Start && Dec && End; 299 } 300 301 SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; } 302 303 // Return the operand for the loop start instruction. This will be the loop 304 // iteration count, or the number of elements if we're tail predicating. 305 MachineOperand &getLoopStartOperand() { 306 return IsTailPredicationLegal() ? TPNumElements : Start->getOperand(0); 307 } 308 309 unsigned getStartOpcode() const { 310 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 311 if (!IsTailPredicationLegal()) 312 return IsDo ? ARM::t2DLS : ARM::t2WLS; 313 314 return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo); 315 } 316 317 void dump() const { 318 if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; 319 if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; 320 if (End) dbgs() << "ARM Loops: Found Loop End: " << *End; 321 if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP; 322 if (!FoundAllComponents()) 323 dbgs() << "ARM Loops: Not a low-overhead loop.\n"; 324 else if (!(Start && Dec && End)) 325 dbgs() << "ARM Loops: Failed to find all loop components.\n"; 326 } 327 }; 328 329 class ARMLowOverheadLoops : public MachineFunctionPass { 330 MachineFunction *MF = nullptr; 331 MachineLoopInfo *MLI = nullptr; 332 ReachingDefAnalysis *RDA = nullptr; 333 const ARMBaseInstrInfo *TII = nullptr; 334 MachineRegisterInfo *MRI = nullptr; 335 const TargetRegisterInfo *TRI = nullptr; 336 std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr; 337 338 public: 339 static char ID; 340 341 ARMLowOverheadLoops() : MachineFunctionPass(ID) { } 342 343 void getAnalysisUsage(AnalysisUsage &AU) const override { 344 AU.setPreservesCFG(); 345 AU.addRequired<MachineLoopInfo>(); 346 AU.addRequired<ReachingDefAnalysis>(); 347 MachineFunctionPass::getAnalysisUsage(AU); 348 } 349 350 bool runOnMachineFunction(MachineFunction &MF) override; 351 352 MachineFunctionProperties getRequiredProperties() const override { 353 return MachineFunctionProperties().set( 354 MachineFunctionProperties::Property::NoVRegs).set( 355 MachineFunctionProperties::Property::TracksLiveness); 356 } 357 358 StringRef getPassName() const override { 359 return ARM_LOW_OVERHEAD_LOOPS_NAME; 360 } 361 362 private: 363 bool ProcessLoop(MachineLoop *ML); 364 365 bool RevertNonLoops(); 366 367 void RevertWhile(MachineInstr *MI) const; 368 369 bool RevertLoopDec(MachineInstr *MI) const; 370 371 void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; 372 373 void ConvertVPTBlocks(LowOverheadLoop &LoLoop); 374 375 MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); 376 377 void Expand(LowOverheadLoop &LoLoop); 378 379 void IterationCountDCE(LowOverheadLoop &LoLoop); 380 }; 381 } 382 383 char ARMLowOverheadLoops::ID = 0; 384 385 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, 386 false, false) 387 388 MachineInstr *LowOverheadLoop::isSafeToDefineLR() { 389 // We can define LR because LR already contains the same value. 390 if (Start->getOperand(0).getReg() == ARM::LR) 391 return Start; 392 393 unsigned CountReg = Start->getOperand(0).getReg(); 394 auto IsMoveLR = [&CountReg](MachineInstr *MI) { 395 return MI->getOpcode() == ARM::tMOVr && 396 MI->getOperand(0).getReg() == ARM::LR && 397 MI->getOperand(1).getReg() == CountReg && 398 MI->getOperand(2).getImm() == ARMCC::AL; 399 }; 400 401 MachineBasicBlock *MBB = Start->getParent(); 402 403 // Find an insertion point: 404 // - Is there a (mov lr, Count) before Start? If so, and nothing else writes 405 // to Count before Start, we can insert at that mov. 406 if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR)) 407 if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) 408 return LRDef; 409 410 // - Is there a (mov lr, Count) after Start? If so, and nothing else writes 411 // to Count after Start, we can insert at that mov. 412 if (auto *LRDef = RDA.getLocalLiveOutMIDef(MBB, ARM::LR)) 413 if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) 414 return LRDef; 415 416 // We've found no suitable LR def and Start doesn't use LR directly. Can we 417 // just define LR anyway? 418 return RDA.isSafeToDefRegAt(Start, ARM::LR) ? Start : nullptr; 419 } 420 421 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt) { 422 assert(VCTP && "VCTP instruction expected but is not set"); 423 // All predication within the loop should be based on vctp. If the block 424 // isn't predicated on entry, check whether the vctp is within the block 425 // and that all other instructions are then predicated on it. 426 for (auto &Block : VPTBlocks) { 427 if (Block.IsPredicatedOn(VCTP)) 428 continue; 429 if (Block.HasNonUniformPredicate() && !isVCTP(Block.getDivergent()->MI)) { 430 LLVM_DEBUG(dbgs() << "ARM Loops: Found unsupported diverging predicate: " 431 << *Block.getDivergent()->MI); 432 return false; 433 } 434 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 435 for (auto &PredMI : Insts) { 436 // Check the instructions in the block and only allow: 437 // - VCTPs 438 // - Instructions predicated on the main VCTP 439 // - Any VCMP 440 // - VCMPs just "and" their result with VPR.P0. Whether they are 441 // located before/after the VCTP is irrelevant - the end result will 442 // be the same in both cases, so there's no point in requiring them 443 // to be located after the VCTP! 444 if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI) || 445 VCMPOpcodeToVPT(PredMI.MI->getOpcode()) != 0) 446 continue; 447 LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI 448 << " - which is predicated on:\n"; 449 for (auto *MI : PredMI.Predicates) 450 dbgs() << " - " << *MI); 451 return false; 452 } 453 } 454 455 if (!ValidateLiveOuts()) { 456 LLVM_DEBUG(dbgs() << "ARM Loops: Invalid live outs.\n"); 457 return false; 458 } 459 460 // For tail predication, we need to provide the number of elements, instead 461 // of the iteration count, to the loop start instruction. The number of 462 // elements is provided to the vctp instruction, so we need to check that 463 // we can use this register at InsertPt. 464 TPNumElements = VCTP->getOperand(1); 465 Register NumElements = TPNumElements.getReg(); 466 467 // If the register is defined within loop, then we can't perform TP. 468 // TODO: Check whether this is just a mov of a register that would be 469 // available. 470 if (RDA.hasLocalDefBefore(VCTP, NumElements)) { 471 LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n"); 472 return false; 473 } 474 475 // The element count register maybe defined after InsertPt, in which case we 476 // need to try to move either InsertPt or the def so that the [w|d]lstp can 477 // use the value. 478 MachineBasicBlock *InsertBB = StartInsertPt->getParent(); 479 480 if (!RDA.isReachingDefLiveOut(StartInsertPt, NumElements)) { 481 if (auto *ElemDef = RDA.getLocalLiveOutMIDef(InsertBB, NumElements)) { 482 if (RDA.isSafeToMoveForwards(ElemDef, StartInsertPt)) { 483 ElemDef->removeFromParent(); 484 InsertBB->insert(MachineBasicBlock::iterator(StartInsertPt), ElemDef); 485 LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: " 486 << *ElemDef); 487 } else if (RDA.isSafeToMoveBackwards(StartInsertPt, ElemDef)) { 488 StartInsertPt->removeFromParent(); 489 InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), 490 StartInsertPt); 491 LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); 492 } else { 493 // If we fail to move an instruction and the element count is provided 494 // by a mov, use the mov operand if it will have the same value at the 495 // insertion point 496 MachineOperand Operand = ElemDef->getOperand(1); 497 if (isMovRegOpcode(ElemDef->getOpcode()) && 498 RDA.getUniqueReachingMIDef(ElemDef, Operand.getReg()) == 499 RDA.getUniqueReachingMIDef(StartInsertPt, Operand.getReg())) { 500 TPNumElements = Operand; 501 NumElements = TPNumElements.getReg(); 502 } else { 503 LLVM_DEBUG(dbgs() 504 << "ARM Loops: Unable to move element count to loop " 505 << "start instruction.\n"); 506 return false; 507 } 508 } 509 } 510 } 511 512 // Especially in the case of while loops, InsertBB may not be the 513 // preheader, so we need to check that the register isn't redefined 514 // before entering the loop. 515 auto CannotProvideElements = [this](MachineBasicBlock *MBB, 516 Register NumElements) { 517 // NumElements is redefined in this block. 518 if (RDA.hasLocalDefBefore(&MBB->back(), NumElements)) 519 return true; 520 521 // Don't continue searching up through multiple predecessors. 522 if (MBB->pred_size() > 1) 523 return true; 524 525 return false; 526 }; 527 528 // First, find the block that looks like the preheader. 529 MachineBasicBlock *MBB = Preheader; 530 if (!MBB) { 531 LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find preheader.\n"); 532 return false; 533 } 534 535 // Then search backwards for a def, until we get to InsertBB. 536 while (MBB != InsertBB) { 537 if (CannotProvideElements(MBB, NumElements)) { 538 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to provide element count.\n"); 539 return false; 540 } 541 MBB = *MBB->pred_begin(); 542 } 543 544 // Check that the value change of the element count is what we expect and 545 // that the predication will be equivalent. For this we need: 546 // NumElements = NumElements - VectorWidth. The sub will be a sub immediate 547 // and we can also allow register copies within the chain too. 548 auto IsValidSub = [](MachineInstr *MI, int ExpectedVecWidth) { 549 return -getAddSubImmediate(*MI) == ExpectedVecWidth; 550 }; 551 552 MBB = VCTP->getParent(); 553 // Remove modifications to the element count since they have no purpose in a 554 // tail predicated loop. Explicitly refer to the vctp operand no matter which 555 // register NumElements has been assigned to, since that is what the 556 // modifications will be using 557 if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), 558 VCTP->getOperand(1).getReg())) { 559 SmallPtrSet<MachineInstr*, 2> ElementChain; 560 SmallPtrSet<MachineInstr*, 2> Ignore = { VCTP }; 561 unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode()); 562 563 Ignore.insert(SecondaryVCTPs.begin(), SecondaryVCTPs.end()); 564 565 if (RDA.isSafeToRemove(Def, ElementChain, Ignore)) { 566 bool FoundSub = false; 567 568 for (auto *MI : ElementChain) { 569 if (isMovRegOpcode(MI->getOpcode())) 570 continue; 571 572 if (isSubImmOpcode(MI->getOpcode())) { 573 if (FoundSub || !IsValidSub(MI, ExpectedVectorWidth)) 574 return false; 575 FoundSub = true; 576 } else 577 return false; 578 } 579 580 LLVM_DEBUG(dbgs() << "ARM Loops: Will remove element count chain:\n"; 581 for (auto *MI : ElementChain) 582 dbgs() << " - " << *MI); 583 ToRemove.insert(ElementChain.begin(), ElementChain.end()); 584 } 585 } 586 return true; 587 } 588 589 static bool isRegInClass(const MachineOperand &MO, 590 const TargetRegisterClass *Class) { 591 return MO.isReg() && MO.getReg() && Class->contains(MO.getReg()); 592 } 593 594 // MVE 'narrowing' operate on half a lane, reading from half and writing 595 // to half, which are referred to has the top and bottom half. The other 596 // half retains its previous value. 597 static bool retainsPreviousHalfElement(const MachineInstr &MI) { 598 const MCInstrDesc &MCID = MI.getDesc(); 599 uint64_t Flags = MCID.TSFlags; 600 return (Flags & ARMII::RetainsPreviousHalfElement) != 0; 601 } 602 603 // Some MVE instructions read from the top/bottom halves of their operand(s) 604 // and generate a vector result with result elements that are double the 605 // width of the input. 606 static bool producesDoubleWidthResult(const MachineInstr &MI) { 607 const MCInstrDesc &MCID = MI.getDesc(); 608 uint64_t Flags = MCID.TSFlags; 609 return (Flags & ARMII::DoubleWidthResult) != 0; 610 } 611 612 static bool isHorizontalReduction(const MachineInstr &MI) { 613 const MCInstrDesc &MCID = MI.getDesc(); 614 uint64_t Flags = MCID.TSFlags; 615 return (Flags & ARMII::HorizontalReduction) != 0; 616 } 617 618 // Can this instruction generate a non-zero result when given only zeroed 619 // operands? This allows us to know that, given operands with false bytes 620 // zeroed by masked loads, that the result will also contain zeros in those 621 // bytes. 622 static bool canGenerateNonZeros(const MachineInstr &MI) { 623 624 // Check for instructions which can write into a larger element size, 625 // possibly writing into a previous zero'd lane. 626 if (producesDoubleWidthResult(MI)) 627 return true; 628 629 switch (MI.getOpcode()) { 630 default: 631 break; 632 // FIXME: VNEG FP and -0? I think we'll need to handle this once we allow 633 // fp16 -> fp32 vector conversions. 634 // Instructions that perform a NOT will generate 1s from 0s. 635 case ARM::MVE_VMVN: 636 case ARM::MVE_VORN: 637 // Count leading zeros will do just that! 638 case ARM::MVE_VCLZs8: 639 case ARM::MVE_VCLZs16: 640 case ARM::MVE_VCLZs32: 641 return true; 642 } 643 return false; 644 } 645 646 // Look at its register uses to see if it only can only receive zeros 647 // into its false lanes which would then produce zeros. Also check that 648 // the output register is also defined by an FalseLanesZero instruction 649 // so that if tail-predication happens, the lanes that aren't updated will 650 // still be zeros. 651 static bool producesFalseLanesZero(MachineInstr &MI, 652 const TargetRegisterClass *QPRs, 653 const ReachingDefAnalysis &RDA, 654 InstSet &FalseLanesZero) { 655 if (canGenerateNonZeros(MI)) 656 return false; 657 658 bool isPredicated = isVectorPredicated(&MI); 659 // Predicated loads will write zeros to the falsely predicated bytes of the 660 // destination register. 661 if (MI.mayLoad()) 662 return isPredicated; 663 664 auto IsZeroInit = [](MachineInstr *Def) { 665 return !isVectorPredicated(Def) && 666 Def->getOpcode() == ARM::MVE_VMOVimmi32 && 667 Def->getOperand(1).getImm() == 0; 668 }; 669 670 bool AllowScalars = isHorizontalReduction(MI); 671 for (auto &MO : MI.operands()) { 672 if (!MO.isReg() || !MO.getReg()) 673 continue; 674 if (!isRegInClass(MO, QPRs) && AllowScalars) 675 continue; 676 677 // Check that this instruction will produce zeros in its false lanes: 678 // - If it only consumes false lanes zero or constant 0 (vmov #0) 679 // - If it's predicated, it only matters that it's def register already has 680 // false lane zeros, so we can ignore the uses. 681 SmallPtrSet<MachineInstr *, 2> Defs; 682 RDA.getGlobalReachingDefs(&MI, MO.getReg(), Defs); 683 for (auto *Def : Defs) { 684 if (Def == &MI || FalseLanesZero.count(Def) || IsZeroInit(Def)) 685 continue; 686 if (MO.isUse() && isPredicated) 687 continue; 688 return false; 689 } 690 } 691 LLVM_DEBUG(dbgs() << "ARM Loops: Always False Zeros: " << MI); 692 return true; 693 } 694 695 bool LowOverheadLoop::ValidateLiveOuts() { 696 // We want to find out if the tail-predicated version of this loop will 697 // produce the same values as the loop in its original form. For this to 698 // be true, the newly inserted implicit predication must not change the 699 // the (observable) results. 700 // We're doing this because many instructions in the loop will not be 701 // predicated and so the conversion from VPT predication to tail-predication 702 // can result in different values being produced; due to the tail-predication 703 // preventing many instructions from updating their falsely predicated 704 // lanes. This analysis assumes that all the instructions perform lane-wise 705 // operations and don't perform any exchanges. 706 // A masked load, whether through VPT or tail predication, will write zeros 707 // to any of the falsely predicated bytes. So, from the loads, we know that 708 // the false lanes are zeroed and here we're trying to track that those false 709 // lanes remain zero, or where they change, the differences are masked away 710 // by their user(s). 711 // All MVE stores have to be predicated, so we know that any predicate load 712 // operands, or stored results are equivalent already. Other explicitly 713 // predicated instructions will perform the same operation in the original 714 // loop and the tail-predicated form too. Because of this, we can insert 715 // loads, stores and other predicated instructions into our Predicated 716 // set and build from there. 717 const TargetRegisterClass *QPRs = TRI.getRegClass(ARM::MQPRRegClassID); 718 SetVector<MachineInstr *> FalseLanesUnknown; 719 SmallPtrSet<MachineInstr *, 4> FalseLanesZero; 720 SmallPtrSet<MachineInstr *, 4> Predicated; 721 MachineBasicBlock *Header = ML.getHeader(); 722 723 for (auto &MI : *Header) { 724 if (!shouldInspect(MI)) 725 continue; 726 727 if (isVCTP(&MI) || isVPTOpcode(MI.getOpcode())) 728 continue; 729 730 bool isPredicated = isVectorPredicated(&MI); 731 bool retainsOrReduces = 732 retainsPreviousHalfElement(MI) || isHorizontalReduction(MI); 733 734 if (isPredicated) 735 Predicated.insert(&MI); 736 if (producesFalseLanesZero(MI, QPRs, RDA, FalseLanesZero)) 737 FalseLanesZero.insert(&MI); 738 else if (MI.getNumDefs() == 0) 739 continue; 740 else if (!isPredicated && retainsOrReduces) 741 return false; 742 else if (!isPredicated) 743 FalseLanesUnknown.insert(&MI); 744 } 745 746 auto HasPredicatedUsers = [this](MachineInstr *MI, const MachineOperand &MO, 747 SmallPtrSetImpl<MachineInstr *> &Predicated) { 748 SmallPtrSet<MachineInstr *, 2> Uses; 749 RDA.getGlobalUses(MI, MO.getReg(), Uses); 750 for (auto *Use : Uses) { 751 if (Use != MI && !Predicated.count(Use)) 752 return false; 753 } 754 return true; 755 }; 756 757 // Visit the unknowns in reverse so that we can start at the values being 758 // stored and then we can work towards the leaves, hopefully adding more 759 // instructions to Predicated. Successfully terminating the loop means that 760 // all the unknown values have to found to be masked by predicated user(s). 761 // For any unpredicated values, we store them in NonPredicated so that we 762 // can later check whether these form a reduction. 763 SmallPtrSet<MachineInstr*, 2> NonPredicated; 764 for (auto *MI : reverse(FalseLanesUnknown)) { 765 for (auto &MO : MI->operands()) { 766 if (!isRegInClass(MO, QPRs) || !MO.isDef()) 767 continue; 768 if (!HasPredicatedUsers(MI, MO, Predicated)) { 769 LLVM_DEBUG(dbgs() << "ARM Loops: Found an unknown def of : " 770 << TRI.getRegAsmName(MO.getReg()) << " at " << *MI); 771 NonPredicated.insert(MI); 772 break; 773 } 774 } 775 // Any unknown false lanes have been masked away by the user(s). 776 if (!NonPredicated.contains(MI)) 777 Predicated.insert(MI); 778 } 779 780 SmallPtrSet<MachineInstr *, 2> LiveOutMIs; 781 SmallVector<MachineBasicBlock *, 2> ExitBlocks; 782 ML.getExitBlocks(ExitBlocks); 783 assert(ML.getNumBlocks() == 1 && "Expected single block loop!"); 784 assert(ExitBlocks.size() == 1 && "Expected a single exit block"); 785 MachineBasicBlock *ExitBB = ExitBlocks.front(); 786 for (const MachineBasicBlock::RegisterMaskPair &RegMask : ExitBB->liveins()) { 787 // TODO: Instead of blocking predication, we could move the vctp to the exit 788 // block and calculate it's operand there in or the preheader. 789 if (RegMask.PhysReg == ARM::VPR) 790 return false; 791 // Check Q-regs that are live in the exit blocks. We don't collect scalars 792 // because they won't be affected by lane predication. 793 if (QPRs->contains(RegMask.PhysReg)) 794 if (auto *MI = RDA.getLocalLiveOutMIDef(Header, RegMask.PhysReg)) 795 LiveOutMIs.insert(MI); 796 } 797 798 // We've already validated that any VPT predication within the loop will be 799 // equivalent when we perform the predication transformation; so we know that 800 // any VPT predicated instruction is predicated upon VCTP. Any live-out 801 // instruction needs to be predicated, so check this here. The instructions 802 // in NonPredicated have been found to be a reduction that we can ensure its 803 // legality. 804 for (auto *MI : LiveOutMIs) { 805 if (NonPredicated.count(MI) && FalseLanesUnknown.contains(MI)) { 806 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to handle live out: " << *MI); 807 return false; 808 } 809 } 810 811 return true; 812 } 813 814 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils) { 815 if (Revert) 816 return; 817 818 if (!End->getOperand(1).isMBB()) 819 report_fatal_error("Expected LoopEnd to target basic block"); 820 821 // TODO Maybe there's cases where the target doesn't have to be the header, 822 // but for now be safe and revert. 823 if (End->getOperand(1).getMBB() != ML.getHeader()) { 824 LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); 825 Revert = true; 826 return; 827 } 828 829 // The WLS and LE instructions have 12-bits for the label offset. WLS 830 // requires a positive offset, while LE uses negative. 831 if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML.getHeader()) || 832 !BBUtils->isBBInRange(End, ML.getHeader(), 4094)) { 833 LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); 834 Revert = true; 835 return; 836 } 837 838 if (Start->getOpcode() == ARM::t2WhileLoopStart && 839 (BBUtils->getOffsetOf(Start) > 840 BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || 841 !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { 842 LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); 843 Revert = true; 844 return; 845 } 846 847 InsertPt = Revert ? nullptr : isSafeToDefineLR(); 848 if (!InsertPt) { 849 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); 850 Revert = true; 851 return; 852 } else 853 LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); 854 855 if (!IsTailPredicationLegal()) { 856 LLVM_DEBUG(if (!VCTP) 857 dbgs() << "ARM Loops: Didn't find a VCTP instruction.\n"; 858 dbgs() << "ARM Loops: Tail-predication is not valid.\n"); 859 return; 860 } 861 862 assert(ML.getBlocks().size() == 1 && 863 "Shouldn't be processing a loop with more than one block"); 864 CannotTailPredicate = !ValidateTailPredicate(InsertPt); 865 LLVM_DEBUG(if (CannotTailPredicate) 866 dbgs() << "ARM Loops: Couldn't validate tail predicate.\n"); 867 } 868 869 bool LowOverheadLoop::ValidateMVEInst(MachineInstr* MI) { 870 if (CannotTailPredicate) 871 return false; 872 873 if (!shouldInspect(*MI)) 874 return true; 875 876 if (MI->getOpcode() == ARM::MVE_VPSEL || 877 MI->getOpcode() == ARM::MVE_VPNOT) { 878 // TODO: Allow VPSEL and VPNOT, we currently cannot because: 879 // 1) It will use the VPR as a predicate operand, but doesn't have to be 880 // instead a VPT block, which means we can assert while building up 881 // the VPT block because we don't find another VPT or VPST to being a new 882 // one. 883 // 2) VPSEL still requires a VPR operand even after tail predicating, 884 // which means we can't remove it unless there is another 885 // instruction, such as vcmp, that can provide the VPR def. 886 return false; 887 } 888 889 if (isVCTP(MI)) { 890 // If we find another VCTP, check whether it uses the same value as the main VCTP. 891 // If it does, store it in the SecondaryVCTPs set, else refuse it. 892 if (VCTP) { 893 if (!VCTP->getOperand(1).isIdenticalTo(MI->getOperand(1)) || 894 !RDA.hasSameReachingDef(VCTP, MI, MI->getOperand(1).getReg())) { 895 LLVM_DEBUG(dbgs() << "ARM Loops: Found VCTP with a different reaching " 896 "definition from the main VCTP"); 897 return false; 898 } 899 LLVM_DEBUG(dbgs() << "ARM Loops: Found secondary VCTP: " << *MI); 900 SecondaryVCTPs.insert(MI); 901 } else { 902 LLVM_DEBUG(dbgs() << "ARM Loops: Found 'main' VCTP: " << *MI); 903 VCTP = MI; 904 } 905 } else if (isVPTOpcode(MI->getOpcode())) { 906 if (MI->getOpcode() != ARM::MVE_VPST) { 907 assert(MI->findRegisterDefOperandIdx(ARM::VPR) != -1 && 908 "VPT does not implicitly define VPR?!"); 909 CurrentPredicate.clear(); 910 CurrentPredicate.insert(MI); 911 } 912 913 VPTBlocks.emplace_back(MI, CurrentPredicate); 914 CurrentBlock = &VPTBlocks.back(); 915 return true; 916 } 917 918 // Inspect uses first so that any instructions that alter the VPR don't 919 // alter the predicate upon themselves. 920 const MCInstrDesc &MCID = MI->getDesc(); 921 bool IsUse = false; 922 bool IsDef = false; 923 for (int i = MI->getNumOperands() - 1; i >= 0; --i) { 924 const MachineOperand &MO = MI->getOperand(i); 925 if (!MO.isReg() || MO.getReg() != ARM::VPR) 926 continue; 927 928 if (MO.isDef()) { 929 CurrentPredicate.insert(MI); 930 IsDef = true; 931 } else if (ARM::isVpred(MCID.OpInfo[i].OperandType)) { 932 CurrentBlock->addInst(MI, CurrentPredicate); 933 IsUse = true; 934 } else { 935 LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); 936 return false; 937 } 938 } 939 940 // If this instruction defines the VPR, update the predicate for the 941 // proceeding instructions. 942 if (IsDef) { 943 // Clear the existing predicate when we're not in VPT Active state. 944 if (!isVectorPredicated(MI)) 945 CurrentPredicate.clear(); 946 CurrentPredicate.insert(MI); 947 LLVM_DEBUG(dbgs() << "ARM Loops: Adding Predicate: " << *MI); 948 } 949 950 // If we find a vpr def that is not already predicated on the vctp, we've 951 // got disjoint predicates that may not be equivalent when we do the 952 // conversion. 953 if (IsDef && !IsUse && VCTP && !isVCTP(MI)) { 954 LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI); 955 return false; 956 } 957 958 // If we find an instruction that has been marked as not valid for tail 959 // predication, only allow the instruction if it's contained within a valid 960 // VPT block. 961 bool RequiresExplicitPredication = 962 (MCID.TSFlags & ARMII::ValidForTailPredication) == 0; 963 if (isDomainMVE(MI) && RequiresExplicitPredication) { 964 LLVM_DEBUG(if (!IsUse) 965 dbgs() << "ARM Loops: Can't tail predicate: " << *MI); 966 return IsUse; 967 } 968 969 // If the instruction is already explicitly predicated, then the conversion 970 // will be fine, but ensure that all store operations are predicated. 971 return !IsUse && MI->mayStore() ? false : true; 972 } 973 974 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { 975 const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); 976 if (!ST.hasLOB()) 977 return false; 978 979 MF = &mf; 980 LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); 981 982 MLI = &getAnalysis<MachineLoopInfo>(); 983 RDA = &getAnalysis<ReachingDefAnalysis>(); 984 MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); 985 MRI = &MF->getRegInfo(); 986 TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); 987 TRI = ST.getRegisterInfo(); 988 BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); 989 BBUtils->computeAllBlockSizes(); 990 BBUtils->adjustBBOffsetsAfter(&MF->front()); 991 992 bool Changed = false; 993 for (auto ML : *MLI) { 994 if (!ML->getParentLoop()) 995 Changed |= ProcessLoop(ML); 996 } 997 Changed |= RevertNonLoops(); 998 return Changed; 999 } 1000 1001 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { 1002 1003 bool Changed = false; 1004 1005 // Process inner loops first. 1006 for (auto I = ML->begin(), E = ML->end(); I != E; ++I) 1007 Changed |= ProcessLoop(*I); 1008 1009 LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n"; 1010 if (auto *Preheader = ML->getLoopPreheader()) 1011 dbgs() << " - " << Preheader->getName() << "\n"; 1012 else if (auto *Preheader = MLI->findLoopPreheader(ML)) 1013 dbgs() << " - " << Preheader->getName() << "\n"; 1014 else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) 1015 dbgs() << " - " << Preheader->getName() << "\n"; 1016 for (auto *MBB : ML->getBlocks()) 1017 dbgs() << " - " << MBB->getName() << "\n"; 1018 ); 1019 1020 // Search the given block for a loop start instruction. If one isn't found, 1021 // and there's only one predecessor block, search that one too. 1022 std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart = 1023 [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* { 1024 for (auto &MI : *MBB) { 1025 if (isLoopStart(MI)) 1026 return &MI; 1027 } 1028 if (MBB->pred_size() == 1) 1029 return SearchForStart(*MBB->pred_begin()); 1030 return nullptr; 1031 }; 1032 1033 LowOverheadLoop LoLoop(*ML, *MLI, *RDA, *TRI, *TII); 1034 // Search the preheader for the start intrinsic. 1035 // FIXME: I don't see why we shouldn't be supporting multiple predecessors 1036 // with potentially multiple set.loop.iterations, so we need to enable this. 1037 if (LoLoop.Preheader) 1038 LoLoop.Start = SearchForStart(LoLoop.Preheader); 1039 else 1040 return false; 1041 1042 // Find the low-overhead loop components and decide whether or not to fall 1043 // back to a normal loop. Also look for a vctp instructions and decide 1044 // whether we can convert that predicate using tail predication. 1045 for (auto *MBB : reverse(ML->getBlocks())) { 1046 for (auto &MI : *MBB) { 1047 if (MI.isDebugValue()) 1048 continue; 1049 else if (MI.getOpcode() == ARM::t2LoopDec) 1050 LoLoop.Dec = &MI; 1051 else if (MI.getOpcode() == ARM::t2LoopEnd) 1052 LoLoop.End = &MI; 1053 else if (isLoopStart(MI)) 1054 LoLoop.Start = &MI; 1055 else if (MI.getDesc().isCall()) { 1056 // TODO: Though the call will require LE to execute again, does this 1057 // mean we should revert? Always executing LE hopefully should be 1058 // faster than performing a sub,cmp,br or even subs,br. 1059 LoLoop.Revert = true; 1060 LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); 1061 } else { 1062 // Record VPR defs and build up their corresponding vpt blocks. 1063 // Check we know how to tail predicate any mve instructions. 1064 LoLoop.AnalyseMVEInst(&MI); 1065 } 1066 } 1067 } 1068 1069 LLVM_DEBUG(LoLoop.dump()); 1070 if (!LoLoop.FoundAllComponents()) { 1071 LLVM_DEBUG(dbgs() << "ARM Loops: Didn't find loop start, update, end\n"); 1072 return false; 1073 } 1074 1075 // Check that the only instruction using LoopDec is LoopEnd. 1076 // TODO: Check for copy chains that really have no effect. 1077 SmallPtrSet<MachineInstr*, 2> Uses; 1078 RDA->getReachingLocalUses(LoLoop.Dec, ARM::LR, Uses); 1079 if (Uses.size() > 1 || !Uses.count(LoLoop.End)) { 1080 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to remove LoopDec.\n"); 1081 LoLoop.Revert = true; 1082 } 1083 LoLoop.CheckLegality(BBUtils.get()); 1084 Expand(LoLoop); 1085 return true; 1086 } 1087 1088 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a 1089 // beq that branches to the exit branch. 1090 // TODO: We could also try to generate a cbz if the value in LR is also in 1091 // another low register. 1092 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const { 1093 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI); 1094 MachineBasicBlock *MBB = MI->getParent(); 1095 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1096 TII->get(ARM::t2CMPri)); 1097 MIB.add(MI->getOperand(0)); 1098 MIB.addImm(0); 1099 MIB.addImm(ARMCC::AL); 1100 MIB.addReg(ARM::NoRegister); 1101 1102 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 1103 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 1104 ARM::tBcc : ARM::t2Bcc; 1105 1106 MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 1107 MIB.add(MI->getOperand(1)); // branch target 1108 MIB.addImm(ARMCC::EQ); // condition code 1109 MIB.addReg(ARM::CPSR); 1110 MI->eraseFromParent(); 1111 } 1112 1113 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI) const { 1114 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI); 1115 MachineBasicBlock *MBB = MI->getParent(); 1116 SmallPtrSet<MachineInstr*, 1> Ignore; 1117 for (auto I = MachineBasicBlock::iterator(MI), E = MBB->end(); I != E; ++I) { 1118 if (I->getOpcode() == ARM::t2LoopEnd) { 1119 Ignore.insert(&*I); 1120 break; 1121 } 1122 } 1123 1124 // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS. 1125 bool SetFlags = RDA->isSafeToDefRegAt(MI, ARM::CPSR, Ignore); 1126 1127 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1128 TII->get(ARM::t2SUBri)); 1129 MIB.addDef(ARM::LR); 1130 MIB.add(MI->getOperand(1)); 1131 MIB.add(MI->getOperand(2)); 1132 MIB.addImm(ARMCC::AL); 1133 MIB.addReg(0); 1134 1135 if (SetFlags) { 1136 MIB.addReg(ARM::CPSR); 1137 MIB->getOperand(5).setIsDef(true); 1138 } else 1139 MIB.addReg(0); 1140 1141 MI->eraseFromParent(); 1142 return SetFlags; 1143 } 1144 1145 // Generate a subs, or sub and cmp, and a branch instead of an LE. 1146 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { 1147 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI); 1148 1149 MachineBasicBlock *MBB = MI->getParent(); 1150 // Create cmp 1151 if (!SkipCmp) { 1152 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 1153 TII->get(ARM::t2CMPri)); 1154 MIB.addReg(ARM::LR); 1155 MIB.addImm(0); 1156 MIB.addImm(ARMCC::AL); 1157 MIB.addReg(ARM::NoRegister); 1158 } 1159 1160 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 1161 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 1162 ARM::tBcc : ARM::t2Bcc; 1163 1164 // Create bne 1165 MachineInstrBuilder MIB = 1166 BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 1167 MIB.add(MI->getOperand(1)); // branch target 1168 MIB.addImm(ARMCC::NE); // condition code 1169 MIB.addReg(ARM::CPSR); 1170 MI->eraseFromParent(); 1171 } 1172 1173 // Perform dead code elimation on the loop iteration count setup expression. 1174 // If we are tail-predicating, the number of elements to be processed is the 1175 // operand of the VCTP instruction in the vector body, see getCount(), which is 1176 // register $r3 in this example: 1177 // 1178 // $lr = big-itercount-expression 1179 // .. 1180 // t2DoLoopStart renamable $lr 1181 // vector.body: 1182 // .. 1183 // $vpr = MVE_VCTP32 renamable $r3 1184 // renamable $lr = t2LoopDec killed renamable $lr, 1 1185 // t2LoopEnd renamable $lr, %vector.body 1186 // tB %end 1187 // 1188 // What we would like achieve here is to replace the do-loop start pseudo 1189 // instruction t2DoLoopStart with: 1190 // 1191 // $lr = MVE_DLSTP_32 killed renamable $r3 1192 // 1193 // Thus, $r3 which defines the number of elements, is written to $lr, 1194 // and then we want to delete the whole chain that used to define $lr, 1195 // see the comment below how this chain could look like. 1196 // 1197 void ARMLowOverheadLoops::IterationCountDCE(LowOverheadLoop &LoLoop) { 1198 if (!LoLoop.IsTailPredicationLegal()) 1199 return; 1200 1201 LLVM_DEBUG(dbgs() << "ARM Loops: Trying DCE on loop iteration count.\n"); 1202 1203 MachineInstr *Def = RDA->getMIOperand(LoLoop.Start, 0); 1204 if (!Def) { 1205 LLVM_DEBUG(dbgs() << "ARM Loops: Couldn't find iteration count.\n"); 1206 return; 1207 } 1208 1209 // Collect and remove the users of iteration count. 1210 SmallPtrSet<MachineInstr*, 4> Killed = { LoLoop.Start, LoLoop.Dec, 1211 LoLoop.End, LoLoop.InsertPt }; 1212 SmallPtrSet<MachineInstr*, 2> Remove; 1213 if (RDA->isSafeToRemove(Def, Remove, Killed)) 1214 LoLoop.ToRemove.insert(Remove.begin(), Remove.end()); 1215 else { 1216 LLVM_DEBUG(dbgs() << "ARM Loops: Unsafe to remove loop iteration count.\n"); 1217 return; 1218 } 1219 1220 // Collect the dead code and the MBBs in which they reside. 1221 RDA->collectKilledOperands(Def, Killed); 1222 SmallPtrSet<MachineBasicBlock*, 2> BasicBlocks; 1223 for (auto *MI : Killed) 1224 BasicBlocks.insert(MI->getParent()); 1225 1226 // Collect IT blocks in all affected basic blocks. 1227 std::map<MachineInstr *, SmallPtrSet<MachineInstr *, 2>> ITBlocks; 1228 for (auto *MBB : BasicBlocks) { 1229 for (auto &MI : *MBB) { 1230 if (MI.getOpcode() != ARM::t2IT) 1231 continue; 1232 RDA->getReachingLocalUses(&MI, ARM::ITSTATE, ITBlocks[&MI]); 1233 } 1234 } 1235 1236 // If we're removing all of the instructions within an IT block, then 1237 // also remove the IT instruction. 1238 SmallPtrSet<MachineInstr*, 2> ModifiedITs; 1239 for (auto *MI : Killed) { 1240 if (MachineOperand *MO = MI->findRegisterUseOperand(ARM::ITSTATE)) { 1241 MachineInstr *IT = RDA->getMIOperand(MI, *MO); 1242 auto &CurrentBlock = ITBlocks[IT]; 1243 CurrentBlock.erase(MI); 1244 if (CurrentBlock.empty()) 1245 ModifiedITs.erase(IT); 1246 else 1247 ModifiedITs.insert(IT); 1248 } 1249 } 1250 1251 // Delete the killed instructions only if we don't have any IT blocks that 1252 // need to be modified because we need to fixup the mask. 1253 // TODO: Handle cases where IT blocks are modified. 1254 if (ModifiedITs.empty()) { 1255 LLVM_DEBUG(dbgs() << "ARM Loops: Will remove iteration count:\n"; 1256 for (auto *MI : Killed) 1257 dbgs() << " - " << *MI); 1258 LoLoop.ToRemove.insert(Killed.begin(), Killed.end()); 1259 } else 1260 LLVM_DEBUG(dbgs() << "ARM Loops: Would need to modify IT block(s).\n"); 1261 } 1262 1263 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { 1264 LLVM_DEBUG(dbgs() << "ARM Loops: Expanding LoopStart.\n"); 1265 // When using tail-predication, try to delete the dead code that was used to 1266 // calculate the number of loop iterations. 1267 IterationCountDCE(LoLoop); 1268 1269 MachineInstr *InsertPt = LoLoop.InsertPt; 1270 MachineInstr *Start = LoLoop.Start; 1271 MachineBasicBlock *MBB = InsertPt->getParent(); 1272 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 1273 unsigned Opc = LoLoop.getStartOpcode(); 1274 MachineOperand &Count = LoLoop.getLoopStartOperand(); 1275 1276 MachineInstrBuilder MIB = 1277 BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); 1278 1279 MIB.addDef(ARM::LR); 1280 MIB.add(Count); 1281 if (!IsDo) 1282 MIB.add(Start->getOperand(1)); 1283 1284 // If we're inserting at a mov lr, then remove it as it's redundant. 1285 if (InsertPt != Start) 1286 LoLoop.ToRemove.insert(InsertPt); 1287 LoLoop.ToRemove.insert(Start); 1288 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); 1289 return &*MIB; 1290 } 1291 1292 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { 1293 auto RemovePredicate = [](MachineInstr *MI) { 1294 LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); 1295 if (int PIdx = llvm::findFirstVPTPredOperandIdx(*MI)) { 1296 assert(MI->getOperand(PIdx).getImm() == ARMVCC::Then && 1297 "Expected Then predicate!"); 1298 MI->getOperand(PIdx).setImm(ARMVCC::None); 1299 MI->getOperand(PIdx+1).setReg(0); 1300 } else 1301 llvm_unreachable("trying to unpredicate a non-predicated instruction"); 1302 }; 1303 1304 // There are a few scenarios which we have to fix up: 1305 // 1. VPT Blocks with non-uniform predicates: 1306 // - a. When the divergent instruction is a vctp 1307 // - b. When the block uses a vpst, and is only predicated on the vctp 1308 // - c. When the block uses a vpt and (optionally) contains one or more 1309 // vctp. 1310 // 2. VPT Blocks with uniform predicates: 1311 // - a. The block uses a vpst, and is only predicated on the vctp 1312 for (auto &Block : LoLoop.getVPTBlocks()) { 1313 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 1314 if (Block.HasNonUniformPredicate()) { 1315 PredicatedMI *Divergent = Block.getDivergent(); 1316 if (isVCTP(Divergent->MI)) { 1317 // The vctp will be removed, so the block mask of the vp(s)t will need 1318 // to be recomputed. 1319 LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen()); 1320 } else if (Block.isVPST() && Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { 1321 // The VPT block has a non-uniform predicate but it uses a vpst and its 1322 // entry is guarded only by a vctp, which means we: 1323 // - Need to remove the original vpst. 1324 // - Then need to unpredicate any following instructions, until 1325 // we come across the divergent vpr def. 1326 // - Insert a new vpst to predicate the instruction(s) that following 1327 // the divergent vpr def. 1328 // TODO: We could be producing more VPT blocks than necessary and could 1329 // fold the newly created one into a proceeding one. 1330 for (auto I = ++MachineBasicBlock::iterator(Block.getPredicateThen()), 1331 E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) 1332 RemovePredicate(&*I); 1333 1334 // Check if the instruction defining vpr is a vcmp so it can be combined 1335 // with the VPST This should be the divergent instruction 1336 MachineInstr *VCMP = VCMPOpcodeToVPT(Divergent->MI->getOpcode()) != 0 1337 ? Divergent->MI 1338 : nullptr; 1339 1340 unsigned Size = 0; 1341 auto E = MachineBasicBlock::reverse_iterator(Divergent->MI); 1342 auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI); 1343 MachineInstr *InsertAt = nullptr; 1344 while (I != E) { 1345 InsertAt = &*I; 1346 ++Size; 1347 ++I; 1348 } 1349 MachineInstrBuilder MIB; 1350 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " 1351 << *Block.getPredicateThen()); 1352 if (VCMP) { 1353 // Combine the VPST and VCMP into a VPT 1354 MIB = 1355 BuildMI(*InsertAt->getParent(), InsertAt, InsertAt->getDebugLoc(), 1356 TII->get(VCMPOpcodeToVPT(VCMP->getOpcode()))); 1357 MIB.addImm(ARMVCC::Then); 1358 // Register one 1359 MIB.add(VCMP->getOperand(1)); 1360 // Register two 1361 MIB.add(VCMP->getOperand(2)); 1362 // The comparison code, e.g. ge, eq, lt 1363 MIB.add(VCMP->getOperand(3)); 1364 LLVM_DEBUG(dbgs() 1365 << "ARM Loops: Combining with VCMP to VPT: " << *MIB); 1366 LoLoop.ToRemove.insert(VCMP); 1367 } else { 1368 // Create a VPST (with a null mask for now, we'll recompute it later) 1369 // or a VPT in case there was a VCMP right before it 1370 MIB = BuildMI(*InsertAt->getParent(), InsertAt, 1371 InsertAt->getDebugLoc(), TII->get(ARM::MVE_VPST)); 1372 MIB.addImm(0); 1373 LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); 1374 } 1375 LoLoop.ToRemove.insert(Block.getPredicateThen()); 1376 LoLoop.BlockMasksToRecompute.insert(MIB.getInstr()); 1377 } 1378 // Else, if the block uses a vpt, iterate over the block, removing the 1379 // extra VCTPs it may contain. 1380 else if (Block.isVPT()) { 1381 bool RemovedVCTP = false; 1382 for (PredicatedMI &Elt : Block.getInsts()) { 1383 MachineInstr *MI = Elt.MI; 1384 if (isVCTP(MI)) { 1385 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *MI); 1386 LoLoop.ToRemove.insert(MI); 1387 RemovedVCTP = true; 1388 continue; 1389 } 1390 } 1391 if (RemovedVCTP) 1392 LoLoop.BlockMasksToRecompute.insert(Block.getPredicateThen()); 1393 } 1394 } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP) && Block.isVPST()) { 1395 // A vpt block starting with VPST, is only predicated upon vctp and has no 1396 // internal vpr defs: 1397 // - Remove vpst. 1398 // - Unpredicate the remaining instructions. 1399 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getPredicateThen()); 1400 LoLoop.ToRemove.insert(Block.getPredicateThen()); 1401 for (auto &PredMI : Insts) 1402 RemovePredicate(PredMI.MI); 1403 } 1404 } 1405 LLVM_DEBUG(dbgs() << "ARM Loops: Removing remaining VCTPs...\n"); 1406 // Remove the "main" VCTP 1407 LoLoop.ToRemove.insert(LoLoop.VCTP); 1408 LLVM_DEBUG(dbgs() << " " << *LoLoop.VCTP); 1409 // Remove remaining secondary VCTPs 1410 for (MachineInstr *VCTP : LoLoop.SecondaryVCTPs) { 1411 // All VCTPs that aren't marked for removal yet should be unpredicated ones. 1412 // The predicated ones should have already been marked for removal when 1413 // visiting the VPT blocks. 1414 if (LoLoop.ToRemove.insert(VCTP).second) { 1415 assert(getVPTInstrPredicate(*VCTP) == ARMVCC::None && 1416 "Removing Predicated VCTP without updating the block mask!"); 1417 LLVM_DEBUG(dbgs() << " " << *VCTP); 1418 } 1419 } 1420 } 1421 1422 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { 1423 1424 // Combine the LoopDec and LoopEnd instructions into LE(TP). 1425 auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) { 1426 MachineInstr *End = LoLoop.End; 1427 MachineBasicBlock *MBB = End->getParent(); 1428 unsigned Opc = LoLoop.IsTailPredicationLegal() ? 1429 ARM::MVE_LETP : ARM::t2LEUpdate; 1430 MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(), 1431 TII->get(Opc)); 1432 MIB.addDef(ARM::LR); 1433 MIB.add(End->getOperand(0)); 1434 MIB.add(End->getOperand(1)); 1435 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); 1436 LoLoop.ToRemove.insert(LoLoop.Dec); 1437 LoLoop.ToRemove.insert(End); 1438 return &*MIB; 1439 }; 1440 1441 // TODO: We should be able to automatically remove these branches before we 1442 // get here - probably by teaching analyzeBranch about the pseudo 1443 // instructions. 1444 // If there is an unconditional branch, after I, that just branches to the 1445 // next block, remove it. 1446 auto RemoveDeadBranch = [](MachineInstr *I) { 1447 MachineBasicBlock *BB = I->getParent(); 1448 MachineInstr *Terminator = &BB->instr_back(); 1449 if (Terminator->isUnconditionalBranch() && I != Terminator) { 1450 MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB(); 1451 if (BB->isLayoutSuccessor(Succ)) { 1452 LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator); 1453 Terminator->eraseFromParent(); 1454 } 1455 } 1456 }; 1457 1458 if (LoLoop.Revert) { 1459 if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart) 1460 RevertWhile(LoLoop.Start); 1461 else 1462 LoLoop.Start->eraseFromParent(); 1463 bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec); 1464 RevertLoopEnd(LoLoop.End, FlagsAlreadySet); 1465 } else { 1466 LoLoop.Start = ExpandLoopStart(LoLoop); 1467 RemoveDeadBranch(LoLoop.Start); 1468 LoLoop.End = ExpandLoopEnd(LoLoop); 1469 RemoveDeadBranch(LoLoop.End); 1470 if (LoLoop.IsTailPredicationLegal()) 1471 ConvertVPTBlocks(LoLoop); 1472 for (auto *I : LoLoop.ToRemove) { 1473 LLVM_DEBUG(dbgs() << "ARM Loops: Erasing " << *I); 1474 I->eraseFromParent(); 1475 } 1476 for (auto *I : LoLoop.BlockMasksToRecompute) { 1477 LLVM_DEBUG(dbgs() << "ARM Loops: Recomputing VPT/VPST Block Mask: " << *I); 1478 recomputeVPTBlockMask(*I); 1479 LLVM_DEBUG(dbgs() << " ... done: " << *I); 1480 } 1481 } 1482 1483 PostOrderLoopTraversal DFS(LoLoop.ML, *MLI); 1484 DFS.ProcessLoop(); 1485 const SmallVectorImpl<MachineBasicBlock*> &PostOrder = DFS.getOrder(); 1486 for (auto *MBB : PostOrder) { 1487 recomputeLiveIns(*MBB); 1488 // FIXME: For some reason, the live-in print order is non-deterministic for 1489 // our tests and I can't out why... So just sort them. 1490 MBB->sortUniqueLiveIns(); 1491 } 1492 1493 for (auto *MBB : reverse(PostOrder)) 1494 recomputeLivenessFlags(*MBB); 1495 1496 // We've moved, removed and inserted new instructions, so update RDA. 1497 RDA->reset(); 1498 } 1499 1500 bool ARMLowOverheadLoops::RevertNonLoops() { 1501 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n"); 1502 bool Changed = false; 1503 1504 for (auto &MBB : *MF) { 1505 SmallVector<MachineInstr*, 4> Starts; 1506 SmallVector<MachineInstr*, 4> Decs; 1507 SmallVector<MachineInstr*, 4> Ends; 1508 1509 for (auto &I : MBB) { 1510 if (isLoopStart(I)) 1511 Starts.push_back(&I); 1512 else if (I.getOpcode() == ARM::t2LoopDec) 1513 Decs.push_back(&I); 1514 else if (I.getOpcode() == ARM::t2LoopEnd) 1515 Ends.push_back(&I); 1516 } 1517 1518 if (Starts.empty() && Decs.empty() && Ends.empty()) 1519 continue; 1520 1521 Changed = true; 1522 1523 for (auto *Start : Starts) { 1524 if (Start->getOpcode() == ARM::t2WhileLoopStart) 1525 RevertWhile(Start); 1526 else 1527 Start->eraseFromParent(); 1528 } 1529 for (auto *Dec : Decs) 1530 RevertLoopDec(Dec); 1531 1532 for (auto *End : Ends) 1533 RevertLoopEnd(End); 1534 } 1535 return Changed; 1536 } 1537 1538 FunctionPass *llvm::createARMLowOverheadLoopsPass() { 1539 return new ARMLowOverheadLoops(); 1540 } 1541