1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 #ifndef NDEBUG 31 static cl::opt<bool> 32 AlwaysVerify("sbvec-always-verify", cl::init(false), cl::Hidden, 33 cl::desc("Helps find bugs by verifying the IR whenever we " 34 "emit new instructions (*very* expensive).")); 35 #endif // NDEBUG 36 37 namespace sandboxir { 38 39 BottomUpVec::BottomUpVec(StringRef Pipeline) 40 : FunctionPass("bottom-up-vec"), 41 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 42 43 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 44 unsigned OpIdx) { 45 SmallVector<Value *, 4> Operands; 46 for (Value *BndlV : Bndl) { 47 auto *BndlI = cast<Instruction>(BndlV); 48 Operands.push_back(BndlI->getOperand(OpIdx)); 49 } 50 return Operands; 51 } 52 53 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top 54 /// of BB if no instruction found in \p Vals. 55 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals, 56 BasicBlock *BB) { 57 auto *BotI = VecUtils::getLastPHIOrSelf(VecUtils::getLowest(Vals, BB)); 58 if (BotI == nullptr) 59 // We are using BB->begin() (or after PHIs) as the fallback insert point. 60 return BB->empty() 61 ? BB->begin() 62 : std::next( 63 VecUtils::getLastPHIOrSelf(&*BB->begin())->getIterator()); 64 return std::next(BotI->getIterator()); 65 } 66 67 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 68 ArrayRef<Value *> Operands) { 69 auto CreateVectorInstr = [](ArrayRef<Value *> Bndl, 70 ArrayRef<Value *> Operands) -> Value * { 71 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 72 "Expect Instructions!"); 73 auto &Ctx = Bndl[0]->getContext(); 74 75 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 76 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 77 78 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs( 79 Bndl, cast<Instruction>(Bndl[0])->getParent()); 80 81 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 82 switch (Opcode) { 83 case Instruction::Opcode::ZExt: 84 case Instruction::Opcode::SExt: 85 case Instruction::Opcode::FPToUI: 86 case Instruction::Opcode::FPToSI: 87 case Instruction::Opcode::FPExt: 88 case Instruction::Opcode::PtrToInt: 89 case Instruction::Opcode::IntToPtr: 90 case Instruction::Opcode::SIToFP: 91 case Instruction::Opcode::UIToFP: 92 case Instruction::Opcode::Trunc: 93 case Instruction::Opcode::FPTrunc: 94 case Instruction::Opcode::BitCast: { 95 assert(Operands.size() == 1u && "Casts are unary!"); 96 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, 97 "VCast"); 98 } 99 case Instruction::Opcode::FCmp: 100 case Instruction::Opcode::ICmp: { 101 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 102 assert(all_of(drop_begin(Bndl), 103 [Pred](auto *SBV) { 104 return cast<CmpInst>(SBV)->getPredicate() == Pred; 105 }) && 106 "Expected same predicate across bundle."); 107 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 108 "VCmp"); 109 } 110 case Instruction::Opcode::Select: { 111 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 112 Ctx, "Vec"); 113 } 114 case Instruction::Opcode::FNeg: { 115 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 116 auto OpC = UOp0->getOpcode(); 117 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, 118 WhereIt, Ctx, "Vec"); 119 } 120 case Instruction::Opcode::Add: 121 case Instruction::Opcode::FAdd: 122 case Instruction::Opcode::Sub: 123 case Instruction::Opcode::FSub: 124 case Instruction::Opcode::Mul: 125 case Instruction::Opcode::FMul: 126 case Instruction::Opcode::UDiv: 127 case Instruction::Opcode::SDiv: 128 case Instruction::Opcode::FDiv: 129 case Instruction::Opcode::URem: 130 case Instruction::Opcode::SRem: 131 case Instruction::Opcode::FRem: 132 case Instruction::Opcode::Shl: 133 case Instruction::Opcode::LShr: 134 case Instruction::Opcode::AShr: 135 case Instruction::Opcode::And: 136 case Instruction::Opcode::Or: 137 case Instruction::Opcode::Xor: { 138 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 139 auto *LHS = Operands[0]; 140 auto *RHS = Operands[1]; 141 return BinaryOperator::createWithCopiedFlags( 142 BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec"); 143 } 144 case Instruction::Opcode::Load: { 145 auto *Ld0 = cast<LoadInst>(Bndl[0]); 146 Value *Ptr = Ld0->getPointerOperand(); 147 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, 148 "VecL"); 149 } 150 case Instruction::Opcode::Store: { 151 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 152 Value *Val = Operands[0]; 153 Value *Ptr = Operands[1]; 154 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 155 } 156 case Instruction::Opcode::Br: 157 case Instruction::Opcode::Ret: 158 case Instruction::Opcode::PHI: 159 case Instruction::Opcode::AddrSpaceCast: 160 case Instruction::Opcode::Call: 161 case Instruction::Opcode::GetElementPtr: 162 llvm_unreachable("Unimplemented"); 163 break; 164 default: 165 llvm_unreachable("Unimplemented"); 166 break; 167 } 168 llvm_unreachable("Missing switch case!"); 169 // TODO: Propagate debug info. 170 }; 171 172 auto *VecI = CreateVectorInstr(Bndl, Operands); 173 if (VecI != nullptr) { 174 Change = true; 175 IMaps->registerVector(Bndl, VecI); 176 } 177 return VecI; 178 } 179 180 void BottomUpVec::tryEraseDeadInstrs() { 181 DenseMap<BasicBlock *, SmallVector<Instruction *>> SortedDeadInstrCandidates; 182 // The dead instrs could span BBs, so we need to collect and sort them per BB. 183 for (auto *DeadI : DeadInstrCandidates) 184 SortedDeadInstrCandidates[DeadI->getParent()].push_back(DeadI); 185 for (auto &Pair : SortedDeadInstrCandidates) 186 sort(Pair.second, 187 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 188 for (const auto &Pair : SortedDeadInstrCandidates) { 189 for (Instruction *I : reverse(Pair.second)) { 190 if (I->hasNUses(0)) 191 // Erase the dead instructions bottom-to-top. 192 I->eraseFromParent(); 193 } 194 } 195 DeadInstrCandidates.clear(); 196 } 197 198 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask, 199 BasicBlock *UserBB) { 200 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB); 201 return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt, 202 VecOp->getContext(), "VShuf"); 203 } 204 205 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) { 206 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB); 207 208 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 209 unsigned Lanes = VecUtils::getNumLanes(ToPack); 210 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 211 212 // Create a series of pack instructions. 213 Value *LastInsert = PoisonValue::get(VecTy); 214 215 Context &Ctx = ToPack[0]->getContext(); 216 217 unsigned InsertIdx = 0; 218 for (Value *Elm : ToPack) { 219 // An element can be either scalar or vector. We need to generate different 220 // IR for each case. 221 if (Elm->getType()->isVectorTy()) { 222 unsigned NumElms = 223 cast<FixedVectorType>(Elm->getType())->getNumElements(); 224 for (auto ExtrLane : seq<int>(0, NumElms)) { 225 // We generate extract-insert pairs, for each lane in `Elm`. 226 Constant *ExtrLaneC = 227 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 228 // This may return a Constant if Elm is a Constant. 229 auto *ExtrI = 230 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 231 if (!isa<Constant>(ExtrI)) 232 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 233 Constant *InsertLaneC = 234 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 235 // This may also return a Constant if ExtrI is a Constant. 236 auto *InsertI = InsertElementInst::create( 237 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 238 if (!isa<Constant>(InsertI)) { 239 LastInsert = InsertI; 240 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 241 } 242 } 243 } else { 244 Constant *InsertLaneC = 245 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 246 // This may be folded into a Constant if LastInsert is a Constant. In 247 // that case we only collect the last constant. 248 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 249 WhereIt, Ctx, "Pack"); 250 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 251 WhereIt = std::next(NewI->getIterator()); 252 } 253 } 254 return LastInsert; 255 } 256 257 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) { 258 for (Value *V : Bndl) 259 DeadInstrCandidates.insert(cast<Instruction>(V)); 260 // Also collect the GEPs of vectorized loads and stores. 261 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 262 switch (Opcode) { 263 case Instruction::Opcode::Load: { 264 for (Value *V : drop_begin(Bndl)) 265 if (auto *Ptr = 266 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand())) 267 DeadInstrCandidates.insert(Ptr); 268 break; 269 } 270 case Instruction::Opcode::Store: { 271 for (Value *V : drop_begin(Bndl)) 272 if (auto *Ptr = 273 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand())) 274 DeadInstrCandidates.insert(Ptr); 275 break; 276 } 277 default: 278 break; 279 } 280 } 281 282 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, 283 ArrayRef<Value *> UserBndl, unsigned Depth) { 284 Value *NewVec = nullptr; 285 auto *UserBB = !UserBndl.empty() 286 ? cast<Instruction>(UserBndl.front())->getParent() 287 : cast<Instruction>(Bndl[0])->getParent(); 288 const auto &LegalityRes = Legality->canVectorize(Bndl); 289 switch (LegalityRes.getSubclassID()) { 290 case LegalityResultID::Widen: { 291 auto *I = cast<Instruction>(Bndl[0]); 292 SmallVector<Value *, 2> VecOperands; 293 switch (I->getOpcode()) { 294 case Instruction::Opcode::Load: 295 // Don't recurse towards the pointer operand. 296 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 297 break; 298 case Instruction::Opcode::Store: { 299 // Don't recurse towards the pointer operand. 300 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1); 301 VecOperands.push_back(VecOp); 302 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 303 break; 304 } 305 default: 306 // Visit all operands. 307 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 308 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1); 309 VecOperands.push_back(VecOp); 310 } 311 break; 312 } 313 NewVec = createVectorInstr(Bndl, VecOperands); 314 315 // Collect any potentially dead scalar instructions, including the original 316 // scalars and pointer operands of loads/stores. 317 if (NewVec != nullptr) 318 collectPotentiallyDeadInstrs(Bndl); 319 break; 320 } 321 case LegalityResultID::DiamondReuse: { 322 NewVec = cast<DiamondReuse>(LegalityRes).getVector(); 323 break; 324 } 325 case LegalityResultID::DiamondReuseWithShuffle: { 326 auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector(); 327 const ShuffleMask &Mask = 328 cast<DiamondReuseWithShuffle>(LegalityRes).getMask(); 329 NewVec = createShuffle(VecOp, Mask, UserBB); 330 break; 331 } 332 case LegalityResultID::DiamondReuseMultiInput: { 333 const auto &Descr = 334 cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr(); 335 Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); 336 337 // TODO: Try to get WhereIt without creating a vector. 338 SmallVector<Value *, 4> DescrInstrs; 339 for (const auto &ElmDescr : Descr.getDescrs()) { 340 if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue())) 341 DescrInstrs.push_back(I); 342 } 343 BasicBlock::iterator WhereIt = 344 getInsertPointAfterInstrs(DescrInstrs, UserBB); 345 346 Value *LastV = PoisonValue::get(ResTy); 347 for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { 348 Value *VecOp = ElmDescr.getValue(); 349 Context &Ctx = VecOp->getContext(); 350 Value *ValueToInsert; 351 if (ElmDescr.needsExtract()) { 352 ConstantInt *IdxC = 353 ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); 354 ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, 355 VecOp->getContext(), "VExt"); 356 } else { 357 ValueToInsert = VecOp; 358 } 359 ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); 360 Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, 361 WhereIt, Ctx, "VIns"); 362 LastV = Ins; 363 } 364 NewVec = LastV; 365 break; 366 } 367 case LegalityResultID::Pack: { 368 // If we can't vectorize the seeds then just return. 369 if (Depth == 0) 370 return nullptr; 371 NewVec = createPack(Bndl, UserBB); 372 break; 373 } 374 } 375 #ifndef NDEBUG 376 if (AlwaysVerify) { 377 // This helps find broken IR by constantly verifying the function. Note that 378 // this is very expensive and should only be used for debugging. 379 Instruction *I0 = isa<Instruction>(Bndl[0]) 380 ? cast<Instruction>(Bndl[0]) 381 : cast<Instruction>(UserBndl[0]); 382 assert(!Utils::verifyFunction(I0->getParent()->getParent(), dbgs()) && 383 "Broken function!"); 384 } 385 #endif 386 return NewVec; 387 } 388 389 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 390 DeadInstrCandidates.clear(); 391 Legality->clear(); 392 vectorizeRec(Bndl, {}, /*Depth=*/0); 393 tryEraseDeadInstrs(); 394 return Change; 395 } 396 397 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 398 IMaps = std::make_unique<InstrMaps>(F.getContext()); 399 Legality = std::make_unique<LegalityAnalysis>( 400 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 401 F.getContext(), *IMaps); 402 Change = false; 403 const auto &DL = F.getParent()->getDataLayout(); 404 unsigned VecRegBits = 405 OverrideVecRegBits != 0 406 ? OverrideVecRegBits 407 : A.getTTI() 408 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 409 .getFixedValue(); 410 411 // TODO: Start from innermost BBs first 412 for (auto &BB : F) { 413 SeedCollector SC(&BB, A.getScalarEvolution()); 414 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 415 unsigned ElmBits = 416 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 417 Seeds[Seeds.getFirstUnusedElementIdx()])), 418 DL); 419 420 auto DivideBy2 = [](unsigned Num) { 421 auto Floor = VecUtils::getFloorPowerOf2(Num); 422 if (Floor == Num) 423 return Floor / 2; 424 return Floor; 425 }; 426 // Try to create the largest vector supported by the target. If it fails 427 // reduce the vector size by half. 428 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 429 Seeds.getNumUnusedBits() / ElmBits); 430 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 431 if (Seeds.allUsed()) 432 break; 433 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 434 // the slice. This could be quite expensive, so we enforce a limit. 435 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 436 OE = Seeds.size(); 437 Offset + 1 < OE; Offset += 1) { 438 // Seeds are getting used as we vectorize, so skip them. 439 if (Seeds.isUsed(Offset)) 440 continue; 441 if (Seeds.allUsed()) 442 break; 443 444 auto SeedSlice = 445 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 446 if (SeedSlice.empty()) 447 continue; 448 449 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 450 451 // TODO: If vectorization succeeds, run the RegionPassManager on the 452 // resulting region. 453 454 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 455 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 456 SeedSlice.end()); 457 Change |= tryVectorize(SeedSliceVals); 458 } 459 } 460 } 461 } 462 return Change; 463 } 464 465 } // namespace sandboxir 466 } // namespace llvm 467