1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo 10 /// instructions into machine operations. 11 /// The expectation is that the loop contains three pseudo instructions: 12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop 13 /// form should be in the preheader, whereas the while form should be in the 14 /// preheaders only predecessor. 15 /// - t2LoopDec - placed within in the loop body. 16 /// - t2LoopEnd - the loop latch terminator. 17 /// 18 //===----------------------------------------------------------------------===// 19 20 #include "ARM.h" 21 #include "ARMBaseInstrInfo.h" 22 #include "ARMBaseRegisterInfo.h" 23 #include "ARMBasicBlockInfo.h" 24 #include "ARMSubtarget.h" 25 #include "llvm/ADT/SetOperations.h" 26 #include "llvm/ADT/SmallSet.h" 27 #include "llvm/CodeGen/MachineFunctionPass.h" 28 #include "llvm/CodeGen/MachineLoopInfo.h" 29 #include "llvm/CodeGen/MachineLoopUtils.h" 30 #include "llvm/CodeGen/MachineRegisterInfo.h" 31 #include "llvm/CodeGen/Passes.h" 32 #include "llvm/CodeGen/ReachingDefAnalysis.h" 33 #include "llvm/MC/MCInstrDesc.h" 34 35 using namespace llvm; 36 37 #define DEBUG_TYPE "arm-low-overhead-loops" 38 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass" 39 40 namespace { 41 42 struct PredicatedMI { 43 MachineInstr *MI = nullptr; 44 SetVector<MachineInstr*> Predicates; 45 46 public: 47 PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) : 48 MI(I) { 49 Predicates.insert(Preds.begin(), Preds.end()); 50 } 51 }; 52 53 // Represent a VPT block, a list of instructions that begins with a VPST and 54 // has a maximum of four proceeding instructions. All instructions within the 55 // block are predicated upon the vpr and we allow instructions to define the 56 // vpr within in the block too. 57 class VPTBlock { 58 std::unique_ptr<PredicatedMI> VPST; 59 PredicatedMI *Divergent = nullptr; 60 SmallVector<PredicatedMI, 4> Insts; 61 62 public: 63 VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 64 VPST = std::make_unique<PredicatedMI>(MI, Preds); 65 } 66 67 void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 68 LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); 69 if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) { 70 Divergent = &Insts.back(); 71 LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); 72 } 73 Insts.emplace_back(MI, Preds); 74 assert(Insts.size() <= 4 && "Too many instructions in VPT block!"); 75 } 76 77 // Have we found an instruction within the block which defines the vpr? If 78 // so, not all the instructions in the block will have the same predicate. 79 bool HasNonUniformPredicate() const { 80 return Divergent != nullptr; 81 } 82 83 // Is the given instruction part of the predicate set controlling the entry 84 // to the block. 85 bool IsPredicatedOn(MachineInstr *MI) const { 86 return VPST->Predicates.count(MI); 87 } 88 89 // Is the given instruction the only predicate which controls the entry to 90 // the block. 91 bool IsOnlyPredicatedOn(MachineInstr *MI) const { 92 return IsPredicatedOn(MI) && VPST->Predicates.size() == 1; 93 } 94 95 unsigned size() const { return Insts.size(); } 96 SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } 97 MachineInstr *getVPST() const { return VPST->MI; } 98 PredicatedMI *getDivergent() const { return Divergent; } 99 }; 100 101 struct LowOverheadLoop { 102 103 MachineLoop *ML = nullptr; 104 MachineFunction *MF = nullptr; 105 MachineInstr *InsertPt = nullptr; 106 MachineInstr *Start = nullptr; 107 MachineInstr *Dec = nullptr; 108 MachineInstr *End = nullptr; 109 MachineInstr *VCTP = nullptr; 110 VPTBlock *CurrentBlock = nullptr; 111 SetVector<MachineInstr*> CurrentPredicate; 112 SmallVector<VPTBlock, 4> VPTBlocks; 113 bool Revert = false; 114 bool CannotTailPredicate = false; 115 116 LowOverheadLoop(MachineLoop *ML) : ML(ML) { 117 MF = ML->getHeader()->getParent(); 118 } 119 120 bool RecordVPTBlocks(MachineInstr *MI); 121 122 // If this is an MVE instruction, check that we know how to use tail 123 // predication with it. 124 void AnalyseMVEInst(MachineInstr *MI) { 125 if (CannotTailPredicate) 126 return; 127 128 if (!RecordVPTBlocks(MI)) { 129 CannotTailPredicate = true; 130 return; 131 } 132 133 const MCInstrDesc &MCID = MI->getDesc(); 134 uint64_t Flags = MCID.TSFlags; 135 if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) 136 return; 137 138 if ((Flags & ARMII::ValidForTailPredication) == 0) { 139 LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI); 140 CannotTailPredicate = true; 141 } 142 } 143 144 bool IsTailPredicationLegal() const { 145 // For now, let's keep things really simple and only support a single 146 // block for tail predication. 147 return !Revert && FoundAllComponents() && VCTP && 148 !CannotTailPredicate && ML->getNumBlocks() == 1; 149 } 150 151 bool ValidateTailPredicate(MachineInstr *StartInsertPt, 152 ReachingDefAnalysis *RDA, 153 MachineLoopInfo *MLI); 154 155 // Is it safe to define LR with DLS/WLS? 156 // LR can be defined if it is the operand to start, because it's the same 157 // value, or if it's going to be equivalent to the operand to Start. 158 MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA); 159 160 // Check the branch targets are within range and we satisfy our 161 // restrictions. 162 void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA, 163 MachineLoopInfo *MLI); 164 165 bool FoundAllComponents() const { 166 return Start && Dec && End; 167 } 168 169 SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; } 170 171 // Return the loop iteration count, or the number of elements if we're tail 172 // predicating. 173 MachineOperand &getCount() { 174 return IsTailPredicationLegal() ? 175 VCTP->getOperand(1) : Start->getOperand(0); 176 } 177 178 unsigned getStartOpcode() const { 179 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 180 if (!IsTailPredicationLegal()) 181 return IsDo ? ARM::t2DLS : ARM::t2WLS; 182 183 return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo); 184 } 185 186 void dump() const { 187 if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; 188 if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; 189 if (End) dbgs() << "ARM Loops: Found Loop End: " << *End; 190 if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP; 191 if (!FoundAllComponents()) 192 dbgs() << "ARM Loops: Not a low-overhead loop.\n"; 193 else if (!(Start && Dec && End)) 194 dbgs() << "ARM Loops: Failed to find all loop components.\n"; 195 } 196 }; 197 198 class ARMLowOverheadLoops : public MachineFunctionPass { 199 MachineFunction *MF = nullptr; 200 MachineLoopInfo *MLI = nullptr; 201 ReachingDefAnalysis *RDA = nullptr; 202 const ARMBaseInstrInfo *TII = nullptr; 203 MachineRegisterInfo *MRI = nullptr; 204 const TargetRegisterInfo *TRI = nullptr; 205 std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr; 206 207 public: 208 static char ID; 209 210 ARMLowOverheadLoops() : MachineFunctionPass(ID) { } 211 212 void getAnalysisUsage(AnalysisUsage &AU) const override { 213 AU.setPreservesCFG(); 214 AU.addRequired<MachineLoopInfo>(); 215 AU.addRequired<ReachingDefAnalysis>(); 216 MachineFunctionPass::getAnalysisUsage(AU); 217 } 218 219 bool runOnMachineFunction(MachineFunction &MF) override; 220 221 MachineFunctionProperties getRequiredProperties() const override { 222 return MachineFunctionProperties().set( 223 MachineFunctionProperties::Property::NoVRegs).set( 224 MachineFunctionProperties::Property::TracksLiveness); 225 } 226 227 StringRef getPassName() const override { 228 return ARM_LOW_OVERHEAD_LOOPS_NAME; 229 } 230 231 private: 232 bool ProcessLoop(MachineLoop *ML); 233 234 bool RevertNonLoops(); 235 236 void RevertWhile(MachineInstr *MI) const; 237 238 bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const; 239 240 void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; 241 242 void RemoveLoopUpdate(LowOverheadLoop &LoLoop); 243 244 void ConvertVPTBlocks(LowOverheadLoop &LoLoop); 245 246 MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); 247 248 void Expand(LowOverheadLoop &LoLoop); 249 250 }; 251 } 252 253 char ARMLowOverheadLoops::ID = 0; 254 255 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, 256 false, false) 257 258 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) { 259 // We can define LR because LR already contains the same value. 260 if (Start->getOperand(0).getReg() == ARM::LR) 261 return Start; 262 263 unsigned CountReg = Start->getOperand(0).getReg(); 264 auto IsMoveLR = [&CountReg](MachineInstr *MI) { 265 return MI->getOpcode() == ARM::tMOVr && 266 MI->getOperand(0).getReg() == ARM::LR && 267 MI->getOperand(1).getReg() == CountReg && 268 MI->getOperand(2).getImm() == ARMCC::AL; 269 }; 270 271 MachineBasicBlock *MBB = Start->getParent(); 272 273 // Find an insertion point: 274 // - Is there a (mov lr, Count) before Start? If so, and nothing else writes 275 // to Count before Start, we can insert at that mov. 276 if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR)) 277 if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) 278 return LRDef; 279 280 // - Is there a (mov lr, Count) after Start? If so, and nothing else writes 281 // to Count after Start, we can insert at that mov. 282 if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR)) 283 if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) 284 return LRDef; 285 286 // We've found no suitable LR def and Start doesn't use LR directly. Can we 287 // just define LR anyway? 288 if (!RDA->isRegUsedAfter(Start, ARM::LR)) 289 return Start; 290 291 return nullptr; 292 } 293 294 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must 295 // not define a register that is used by any instructions, after and including, 296 // 'To'. These instructions also must not redefine any of Froms operands. 297 template<typename Iterator> 298 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) { 299 SmallSet<int, 2> Defs; 300 // First check that From would compute the same value if moved. 301 for (auto &MO : From->operands()) { 302 if (!MO.isReg() || MO.isUndef() || !MO.getReg()) 303 continue; 304 if (MO.isDef()) 305 Defs.insert(MO.getReg()); 306 else if (!RDA->hasSameReachingDef(From, To, MO.getReg())) 307 return false; 308 } 309 310 // Now walk checking that the rest of the instructions will compute the same 311 // value. 312 for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) { 313 for (auto &MO : I->operands()) 314 if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg())) 315 return false; 316 } 317 return true; 318 } 319 320 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, 321 ReachingDefAnalysis *RDA, 322 MachineLoopInfo *MLI) { 323 // All predication within the loop should be based on vctp. If the block 324 // isn't predicated on entry, check whether the vctp is within the block 325 // and that all other instructions are then predicated on it. 326 for (auto &Block : VPTBlocks) { 327 if (Block.IsPredicatedOn(VCTP)) 328 continue; 329 if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) 330 return false; 331 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 332 for (auto &PredMI : Insts) { 333 if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI)) 334 continue; 335 LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI 336 << " - which is predicated on:\n"; 337 for (auto *MI : PredMI.Predicates) 338 dbgs() << " - " << *MI; 339 ); 340 return false; 341 } 342 } 343 344 // For tail predication, we need to provide the number of elements, instead 345 // of the iteration count, to the loop start instruction. The number of 346 // elements is provided to the vctp instruction, so we need to check that 347 // we can use this register at InsertPt. 348 Register NumElements = VCTP->getOperand(1).getReg(); 349 350 // If the register is defined within loop, then we can't perform TP. 351 // TODO: Check whether this is just a mov of a register that would be 352 // available. 353 if (RDA->getReachingDef(VCTP, NumElements) >= 0) { 354 LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n"); 355 return false; 356 } 357 358 // The element count register maybe defined after InsertPt, in which case we 359 // need to try to move either InsertPt or the def so that the [w|d]lstp can 360 // use the value. 361 MachineBasicBlock *InsertBB = InsertPt->getParent(); 362 if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) { 363 if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) { 364 if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) { 365 ElemDef->removeFromParent(); 366 InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef); 367 LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: " 368 << *ElemDef); 369 } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) { 370 InsertPt->removeFromParent(); 371 InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt); 372 LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); 373 } else 374 return false; 375 } 376 } 377 378 // Especially in the case of while loops, InsertBB may not be the 379 // preheader, so we need to check that the register isn't redefined 380 // before entering the loop. 381 auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB, 382 Register NumElements) { 383 // NumElements is redefined in this block. 384 if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0) 385 return true; 386 387 // Don't continue searching up through multiple predecessors. 388 if (MBB->pred_size() > 1) 389 return true; 390 391 return false; 392 }; 393 394 // First, find the block that looks like the preheader. 395 MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true); 396 if (!MBB) 397 return false; 398 399 // Then search backwards for a def, until we get to InsertBB. 400 while (MBB != InsertBB) { 401 if (CannotProvideElements(MBB, NumElements)) 402 return false; 403 MBB = *MBB->pred_begin(); 404 } 405 406 LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n"); 407 return true; 408 } 409 410 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, 411 ReachingDefAnalysis *RDA, 412 MachineLoopInfo *MLI) { 413 if (Revert) 414 return; 415 416 if (!End->getOperand(1).isMBB()) 417 report_fatal_error("Expected LoopEnd to target basic block"); 418 419 // TODO Maybe there's cases where the target doesn't have to be the header, 420 // but for now be safe and revert. 421 if (End->getOperand(1).getMBB() != ML->getHeader()) { 422 LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); 423 Revert = true; 424 return; 425 } 426 427 // The WLS and LE instructions have 12-bits for the label offset. WLS 428 // requires a positive offset, while LE uses negative. 429 if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) || 430 !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) { 431 LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); 432 Revert = true; 433 return; 434 } 435 436 if (Start->getOpcode() == ARM::t2WhileLoopStart && 437 (BBUtils->getOffsetOf(Start) > 438 BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || 439 !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { 440 LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); 441 Revert = true; 442 return; 443 } 444 445 InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA); 446 if (!InsertPt) { 447 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); 448 Revert = true; 449 return; 450 } else 451 LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); 452 453 if (!IsTailPredicationLegal()) { 454 LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n"); 455 return; 456 } 457 458 assert(ML->getBlocks().size() == 1 && 459 "Shouldn't be processing a loop with more than one block"); 460 CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI); 461 LLVM_DEBUG(if (CannotTailPredicate) 462 dbgs() << "ARM Loops: Couldn't validate tail predicate.\n"); 463 } 464 465 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) { 466 // Only support a single vctp. 467 if (isVCTP(MI) && VCTP) 468 return false; 469 470 // Start a new vpt block when we discover a vpt. 471 if (MI->getOpcode() == ARM::MVE_VPST) { 472 VPTBlocks.emplace_back(MI, CurrentPredicate); 473 CurrentBlock = &VPTBlocks.back(); 474 return true; 475 } 476 477 if (isVCTP(MI)) 478 VCTP = MI; 479 480 unsigned VPROpNum = MI->getNumOperands() - 1; 481 bool IsUse = false; 482 if (MI->getOperand(VPROpNum).isReg() && 483 MI->getOperand(VPROpNum).getReg() == ARM::VPR && 484 MI->getOperand(VPROpNum).isUse()) { 485 // If this instruction is predicated by VPR, it will be its last 486 // operand. Also check that it's only 'Then' predicated. 487 if (!MI->getOperand(VPROpNum-1).isImm() || 488 MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) { 489 LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: " 490 << *MI); 491 return false; 492 } 493 CurrentBlock->addInst(MI, CurrentPredicate); 494 IsUse = true; 495 } 496 497 bool IsDef = false; 498 for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) { 499 const MachineOperand &MO = MI->getOperand(i); 500 if (!MO.isReg() || MO.getReg() != ARM::VPR) 501 continue; 502 503 if (MO.isDef()) { 504 CurrentPredicate.insert(MI); 505 IsDef = true; 506 } else { 507 LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); 508 return false; 509 } 510 } 511 512 // If we find a vpr def that is not already predicated on the vctp, we've 513 // got disjoint predicates that may not be equivalent when we do the 514 // conversion. 515 if (IsDef && !IsUse && VCTP && !isVCTP(MI)) { 516 LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI); 517 return false; 518 } 519 520 return true; 521 } 522 523 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { 524 const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); 525 if (!ST.hasLOB()) 526 return false; 527 528 MF = &mf; 529 LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); 530 531 MLI = &getAnalysis<MachineLoopInfo>(); 532 RDA = &getAnalysis<ReachingDefAnalysis>(); 533 MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); 534 MRI = &MF->getRegInfo(); 535 TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); 536 TRI = ST.getRegisterInfo(); 537 BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); 538 BBUtils->computeAllBlockSizes(); 539 BBUtils->adjustBBOffsetsAfter(&MF->front()); 540 541 bool Changed = false; 542 for (auto ML : *MLI) { 543 if (!ML->getParentLoop()) 544 Changed |= ProcessLoop(ML); 545 } 546 Changed |= RevertNonLoops(); 547 return Changed; 548 } 549 550 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { 551 552 bool Changed = false; 553 554 // Process inner loops first. 555 for (auto I = ML->begin(), E = ML->end(); I != E; ++I) 556 Changed |= ProcessLoop(*I); 557 558 LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n"; 559 if (auto *Preheader = ML->getLoopPreheader()) 560 dbgs() << " - " << Preheader->getName() << "\n"; 561 else if (auto *Preheader = MLI->findLoopPreheader(ML)) 562 dbgs() << " - " << Preheader->getName() << "\n"; 563 for (auto *MBB : ML->getBlocks()) 564 dbgs() << " - " << MBB->getName() << "\n"; 565 ); 566 567 // Search the given block for a loop start instruction. If one isn't found, 568 // and there's only one predecessor block, search that one too. 569 std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart = 570 [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* { 571 for (auto &MI : *MBB) { 572 if (isLoopStart(MI)) 573 return &MI; 574 } 575 if (MBB->pred_size() == 1) 576 return SearchForStart(*MBB->pred_begin()); 577 return nullptr; 578 }; 579 580 LowOverheadLoop LoLoop(ML); 581 // Search the preheader for the start intrinsic. 582 // FIXME: I don't see why we shouldn't be supporting multiple predecessors 583 // with potentially multiple set.loop.iterations, so we need to enable this. 584 if (auto *Preheader = ML->getLoopPreheader()) 585 LoLoop.Start = SearchForStart(Preheader); 586 else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) 587 LoLoop.Start = SearchForStart(Preheader); 588 else 589 return false; 590 591 // Find the low-overhead loop components and decide whether or not to fall 592 // back to a normal loop. Also look for a vctp instructions and decide 593 // whether we can convert that predicate using tail predication. 594 for (auto *MBB : reverse(ML->getBlocks())) { 595 for (auto &MI : *MBB) { 596 if (MI.getOpcode() == ARM::t2LoopDec) 597 LoLoop.Dec = &MI; 598 else if (MI.getOpcode() == ARM::t2LoopEnd) 599 LoLoop.End = &MI; 600 else if (isLoopStart(MI)) 601 LoLoop.Start = &MI; 602 else if (MI.getDesc().isCall()) { 603 // TODO: Though the call will require LE to execute again, does this 604 // mean we should revert? Always executing LE hopefully should be 605 // faster than performing a sub,cmp,br or even subs,br. 606 LoLoop.Revert = true; 607 LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); 608 } else { 609 // Record VPR defs and build up their corresponding vpt blocks. 610 // Check we know how to tail predicate any mve instructions. 611 LoLoop.AnalyseMVEInst(&MI); 612 } 613 614 // We need to ensure that LR is not used or defined inbetween LoopDec and 615 // LoopEnd. 616 if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert) 617 continue; 618 619 // If we find that LR has been written or read between LoopDec and 620 // LoopEnd, expect that the decremented value is being used else where. 621 // Because this value isn't actually going to be produced until the 622 // latch, by LE, we would need to generate a real sub. The value is also 623 // likely to be copied/reloaded for use of LoopEnd - in which in case 624 // we'd need to perform an add because it gets subtracted again by LE! 625 // The other option is to then generate the other form of LE which doesn't 626 // perform the sub. 627 for (auto &MO : MI.operands()) { 628 if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() && 629 MO.getReg() == ARM::LR) { 630 LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI); 631 LoLoop.Revert = true; 632 break; 633 } 634 } 635 } 636 } 637 638 LLVM_DEBUG(LoLoop.dump()); 639 if (!LoLoop.FoundAllComponents()) 640 return false; 641 642 LoLoop.CheckLegality(BBUtils.get(), RDA, MLI); 643 Expand(LoLoop); 644 return true; 645 } 646 647 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a 648 // beq that branches to the exit branch. 649 // TODO: We could also try to generate a cbz if the value in LR is also in 650 // another low register. 651 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const { 652 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI); 653 MachineBasicBlock *MBB = MI->getParent(); 654 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 655 TII->get(ARM::t2CMPri)); 656 MIB.add(MI->getOperand(0)); 657 MIB.addImm(0); 658 MIB.addImm(ARMCC::AL); 659 MIB.addReg(ARM::NoRegister); 660 661 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 662 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 663 ARM::tBcc : ARM::t2Bcc; 664 665 MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 666 MIB.add(MI->getOperand(1)); // branch target 667 MIB.addImm(ARMCC::EQ); // condition code 668 MIB.addReg(ARM::CPSR); 669 MI->eraseFromParent(); 670 } 671 672 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI, 673 bool SetFlags) const { 674 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI); 675 MachineBasicBlock *MBB = MI->getParent(); 676 677 // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS. 678 if (SetFlags && 679 (RDA->isRegUsedAfter(MI, ARM::CPSR) || 680 !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR))) 681 SetFlags = false; 682 683 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 684 TII->get(ARM::t2SUBri)); 685 MIB.addDef(ARM::LR); 686 MIB.add(MI->getOperand(1)); 687 MIB.add(MI->getOperand(2)); 688 MIB.addImm(ARMCC::AL); 689 MIB.addReg(0); 690 691 if (SetFlags) { 692 MIB.addReg(ARM::CPSR); 693 MIB->getOperand(5).setIsDef(true); 694 } else 695 MIB.addReg(0); 696 697 MI->eraseFromParent(); 698 return SetFlags; 699 } 700 701 // Generate a subs, or sub and cmp, and a branch instead of an LE. 702 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { 703 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI); 704 705 MachineBasicBlock *MBB = MI->getParent(); 706 // Create cmp 707 if (!SkipCmp) { 708 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 709 TII->get(ARM::t2CMPri)); 710 MIB.addReg(ARM::LR); 711 MIB.addImm(0); 712 MIB.addImm(ARMCC::AL); 713 MIB.addReg(ARM::NoRegister); 714 } 715 716 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 717 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 718 ARM::tBcc : ARM::t2Bcc; 719 720 // Create bne 721 MachineInstrBuilder MIB = 722 BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 723 MIB.add(MI->getOperand(1)); // branch target 724 MIB.addImm(ARMCC::NE); // condition code 725 MIB.addReg(ARM::CPSR); 726 MI->eraseFromParent(); 727 } 728 729 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { 730 MachineInstr *InsertPt = LoLoop.InsertPt; 731 MachineInstr *Start = LoLoop.Start; 732 MachineBasicBlock *MBB = InsertPt->getParent(); 733 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 734 unsigned Opc = LoLoop.getStartOpcode(); 735 MachineOperand &Count = LoLoop.getCount(); 736 737 MachineInstrBuilder MIB = 738 BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); 739 740 MIB.addDef(ARM::LR); 741 MIB.add(Count); 742 if (!IsDo) 743 MIB.add(Start->getOperand(1)); 744 745 // When using tail-predication, try to delete the dead code that was used to 746 // calculate the number of loop iterations. 747 if (LoLoop.IsTailPredicationLegal()) { 748 SmallVector<MachineInstr*, 4> Killed; 749 SmallVector<MachineInstr*, 4> Dead; 750 if (auto *Def = RDA->getReachingMIDef(Start, 751 Start->getOperand(0).getReg())) { 752 Killed.push_back(Def); 753 754 while (!Killed.empty()) { 755 MachineInstr *Def = Killed.back(); 756 Killed.pop_back(); 757 Dead.push_back(Def); 758 for (auto &MO : Def->operands()) { 759 if (!MO.isReg() || !MO.isKill()) 760 continue; 761 762 MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg()); 763 if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1) 764 Killed.push_back(Kill); 765 } 766 } 767 for (auto *MI : Dead) 768 MI->eraseFromParent(); 769 } 770 } 771 772 // If we're inserting at a mov lr, then remove it as it's redundant. 773 if (InsertPt != Start) 774 InsertPt->eraseFromParent(); 775 Start->eraseFromParent(); 776 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); 777 return &*MIB; 778 } 779 780 // Goal is to optimise and clean-up these loops: 781 // 782 // vector.body: 783 // renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg 784 // renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4 785 // .. 786 // $lr = MVE_DLSTP_32 renamable $r3 787 // 788 // The SUB is the old update of the loop iteration count expression, which 789 // is no longer needed. This sub is removed when the element count, which is in 790 // r3 in this example, is defined by an instruction in the loop, and it has 791 // no uses. 792 // 793 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) { 794 Register ElemCount = LoLoop.VCTP->getOperand(1).getReg(); 795 MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back(); 796 797 LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n"); 798 799 if (LoLoop.ML->getNumBlocks() != 1) { 800 LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n"); 801 return; 802 } 803 804 LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: "; 805 LoLoop.VCTP->getOperand(1).dump()); 806 807 // Find the definition we are interested in removing, if there is one. 808 MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount); 809 if (!Def) 810 return; 811 812 // Bail if we define CPSR and it is not dead 813 if (!Def->registerDefIsDead(ARM::CPSR, TRI)) { 814 LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n"); 815 return; 816 } 817 818 // Bail if elemcount is used in exit blocks, i.e. if it is live-in. 819 if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) { 820 LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n"); 821 return; 822 } 823 824 // Bail if there are uses after this Def in the block. 825 SmallVector<MachineInstr*, 4> Uses; 826 RDA->getReachingLocalUses(Def, ElemCount, Uses); 827 if (Uses.size()) { 828 LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n"); 829 return; 830 } 831 832 Uses.clear(); 833 RDA->getAllInstWithUseBefore(Def, ElemCount, Uses); 834 835 // Remove Def if there are no uses, or if the only use is the VCTP 836 // instruction. 837 if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) { 838 LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: "; 839 Def->dump()); 840 Def->eraseFromParent(); 841 } 842 } 843 844 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { 845 auto RemovePredicate = [](MachineInstr *MI) { 846 LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); 847 unsigned OpNum = MI->getNumOperands() - 1; 848 assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then && 849 "Expected Then predicate!"); 850 MI->getOperand(OpNum-1).setImm(ARMVCC::None); 851 MI->getOperand(OpNum).setReg(0); 852 }; 853 854 // There are a few scenarios which we have to fix up: 855 // 1) A VPT block with is only predicated by the vctp and has no internal vpr 856 // defs. 857 // 2) A VPT block which is only predicated by the vctp but has an internal 858 // vpr def. 859 // 3) A VPT block which is predicated upon the vctp as well as another vpr 860 // def. 861 // 4) A VPT block which is not predicated upon a vctp, but contains it and 862 // all instructions within the block are predicated upon in. 863 864 for (auto &Block : LoLoop.getVPTBlocks()) { 865 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 866 if (Block.HasNonUniformPredicate()) { 867 PredicatedMI *Divergent = Block.getDivergent(); 868 if (isVCTP(Divergent->MI)) { 869 // The vctp will be removed, so the size of the vpt block needs to be 870 // modified. 871 uint64_t Size = getARMVPTBlockMask(Block.size() - 1); 872 Block.getVPST()->getOperand(0).setImm(Size); 873 LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); 874 } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { 875 // The VPT block has a non-uniform predicate but it's entry is guarded 876 // only by a vctp, which means we: 877 // - Need to remove the original vpst. 878 // - Then need to unpredicate any following instructions, until 879 // we come across the divergent vpr def. 880 // - Insert a new vpst to predicate the instruction(s) that following 881 // the divergent vpr def. 882 // TODO: We could be producing more VPT blocks than necessary and could 883 // fold the newly created one into a proceeding one. 884 for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()), 885 E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) 886 RemovePredicate(&*I); 887 888 unsigned Size = 0; 889 auto E = MachineBasicBlock::reverse_iterator(Divergent->MI); 890 auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI); 891 MachineInstr *InsertAt = nullptr; 892 while (I != E) { 893 InsertAt = &*I; 894 ++Size; 895 ++I; 896 } 897 MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, 898 InsertAt->getDebugLoc(), 899 TII->get(ARM::MVE_VPST)); 900 MIB.addImm(getARMVPTBlockMask(Size)); 901 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); 902 LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); 903 Block.getVPST()->eraseFromParent(); 904 } 905 } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { 906 // A vpt block which is only predicated upon vctp and has no internal vpr 907 // defs: 908 // - Remove vpst. 909 // - Unpredicate the remaining instructions. 910 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); 911 Block.getVPST()->eraseFromParent(); 912 for (auto &PredMI : Insts) 913 RemovePredicate(PredMI.MI); 914 } 915 } 916 917 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); 918 LoLoop.VCTP->eraseFromParent(); 919 } 920 921 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { 922 923 // Combine the LoopDec and LoopEnd instructions into LE(TP). 924 auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) { 925 MachineInstr *End = LoLoop.End; 926 MachineBasicBlock *MBB = End->getParent(); 927 unsigned Opc = LoLoop.IsTailPredicationLegal() ? 928 ARM::MVE_LETP : ARM::t2LEUpdate; 929 MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(), 930 TII->get(Opc)); 931 MIB.addDef(ARM::LR); 932 MIB.add(End->getOperand(0)); 933 MIB.add(End->getOperand(1)); 934 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); 935 936 LoLoop.End->eraseFromParent(); 937 LoLoop.Dec->eraseFromParent(); 938 return &*MIB; 939 }; 940 941 // TODO: We should be able to automatically remove these branches before we 942 // get here - probably by teaching analyzeBranch about the pseudo 943 // instructions. 944 // If there is an unconditional branch, after I, that just branches to the 945 // next block, remove it. 946 auto RemoveDeadBranch = [](MachineInstr *I) { 947 MachineBasicBlock *BB = I->getParent(); 948 MachineInstr *Terminator = &BB->instr_back(); 949 if (Terminator->isUnconditionalBranch() && I != Terminator) { 950 MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB(); 951 if (BB->isLayoutSuccessor(Succ)) { 952 LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator); 953 Terminator->eraseFromParent(); 954 } 955 } 956 }; 957 958 if (LoLoop.Revert) { 959 if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart) 960 RevertWhile(LoLoop.Start); 961 else 962 LoLoop.Start->eraseFromParent(); 963 bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true); 964 RevertLoopEnd(LoLoop.End, FlagsAlreadySet); 965 } else { 966 LoLoop.Start = ExpandLoopStart(LoLoop); 967 RemoveDeadBranch(LoLoop.Start); 968 LoLoop.End = ExpandLoopEnd(LoLoop); 969 RemoveDeadBranch(LoLoop.End); 970 if (LoLoop.IsTailPredicationLegal()) { 971 RemoveLoopUpdate(LoLoop); 972 ConvertVPTBlocks(LoLoop); 973 } 974 } 975 } 976 977 bool ARMLowOverheadLoops::RevertNonLoops() { 978 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n"); 979 bool Changed = false; 980 981 for (auto &MBB : *MF) { 982 SmallVector<MachineInstr*, 4> Starts; 983 SmallVector<MachineInstr*, 4> Decs; 984 SmallVector<MachineInstr*, 4> Ends; 985 986 for (auto &I : MBB) { 987 if (isLoopStart(I)) 988 Starts.push_back(&I); 989 else if (I.getOpcode() == ARM::t2LoopDec) 990 Decs.push_back(&I); 991 else if (I.getOpcode() == ARM::t2LoopEnd) 992 Ends.push_back(&I); 993 } 994 995 if (Starts.empty() && Decs.empty() && Ends.empty()) 996 continue; 997 998 Changed = true; 999 1000 for (auto *Start : Starts) { 1001 if (Start->getOpcode() == ARM::t2WhileLoopStart) 1002 RevertWhile(Start); 1003 else 1004 Start->eraseFromParent(); 1005 } 1006 for (auto *Dec : Decs) 1007 RevertLoopDec(Dec); 1008 1009 for (auto *End : Ends) 1010 RevertLoopEnd(End); 1011 } 1012 return Changed; 1013 } 1014 1015 FunctionPass *llvm::createARMLowOverheadLoopsPass() { 1016 return new ARMLowOverheadLoops(); 1017 } 1018