1 //===-- ARMLowOverheadLoops.cpp - CodeGen Low-overhead Loops ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// Finalize v8.1-m low-overhead loops by converting the associated pseudo 10 /// instructions into machine operations. 11 /// The expectation is that the loop contains three pseudo instructions: 12 /// - t2*LoopStart - placed in the preheader or pre-preheader. The do-loop 13 /// form should be in the preheader, whereas the while form should be in the 14 /// preheaders only predecessor. 15 /// - t2LoopDec - placed within in the loop body. 16 /// - t2LoopEnd - the loop latch terminator. 17 /// 18 /// In addition to this, we also look for the presence of the VCTP instruction, 19 /// which determines whether we can generated the tail-predicated low-overhead 20 /// loop form. 21 /// 22 //===----------------------------------------------------------------------===// 23 24 #include "ARM.h" 25 #include "ARMBaseInstrInfo.h" 26 #include "ARMBaseRegisterInfo.h" 27 #include "ARMBasicBlockInfo.h" 28 #include "ARMSubtarget.h" 29 #include "llvm/ADT/SetOperations.h" 30 #include "llvm/ADT/SmallSet.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/MachineLoopInfo.h" 33 #include "llvm/CodeGen/MachineLoopUtils.h" 34 #include "llvm/CodeGen/MachineRegisterInfo.h" 35 #include "llvm/CodeGen/Passes.h" 36 #include "llvm/CodeGen/ReachingDefAnalysis.h" 37 #include "llvm/MC/MCInstrDesc.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "arm-low-overhead-loops" 42 #define ARM_LOW_OVERHEAD_LOOPS_NAME "ARM Low Overhead Loops pass" 43 44 namespace { 45 46 struct PredicatedMI { 47 MachineInstr *MI = nullptr; 48 SetVector<MachineInstr*> Predicates; 49 50 public: 51 PredicatedMI(MachineInstr *I, SetVector<MachineInstr*> &Preds) : 52 MI(I) { 53 Predicates.insert(Preds.begin(), Preds.end()); 54 } 55 }; 56 57 // Represent a VPT block, a list of instructions that begins with a VPST and 58 // has a maximum of four proceeding instructions. All instructions within the 59 // block are predicated upon the vpr and we allow instructions to define the 60 // vpr within in the block too. 61 class VPTBlock { 62 std::unique_ptr<PredicatedMI> VPST; 63 PredicatedMI *Divergent = nullptr; 64 SmallVector<PredicatedMI, 4> Insts; 65 66 public: 67 VPTBlock(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 68 VPST = std::make_unique<PredicatedMI>(MI, Preds); 69 } 70 71 void addInst(MachineInstr *MI, SetVector<MachineInstr*> &Preds) { 72 LLVM_DEBUG(dbgs() << "ARM Loops: Adding predicated MI: " << *MI); 73 if (!Divergent && !set_difference(Preds, VPST->Predicates).empty()) { 74 Divergent = &Insts.back(); 75 LLVM_DEBUG(dbgs() << " - has divergent predicate: " << *Divergent->MI); 76 } 77 Insts.emplace_back(MI, Preds); 78 assert(Insts.size() <= 4 && "Too many instructions in VPT block!"); 79 } 80 81 // Have we found an instruction within the block which defines the vpr? If 82 // so, not all the instructions in the block will have the same predicate. 83 bool HasNonUniformPredicate() const { 84 return Divergent != nullptr; 85 } 86 87 // Is the given instruction part of the predicate set controlling the entry 88 // to the block. 89 bool IsPredicatedOn(MachineInstr *MI) const { 90 return VPST->Predicates.count(MI); 91 } 92 93 // Is the given instruction the only predicate which controls the entry to 94 // the block. 95 bool IsOnlyPredicatedOn(MachineInstr *MI) const { 96 return IsPredicatedOn(MI) && VPST->Predicates.size() == 1; 97 } 98 99 unsigned size() const { return Insts.size(); } 100 SmallVectorImpl<PredicatedMI> &getInsts() { return Insts; } 101 MachineInstr *getVPST() const { return VPST->MI; } 102 PredicatedMI *getDivergent() const { return Divergent; } 103 }; 104 105 struct LowOverheadLoop { 106 107 MachineLoop *ML = nullptr; 108 MachineFunction *MF = nullptr; 109 MachineInstr *InsertPt = nullptr; 110 MachineInstr *Start = nullptr; 111 MachineInstr *Dec = nullptr; 112 MachineInstr *End = nullptr; 113 MachineInstr *VCTP = nullptr; 114 VPTBlock *CurrentBlock = nullptr; 115 SetVector<MachineInstr*> CurrentPredicate; 116 SmallVector<VPTBlock, 4> VPTBlocks; 117 bool Revert = false; 118 bool CannotTailPredicate = false; 119 120 LowOverheadLoop(MachineLoop *ML) : ML(ML) { 121 MF = ML->getHeader()->getParent(); 122 } 123 124 bool RecordVPTBlocks(MachineInstr *MI); 125 126 // If this is an MVE instruction, check that we know how to use tail 127 // predication with it. 128 void AnalyseMVEInst(MachineInstr *MI) { 129 if (CannotTailPredicate) 130 return; 131 132 if (!RecordVPTBlocks(MI)) { 133 CannotTailPredicate = true; 134 return; 135 } 136 137 const MCInstrDesc &MCID = MI->getDesc(); 138 uint64_t Flags = MCID.TSFlags; 139 if ((Flags & ARMII::DomainMask) != ARMII::DomainMVE) 140 return; 141 142 if ((Flags & ARMII::ValidForTailPredication) == 0) { 143 LLVM_DEBUG(dbgs() << "ARM Loops: Can't tail predicate: " << *MI); 144 CannotTailPredicate = true; 145 } 146 } 147 148 bool IsTailPredicationLegal() const { 149 // For now, let's keep things really simple and only support a single 150 // block for tail predication. 151 return !Revert && FoundAllComponents() && VCTP && 152 !CannotTailPredicate && ML->getNumBlocks() == 1; 153 } 154 155 bool ValidateTailPredicate(MachineInstr *StartInsertPt, 156 ReachingDefAnalysis *RDA, 157 MachineLoopInfo *MLI); 158 159 // Is it safe to define LR with DLS/WLS? 160 // LR can be defined if it is the operand to start, because it's the same 161 // value, or if it's going to be equivalent to the operand to Start. 162 MachineInstr *IsSafeToDefineLR(ReachingDefAnalysis *RDA); 163 164 // Check the branch targets are within range and we satisfy our 165 // restrictions. 166 void CheckLegality(ARMBasicBlockUtils *BBUtils, ReachingDefAnalysis *RDA, 167 MachineLoopInfo *MLI); 168 169 bool FoundAllComponents() const { 170 return Start && Dec && End; 171 } 172 173 SmallVectorImpl<VPTBlock> &getVPTBlocks() { return VPTBlocks; } 174 175 // Return the loop iteration count, or the number of elements if we're tail 176 // predicating. 177 MachineOperand &getCount() { 178 return IsTailPredicationLegal() ? 179 VCTP->getOperand(1) : Start->getOperand(0); 180 } 181 182 unsigned getStartOpcode() const { 183 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 184 if (!IsTailPredicationLegal()) 185 return IsDo ? ARM::t2DLS : ARM::t2WLS; 186 187 return VCTPOpcodeToLSTP(VCTP->getOpcode(), IsDo); 188 } 189 190 void dump() const { 191 if (Start) dbgs() << "ARM Loops: Found Loop Start: " << *Start; 192 if (Dec) dbgs() << "ARM Loops: Found Loop Dec: " << *Dec; 193 if (End) dbgs() << "ARM Loops: Found Loop End: " << *End; 194 if (VCTP) dbgs() << "ARM Loops: Found VCTP: " << *VCTP; 195 if (!FoundAllComponents()) 196 dbgs() << "ARM Loops: Not a low-overhead loop.\n"; 197 else if (!(Start && Dec && End)) 198 dbgs() << "ARM Loops: Failed to find all loop components.\n"; 199 } 200 }; 201 202 class ARMLowOverheadLoops : public MachineFunctionPass { 203 MachineFunction *MF = nullptr; 204 MachineLoopInfo *MLI = nullptr; 205 ReachingDefAnalysis *RDA = nullptr; 206 const ARMBaseInstrInfo *TII = nullptr; 207 MachineRegisterInfo *MRI = nullptr; 208 const TargetRegisterInfo *TRI = nullptr; 209 std::unique_ptr<ARMBasicBlockUtils> BBUtils = nullptr; 210 211 public: 212 static char ID; 213 214 ARMLowOverheadLoops() : MachineFunctionPass(ID) { } 215 216 void getAnalysisUsage(AnalysisUsage &AU) const override { 217 AU.setPreservesCFG(); 218 AU.addRequired<MachineLoopInfo>(); 219 AU.addRequired<ReachingDefAnalysis>(); 220 MachineFunctionPass::getAnalysisUsage(AU); 221 } 222 223 bool runOnMachineFunction(MachineFunction &MF) override; 224 225 MachineFunctionProperties getRequiredProperties() const override { 226 return MachineFunctionProperties().set( 227 MachineFunctionProperties::Property::NoVRegs).set( 228 MachineFunctionProperties::Property::TracksLiveness); 229 } 230 231 StringRef getPassName() const override { 232 return ARM_LOW_OVERHEAD_LOOPS_NAME; 233 } 234 235 private: 236 bool ProcessLoop(MachineLoop *ML); 237 238 bool RevertNonLoops(); 239 240 void RevertWhile(MachineInstr *MI) const; 241 242 bool RevertLoopDec(MachineInstr *MI, bool AllowFlags = false) const; 243 244 void RevertLoopEnd(MachineInstr *MI, bool SkipCmp = false) const; 245 246 void RemoveLoopUpdate(LowOverheadLoop &LoLoop); 247 248 void ConvertVPTBlocks(LowOverheadLoop &LoLoop); 249 250 MachineInstr *ExpandLoopStart(LowOverheadLoop &LoLoop); 251 252 void Expand(LowOverheadLoop &LoLoop); 253 254 }; 255 } 256 257 char ARMLowOverheadLoops::ID = 0; 258 259 INITIALIZE_PASS(ARMLowOverheadLoops, DEBUG_TYPE, ARM_LOW_OVERHEAD_LOOPS_NAME, 260 false, false) 261 262 MachineInstr *LowOverheadLoop::IsSafeToDefineLR(ReachingDefAnalysis *RDA) { 263 // We can define LR because LR already contains the same value. 264 if (Start->getOperand(0).getReg() == ARM::LR) 265 return Start; 266 267 unsigned CountReg = Start->getOperand(0).getReg(); 268 auto IsMoveLR = [&CountReg](MachineInstr *MI) { 269 return MI->getOpcode() == ARM::tMOVr && 270 MI->getOperand(0).getReg() == ARM::LR && 271 MI->getOperand(1).getReg() == CountReg && 272 MI->getOperand(2).getImm() == ARMCC::AL; 273 }; 274 275 MachineBasicBlock *MBB = Start->getParent(); 276 277 // Find an insertion point: 278 // - Is there a (mov lr, Count) before Start? If so, and nothing else writes 279 // to Count before Start, we can insert at that mov. 280 if (auto *LRDef = RDA->getReachingMIDef(Start, ARM::LR)) 281 if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) 282 return LRDef; 283 284 // - Is there a (mov lr, Count) after Start? If so, and nothing else writes 285 // to Count after Start, we can insert at that mov. 286 if (auto *LRDef = RDA->getLocalLiveOutMIDef(MBB, ARM::LR)) 287 if (IsMoveLR(LRDef) && RDA->hasSameReachingDef(Start, LRDef, CountReg)) 288 return LRDef; 289 290 // We've found no suitable LR def and Start doesn't use LR directly. Can we 291 // just define LR anyway? 292 if (!RDA->isRegUsedAfter(Start, ARM::LR)) 293 return Start; 294 295 return nullptr; 296 } 297 298 // Can we safely move 'From' to just before 'To'? To satisfy this, 'From' must 299 // not define a register that is used by any instructions, after and including, 300 // 'To'. These instructions also must not redefine any of Froms operands. 301 template<typename Iterator> 302 static bool IsSafeToMove(MachineInstr *From, MachineInstr *To, ReachingDefAnalysis *RDA) { 303 SmallSet<int, 2> Defs; 304 // First check that From would compute the same value if moved. 305 for (auto &MO : From->operands()) { 306 if (!MO.isReg() || MO.isUndef() || !MO.getReg()) 307 continue; 308 if (MO.isDef()) 309 Defs.insert(MO.getReg()); 310 else if (!RDA->hasSameReachingDef(From, To, MO.getReg())) 311 return false; 312 } 313 314 // Now walk checking that the rest of the instructions will compute the same 315 // value. 316 for (auto I = ++Iterator(From), E = Iterator(To); I != E; ++I) { 317 for (auto &MO : I->operands()) 318 if (MO.isReg() && MO.getReg() && MO.isUse() && Defs.count(MO.getReg())) 319 return false; 320 } 321 return true; 322 } 323 324 bool LowOverheadLoop::ValidateTailPredicate(MachineInstr *StartInsertPt, 325 ReachingDefAnalysis *RDA, 326 MachineLoopInfo *MLI) { 327 // All predication within the loop should be based on vctp. If the block 328 // isn't predicated on entry, check whether the vctp is within the block 329 // and that all other instructions are then predicated on it. 330 for (auto &Block : VPTBlocks) { 331 if (Block.IsPredicatedOn(VCTP)) 332 continue; 333 if (!Block.HasNonUniformPredicate() || !isVCTP(Block.getDivergent()->MI)) 334 return false; 335 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 336 for (auto &PredMI : Insts) { 337 if (PredMI.Predicates.count(VCTP) || isVCTP(PredMI.MI)) 338 continue; 339 LLVM_DEBUG(dbgs() << "ARM Loops: Can't convert: " << *PredMI.MI 340 << " - which is predicated on:\n"; 341 for (auto *MI : PredMI.Predicates) 342 dbgs() << " - " << *MI; 343 ); 344 return false; 345 } 346 } 347 348 // For tail predication, we need to provide the number of elements, instead 349 // of the iteration count, to the loop start instruction. The number of 350 // elements is provided to the vctp instruction, so we need to check that 351 // we can use this register at InsertPt. 352 Register NumElements = VCTP->getOperand(1).getReg(); 353 354 // If the register is defined within loop, then we can't perform TP. 355 // TODO: Check whether this is just a mov of a register that would be 356 // available. 357 if (RDA->getReachingDef(VCTP, NumElements) >= 0) { 358 LLVM_DEBUG(dbgs() << "ARM Loops: VCTP operand is defined in the loop.\n"); 359 return false; 360 } 361 362 // The element count register maybe defined after InsertPt, in which case we 363 // need to try to move either InsertPt or the def so that the [w|d]lstp can 364 // use the value. 365 MachineBasicBlock *InsertBB = InsertPt->getParent(); 366 if (!RDA->isReachingDefLiveOut(InsertPt, NumElements)) { 367 if (auto *ElemDef = RDA->getLocalLiveOutMIDef(InsertBB, NumElements)) { 368 if (IsSafeToMove<MachineBasicBlock::reverse_iterator>(ElemDef, InsertPt, RDA)) { 369 ElemDef->removeFromParent(); 370 InsertBB->insert(MachineBasicBlock::iterator(InsertPt), ElemDef); 371 LLVM_DEBUG(dbgs() << "ARM Loops: Moved element count def: " 372 << *ElemDef); 373 } else if (IsSafeToMove<MachineBasicBlock::iterator>(InsertPt, ElemDef, RDA)) { 374 InsertPt->removeFromParent(); 375 InsertBB->insertAfter(MachineBasicBlock::iterator(ElemDef), InsertPt); 376 LLVM_DEBUG(dbgs() << "ARM Loops: Moved start past: " << *ElemDef); 377 } else 378 return false; 379 } 380 } 381 382 // Especially in the case of while loops, InsertBB may not be the 383 // preheader, so we need to check that the register isn't redefined 384 // before entering the loop. 385 auto CannotProvideElements = [&RDA](MachineBasicBlock *MBB, 386 Register NumElements) { 387 // NumElements is redefined in this block. 388 if (RDA->getReachingDef(&MBB->back(), NumElements) >= 0) 389 return true; 390 391 // Don't continue searching up through multiple predecessors. 392 if (MBB->pred_size() > 1) 393 return true; 394 395 return false; 396 }; 397 398 // First, find the block that looks like the preheader. 399 MachineBasicBlock *MBB = MLI->findLoopPreheader(ML, true); 400 if (!MBB) 401 return false; 402 403 // Then search backwards for a def, until we get to InsertBB. 404 while (MBB != InsertBB) { 405 if (CannotProvideElements(MBB, NumElements)) 406 return false; 407 MBB = *MBB->pred_begin(); 408 } 409 410 LLVM_DEBUG(dbgs() << "ARM Loops: Will use tail predication.\n"); 411 return true; 412 } 413 414 void LowOverheadLoop::CheckLegality(ARMBasicBlockUtils *BBUtils, 415 ReachingDefAnalysis *RDA, 416 MachineLoopInfo *MLI) { 417 if (Revert) 418 return; 419 420 if (!End->getOperand(1).isMBB()) 421 report_fatal_error("Expected LoopEnd to target basic block"); 422 423 // TODO Maybe there's cases where the target doesn't have to be the header, 424 // but for now be safe and revert. 425 if (End->getOperand(1).getMBB() != ML->getHeader()) { 426 LLVM_DEBUG(dbgs() << "ARM Loops: LoopEnd is not targetting header.\n"); 427 Revert = true; 428 return; 429 } 430 431 // The WLS and LE instructions have 12-bits for the label offset. WLS 432 // requires a positive offset, while LE uses negative. 433 if (BBUtils->getOffsetOf(End) < BBUtils->getOffsetOf(ML->getHeader()) || 434 !BBUtils->isBBInRange(End, ML->getHeader(), 4094)) { 435 LLVM_DEBUG(dbgs() << "ARM Loops: LE offset is out-of-range\n"); 436 Revert = true; 437 return; 438 } 439 440 if (Start->getOpcode() == ARM::t2WhileLoopStart && 441 (BBUtils->getOffsetOf(Start) > 442 BBUtils->getOffsetOf(Start->getOperand(1).getMBB()) || 443 !BBUtils->isBBInRange(Start, Start->getOperand(1).getMBB(), 4094))) { 444 LLVM_DEBUG(dbgs() << "ARM Loops: WLS offset is out-of-range!\n"); 445 Revert = true; 446 return; 447 } 448 449 InsertPt = Revert ? nullptr : IsSafeToDefineLR(RDA); 450 if (!InsertPt) { 451 LLVM_DEBUG(dbgs() << "ARM Loops: Unable to find safe insertion point.\n"); 452 Revert = true; 453 return; 454 } else 455 LLVM_DEBUG(dbgs() << "ARM Loops: Start insertion point: " << *InsertPt); 456 457 if (!IsTailPredicationLegal()) { 458 LLVM_DEBUG(dbgs() << "ARM Loops: Tail-predication is not valid.\n"); 459 return; 460 } 461 462 assert(ML->getBlocks().size() == 1 && 463 "Shouldn't be processing a loop with more than one block"); 464 CannotTailPredicate = !ValidateTailPredicate(InsertPt, RDA, MLI); 465 LLVM_DEBUG(if (CannotTailPredicate) 466 dbgs() << "ARM Loops: Couldn't validate tail predicate.\n"); 467 } 468 469 bool LowOverheadLoop::RecordVPTBlocks(MachineInstr* MI) { 470 // Only support a single vctp. 471 if (isVCTP(MI) && VCTP) 472 return false; 473 474 // Start a new vpt block when we discover a vpt. 475 if (MI->getOpcode() == ARM::MVE_VPST) { 476 VPTBlocks.emplace_back(MI, CurrentPredicate); 477 CurrentBlock = &VPTBlocks.back(); 478 return true; 479 } 480 481 if (isVCTP(MI)) 482 VCTP = MI; 483 484 unsigned VPROpNum = MI->getNumOperands() - 1; 485 bool IsUse = false; 486 if (MI->getOperand(VPROpNum).isReg() && 487 MI->getOperand(VPROpNum).getReg() == ARM::VPR && 488 MI->getOperand(VPROpNum).isUse()) { 489 // If this instruction is predicated by VPR, it will be its last 490 // operand. Also check that it's only 'Then' predicated. 491 if (!MI->getOperand(VPROpNum-1).isImm() || 492 MI->getOperand(VPROpNum-1).getImm() != ARMVCC::Then) { 493 LLVM_DEBUG(dbgs() << "ARM Loops: Found unhandled predicate on: " 494 << *MI); 495 return false; 496 } 497 CurrentBlock->addInst(MI, CurrentPredicate); 498 IsUse = true; 499 } 500 501 bool IsDef = false; 502 for (unsigned i = 0; i < MI->getNumOperands() - 1; ++i) { 503 const MachineOperand &MO = MI->getOperand(i); 504 if (!MO.isReg() || MO.getReg() != ARM::VPR) 505 continue; 506 507 if (MO.isDef()) { 508 CurrentPredicate.insert(MI); 509 IsDef = true; 510 } else { 511 LLVM_DEBUG(dbgs() << "ARM Loops: Found instruction using vpr: " << *MI); 512 return false; 513 } 514 } 515 516 // If we find a vpr def that is not already predicated on the vctp, we've 517 // got disjoint predicates that may not be equivalent when we do the 518 // conversion. 519 if (IsDef && !IsUse && VCTP && !isVCTP(MI)) { 520 LLVM_DEBUG(dbgs() << "ARM Loops: Found disjoint vpr def: " << *MI); 521 return false; 522 } 523 524 return true; 525 } 526 527 bool ARMLowOverheadLoops::runOnMachineFunction(MachineFunction &mf) { 528 const ARMSubtarget &ST = static_cast<const ARMSubtarget&>(mf.getSubtarget()); 529 if (!ST.hasLOB()) 530 return false; 531 532 MF = &mf; 533 LLVM_DEBUG(dbgs() << "ARM Loops on " << MF->getName() << " ------------- \n"); 534 535 MLI = &getAnalysis<MachineLoopInfo>(); 536 RDA = &getAnalysis<ReachingDefAnalysis>(); 537 MF->getProperties().set(MachineFunctionProperties::Property::TracksLiveness); 538 MRI = &MF->getRegInfo(); 539 TII = static_cast<const ARMBaseInstrInfo*>(ST.getInstrInfo()); 540 TRI = ST.getRegisterInfo(); 541 BBUtils = std::unique_ptr<ARMBasicBlockUtils>(new ARMBasicBlockUtils(*MF)); 542 BBUtils->computeAllBlockSizes(); 543 BBUtils->adjustBBOffsetsAfter(&MF->front()); 544 545 bool Changed = false; 546 for (auto ML : *MLI) { 547 if (!ML->getParentLoop()) 548 Changed |= ProcessLoop(ML); 549 } 550 Changed |= RevertNonLoops(); 551 return Changed; 552 } 553 554 bool ARMLowOverheadLoops::ProcessLoop(MachineLoop *ML) { 555 556 bool Changed = false; 557 558 // Process inner loops first. 559 for (auto I = ML->begin(), E = ML->end(); I != E; ++I) 560 Changed |= ProcessLoop(*I); 561 562 LLVM_DEBUG(dbgs() << "ARM Loops: Processing loop containing:\n"; 563 if (auto *Preheader = ML->getLoopPreheader()) 564 dbgs() << " - " << Preheader->getName() << "\n"; 565 else if (auto *Preheader = MLI->findLoopPreheader(ML)) 566 dbgs() << " - " << Preheader->getName() << "\n"; 567 for (auto *MBB : ML->getBlocks()) 568 dbgs() << " - " << MBB->getName() << "\n"; 569 ); 570 571 // Search the given block for a loop start instruction. If one isn't found, 572 // and there's only one predecessor block, search that one too. 573 std::function<MachineInstr*(MachineBasicBlock*)> SearchForStart = 574 [&SearchForStart](MachineBasicBlock *MBB) -> MachineInstr* { 575 for (auto &MI : *MBB) { 576 if (isLoopStart(MI)) 577 return &MI; 578 } 579 if (MBB->pred_size() == 1) 580 return SearchForStart(*MBB->pred_begin()); 581 return nullptr; 582 }; 583 584 LowOverheadLoop LoLoop(ML); 585 // Search the preheader for the start intrinsic. 586 // FIXME: I don't see why we shouldn't be supporting multiple predecessors 587 // with potentially multiple set.loop.iterations, so we need to enable this. 588 if (auto *Preheader = ML->getLoopPreheader()) 589 LoLoop.Start = SearchForStart(Preheader); 590 else if (auto *Preheader = MLI->findLoopPreheader(ML, true)) 591 LoLoop.Start = SearchForStart(Preheader); 592 else 593 return false; 594 595 // Find the low-overhead loop components and decide whether or not to fall 596 // back to a normal loop. Also look for a vctp instructions and decide 597 // whether we can convert that predicate using tail predication. 598 for (auto *MBB : reverse(ML->getBlocks())) { 599 for (auto &MI : *MBB) { 600 if (MI.getOpcode() == ARM::t2LoopDec) 601 LoLoop.Dec = &MI; 602 else if (MI.getOpcode() == ARM::t2LoopEnd) 603 LoLoop.End = &MI; 604 else if (isLoopStart(MI)) 605 LoLoop.Start = &MI; 606 else if (MI.getDesc().isCall()) { 607 // TODO: Though the call will require LE to execute again, does this 608 // mean we should revert? Always executing LE hopefully should be 609 // faster than performing a sub,cmp,br or even subs,br. 610 LoLoop.Revert = true; 611 LLVM_DEBUG(dbgs() << "ARM Loops: Found call.\n"); 612 } else { 613 // Record VPR defs and build up their corresponding vpt blocks. 614 // Check we know how to tail predicate any mve instructions. 615 LoLoop.AnalyseMVEInst(&MI); 616 } 617 618 // We need to ensure that LR is not used or defined inbetween LoopDec and 619 // LoopEnd. 620 if (!LoLoop.Dec || LoLoop.End || LoLoop.Revert) 621 continue; 622 623 // If we find that LR has been written or read between LoopDec and 624 // LoopEnd, expect that the decremented value is being used else where. 625 // Because this value isn't actually going to be produced until the 626 // latch, by LE, we would need to generate a real sub. The value is also 627 // likely to be copied/reloaded for use of LoopEnd - in which in case 628 // we'd need to perform an add because it gets subtracted again by LE! 629 // The other option is to then generate the other form of LE which doesn't 630 // perform the sub. 631 for (auto &MO : MI.operands()) { 632 if (MI.getOpcode() != ARM::t2LoopDec && MO.isReg() && 633 MO.getReg() == ARM::LR) { 634 LLVM_DEBUG(dbgs() << "ARM Loops: Found LR Use/Def: " << MI); 635 LoLoop.Revert = true; 636 break; 637 } 638 } 639 } 640 } 641 642 LLVM_DEBUG(LoLoop.dump()); 643 if (!LoLoop.FoundAllComponents()) 644 return false; 645 646 LoLoop.CheckLegality(BBUtils.get(), RDA, MLI); 647 Expand(LoLoop); 648 return true; 649 } 650 651 // WhileLoopStart holds the exit block, so produce a cmp lr, 0 and then a 652 // beq that branches to the exit branch. 653 // TODO: We could also try to generate a cbz if the value in LR is also in 654 // another low register. 655 void ARMLowOverheadLoops::RevertWhile(MachineInstr *MI) const { 656 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp: " << *MI); 657 MachineBasicBlock *MBB = MI->getParent(); 658 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 659 TII->get(ARM::t2CMPri)); 660 MIB.add(MI->getOperand(0)); 661 MIB.addImm(0); 662 MIB.addImm(ARMCC::AL); 663 MIB.addReg(ARM::NoRegister); 664 665 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 666 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 667 ARM::tBcc : ARM::t2Bcc; 668 669 MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 670 MIB.add(MI->getOperand(1)); // branch target 671 MIB.addImm(ARMCC::EQ); // condition code 672 MIB.addReg(ARM::CPSR); 673 MI->eraseFromParent(); 674 } 675 676 bool ARMLowOverheadLoops::RevertLoopDec(MachineInstr *MI, 677 bool SetFlags) const { 678 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to sub: " << *MI); 679 MachineBasicBlock *MBB = MI->getParent(); 680 681 // If nothing defines CPSR between LoopDec and LoopEnd, use a t2SUBS. 682 if (SetFlags && 683 (RDA->isRegUsedAfter(MI, ARM::CPSR) || 684 !RDA->hasSameReachingDef(MI, &MBB->back(), ARM::CPSR))) 685 SetFlags = false; 686 687 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 688 TII->get(ARM::t2SUBri)); 689 MIB.addDef(ARM::LR); 690 MIB.add(MI->getOperand(1)); 691 MIB.add(MI->getOperand(2)); 692 MIB.addImm(ARMCC::AL); 693 MIB.addReg(0); 694 695 if (SetFlags) { 696 MIB.addReg(ARM::CPSR); 697 MIB->getOperand(5).setIsDef(true); 698 } else 699 MIB.addReg(0); 700 701 MI->eraseFromParent(); 702 return SetFlags; 703 } 704 705 // Generate a subs, or sub and cmp, and a branch instead of an LE. 706 void ARMLowOverheadLoops::RevertLoopEnd(MachineInstr *MI, bool SkipCmp) const { 707 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting to cmp, br: " << *MI); 708 709 MachineBasicBlock *MBB = MI->getParent(); 710 // Create cmp 711 if (!SkipCmp) { 712 MachineInstrBuilder MIB = BuildMI(*MBB, MI, MI->getDebugLoc(), 713 TII->get(ARM::t2CMPri)); 714 MIB.addReg(ARM::LR); 715 MIB.addImm(0); 716 MIB.addImm(ARMCC::AL); 717 MIB.addReg(ARM::NoRegister); 718 } 719 720 MachineBasicBlock *DestBB = MI->getOperand(1).getMBB(); 721 unsigned BrOpc = BBUtils->isBBInRange(MI, DestBB, 254) ? 722 ARM::tBcc : ARM::t2Bcc; 723 724 // Create bne 725 MachineInstrBuilder MIB = 726 BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(BrOpc)); 727 MIB.add(MI->getOperand(1)); // branch target 728 MIB.addImm(ARMCC::NE); // condition code 729 MIB.addReg(ARM::CPSR); 730 MI->eraseFromParent(); 731 } 732 733 MachineInstr* ARMLowOverheadLoops::ExpandLoopStart(LowOverheadLoop &LoLoop) { 734 MachineInstr *InsertPt = LoLoop.InsertPt; 735 MachineInstr *Start = LoLoop.Start; 736 MachineBasicBlock *MBB = InsertPt->getParent(); 737 bool IsDo = Start->getOpcode() == ARM::t2DoLoopStart; 738 unsigned Opc = LoLoop.getStartOpcode(); 739 MachineOperand &Count = LoLoop.getCount(); 740 741 MachineInstrBuilder MIB = 742 BuildMI(*MBB, InsertPt, InsertPt->getDebugLoc(), TII->get(Opc)); 743 744 MIB.addDef(ARM::LR); 745 MIB.add(Count); 746 if (!IsDo) 747 MIB.add(Start->getOperand(1)); 748 749 // When using tail-predication, try to delete the dead code that was used to 750 // calculate the number of loop iterations. 751 if (LoLoop.IsTailPredicationLegal()) { 752 SmallVector<MachineInstr*, 4> Killed; 753 SmallVector<MachineInstr*, 4> Dead; 754 if (auto *Def = RDA->getReachingMIDef(Start, 755 Start->getOperand(0).getReg())) { 756 Killed.push_back(Def); 757 758 while (!Killed.empty()) { 759 MachineInstr *Def = Killed.back(); 760 Killed.pop_back(); 761 Dead.push_back(Def); 762 for (auto &MO : Def->operands()) { 763 if (!MO.isReg() || !MO.isKill()) 764 continue; 765 766 MachineInstr *Kill = RDA->getReachingMIDef(Def, MO.getReg()); 767 if (Kill && RDA->getNumUses(Kill, MO.getReg()) == 1) 768 Killed.push_back(Kill); 769 } 770 } 771 for (auto *MI : Dead) 772 MI->eraseFromParent(); 773 } 774 } 775 776 // If we're inserting at a mov lr, then remove it as it's redundant. 777 if (InsertPt != Start) 778 InsertPt->eraseFromParent(); 779 Start->eraseFromParent(); 780 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted start: " << *MIB); 781 return &*MIB; 782 } 783 784 // Goal is to optimise and clean-up these loops: 785 // 786 // vector.body: 787 // renamable $vpr = MVE_VCTP32 renamable $r3, 0, $noreg 788 // renamable $r3, dead $cpsr = tSUBi8 killed renamable $r3(tied-def 0), 4 789 // .. 790 // $lr = MVE_DLSTP_32 renamable $r3 791 // 792 // The SUB is the old update of the loop iteration count expression, which 793 // is no longer needed. This sub is removed when the element count, which is in 794 // r3 in this example, is defined by an instruction in the loop, and it has 795 // no uses. 796 // 797 void ARMLowOverheadLoops::RemoveLoopUpdate(LowOverheadLoop &LoLoop) { 798 Register ElemCount = LoLoop.VCTP->getOperand(1).getReg(); 799 MachineInstr *LastInstrInBlock = &LoLoop.VCTP->getParent()->back(); 800 801 LLVM_DEBUG(dbgs() << "ARM Loops: Trying to remove loop update stmt\n"); 802 803 if (LoLoop.ML->getNumBlocks() != 1) { 804 LLVM_DEBUG(dbgs() << "ARM Loops: single block loop expected\n"); 805 return; 806 } 807 808 LLVM_DEBUG(dbgs() << "ARM Loops: Analyzing MO: "; 809 LoLoop.VCTP->getOperand(1).dump()); 810 811 // Find the definition we are interested in removing, if there is one. 812 MachineInstr *Def = RDA->getReachingMIDef(LastInstrInBlock, ElemCount); 813 if (!Def) 814 return; 815 816 // Bail if we define CPSR and it is not dead 817 if (!Def->registerDefIsDead(ARM::CPSR, TRI)) { 818 LLVM_DEBUG(dbgs() << "ARM Loops: CPSR is not dead\n"); 819 return; 820 } 821 822 // Bail if elemcount is used in exit blocks, i.e. if it is live-in. 823 if (isRegLiveInExitBlocks(LoLoop.ML, ElemCount)) { 824 LLVM_DEBUG(dbgs() << "ARM Loops: Elemcount is live-out, can't remove stmt\n"); 825 return; 826 } 827 828 // Bail if there are uses after this Def in the block. 829 SmallVector<MachineInstr*, 4> Uses; 830 RDA->getReachingLocalUses(Def, ElemCount, Uses); 831 if (Uses.size()) { 832 LLVM_DEBUG(dbgs() << "ARM Loops: Local uses in block, can't remove stmt\n"); 833 return; 834 } 835 836 Uses.clear(); 837 RDA->getAllInstWithUseBefore(Def, ElemCount, Uses); 838 839 // Remove Def if there are no uses, or if the only use is the VCTP 840 // instruction. 841 if (!Uses.size() || (Uses.size() == 1 && Uses[0] == LoLoop.VCTP)) { 842 LLVM_DEBUG(dbgs() << "ARM Loops: Removing loop update instruction: "; 843 Def->dump()); 844 Def->eraseFromParent(); 845 } 846 } 847 848 void ARMLowOverheadLoops::ConvertVPTBlocks(LowOverheadLoop &LoLoop) { 849 auto RemovePredicate = [](MachineInstr *MI) { 850 LLVM_DEBUG(dbgs() << "ARM Loops: Removing predicate from: " << *MI); 851 unsigned OpNum = MI->getNumOperands() - 1; 852 assert(MI->getOperand(OpNum-1).getImm() == ARMVCC::Then && 853 "Expected Then predicate!"); 854 MI->getOperand(OpNum-1).setImm(ARMVCC::None); 855 MI->getOperand(OpNum).setReg(0); 856 }; 857 858 // There are a few scenarios which we have to fix up: 859 // 1) A VPT block with is only predicated by the vctp and has no internal vpr 860 // defs. 861 // 2) A VPT block which is only predicated by the vctp but has an internal 862 // vpr def. 863 // 3) A VPT block which is predicated upon the vctp as well as another vpr 864 // def. 865 // 4) A VPT block which is not predicated upon a vctp, but contains it and 866 // all instructions within the block are predicated upon in. 867 868 for (auto &Block : LoLoop.getVPTBlocks()) { 869 SmallVectorImpl<PredicatedMI> &Insts = Block.getInsts(); 870 if (Block.HasNonUniformPredicate()) { 871 PredicatedMI *Divergent = Block.getDivergent(); 872 if (isVCTP(Divergent->MI)) { 873 // The vctp will be removed, so the size of the vpt block needs to be 874 // modified. 875 uint64_t Size = getARMVPTBlockMask(Block.size() - 1); 876 Block.getVPST()->getOperand(0).setImm(Size); 877 LLVM_DEBUG(dbgs() << "ARM Loops: Modified VPT block mask.\n"); 878 } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { 879 // The VPT block has a non-uniform predicate but it's entry is guarded 880 // only by a vctp, which means we: 881 // - Need to remove the original vpst. 882 // - Then need to unpredicate any following instructions, until 883 // we come across the divergent vpr def. 884 // - Insert a new vpst to predicate the instruction(s) that following 885 // the divergent vpr def. 886 // TODO: We could be producing more VPT blocks than necessary and could 887 // fold the newly created one into a proceeding one. 888 for (auto I = ++MachineBasicBlock::iterator(Block.getVPST()), 889 E = ++MachineBasicBlock::iterator(Divergent->MI); I != E; ++I) 890 RemovePredicate(&*I); 891 892 unsigned Size = 0; 893 auto E = MachineBasicBlock::reverse_iterator(Divergent->MI); 894 auto I = MachineBasicBlock::reverse_iterator(Insts.back().MI); 895 MachineInstr *InsertAt = nullptr; 896 while (I != E) { 897 InsertAt = &*I; 898 ++Size; 899 ++I; 900 } 901 MachineInstrBuilder MIB = BuildMI(*InsertAt->getParent(), InsertAt, 902 InsertAt->getDebugLoc(), 903 TII->get(ARM::MVE_VPST)); 904 MIB.addImm(getARMVPTBlockMask(Size)); 905 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); 906 LLVM_DEBUG(dbgs() << "ARM Loops: Created VPST: " << *MIB); 907 Block.getVPST()->eraseFromParent(); 908 } 909 } else if (Block.IsOnlyPredicatedOn(LoLoop.VCTP)) { 910 // A vpt block which is only predicated upon vctp and has no internal vpr 911 // defs: 912 // - Remove vpst. 913 // - Unpredicate the remaining instructions. 914 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VPST: " << *Block.getVPST()); 915 Block.getVPST()->eraseFromParent(); 916 for (auto &PredMI : Insts) 917 RemovePredicate(PredMI.MI); 918 } 919 } 920 921 LLVM_DEBUG(dbgs() << "ARM Loops: Removing VCTP: " << *LoLoop.VCTP); 922 LoLoop.VCTP->eraseFromParent(); 923 } 924 925 void ARMLowOverheadLoops::Expand(LowOverheadLoop &LoLoop) { 926 927 // Combine the LoopDec and LoopEnd instructions into LE(TP). 928 auto ExpandLoopEnd = [this](LowOverheadLoop &LoLoop) { 929 MachineInstr *End = LoLoop.End; 930 MachineBasicBlock *MBB = End->getParent(); 931 unsigned Opc = LoLoop.IsTailPredicationLegal() ? 932 ARM::MVE_LETP : ARM::t2LEUpdate; 933 MachineInstrBuilder MIB = BuildMI(*MBB, End, End->getDebugLoc(), 934 TII->get(Opc)); 935 MIB.addDef(ARM::LR); 936 MIB.add(End->getOperand(0)); 937 MIB.add(End->getOperand(1)); 938 LLVM_DEBUG(dbgs() << "ARM Loops: Inserted LE: " << *MIB); 939 940 LoLoop.End->eraseFromParent(); 941 LoLoop.Dec->eraseFromParent(); 942 return &*MIB; 943 }; 944 945 // TODO: We should be able to automatically remove these branches before we 946 // get here - probably by teaching analyzeBranch about the pseudo 947 // instructions. 948 // If there is an unconditional branch, after I, that just branches to the 949 // next block, remove it. 950 auto RemoveDeadBranch = [](MachineInstr *I) { 951 MachineBasicBlock *BB = I->getParent(); 952 MachineInstr *Terminator = &BB->instr_back(); 953 if (Terminator->isUnconditionalBranch() && I != Terminator) { 954 MachineBasicBlock *Succ = Terminator->getOperand(0).getMBB(); 955 if (BB->isLayoutSuccessor(Succ)) { 956 LLVM_DEBUG(dbgs() << "ARM Loops: Removing branch: " << *Terminator); 957 Terminator->eraseFromParent(); 958 } 959 } 960 }; 961 962 if (LoLoop.Revert) { 963 if (LoLoop.Start->getOpcode() == ARM::t2WhileLoopStart) 964 RevertWhile(LoLoop.Start); 965 else 966 LoLoop.Start->eraseFromParent(); 967 bool FlagsAlreadySet = RevertLoopDec(LoLoop.Dec, true); 968 RevertLoopEnd(LoLoop.End, FlagsAlreadySet); 969 } else { 970 LoLoop.Start = ExpandLoopStart(LoLoop); 971 RemoveDeadBranch(LoLoop.Start); 972 LoLoop.End = ExpandLoopEnd(LoLoop); 973 RemoveDeadBranch(LoLoop.End); 974 if (LoLoop.IsTailPredicationLegal()) { 975 RemoveLoopUpdate(LoLoop); 976 ConvertVPTBlocks(LoLoop); 977 } 978 } 979 } 980 981 bool ARMLowOverheadLoops::RevertNonLoops() { 982 LLVM_DEBUG(dbgs() << "ARM Loops: Reverting any remaining pseudos...\n"); 983 bool Changed = false; 984 985 for (auto &MBB : *MF) { 986 SmallVector<MachineInstr*, 4> Starts; 987 SmallVector<MachineInstr*, 4> Decs; 988 SmallVector<MachineInstr*, 4> Ends; 989 990 for (auto &I : MBB) { 991 if (isLoopStart(I)) 992 Starts.push_back(&I); 993 else if (I.getOpcode() == ARM::t2LoopDec) 994 Decs.push_back(&I); 995 else if (I.getOpcode() == ARM::t2LoopEnd) 996 Ends.push_back(&I); 997 } 998 999 if (Starts.empty() && Decs.empty() && Ends.empty()) 1000 continue; 1001 1002 Changed = true; 1003 1004 for (auto *Start : Starts) { 1005 if (Start->getOpcode() == ARM::t2WhileLoopStart) 1006 RevertWhile(Start); 1007 else 1008 Start->eraseFromParent(); 1009 } 1010 for (auto *Dec : Decs) 1011 RevertLoopDec(Dec); 1012 1013 for (auto *End : Ends) 1014 RevertLoopEnd(End); 1015 } 1016 return Changed; 1017 } 1018 1019 FunctionPass *llvm::createARMLowOverheadLoopsPass() { 1020 return new ARMLowOverheadLoops(); 1021 } 1022