1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 namespace sandboxir { 31 32 BottomUpVec::BottomUpVec(StringRef Pipeline) 33 : FunctionPass("bottom-up-vec"), 34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 35 36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 37 unsigned OpIdx) { 38 SmallVector<Value *, 4> Operands; 39 for (Value *BndlV : Bndl) { 40 auto *BndlI = cast<Instruction>(BndlV); 41 Operands.push_back(BndlI->getOperand(OpIdx)); 42 } 43 return Operands; 44 } 45 46 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top 47 /// of BB if no instruction found in \p Vals. 48 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals, 49 BasicBlock *BB) { 50 auto *BotI = VecUtils::getLowest(Vals); 51 if (BotI == nullptr) 52 // We are using BB->begin() as the fallback insert point if `ToPack` did 53 // not contain instructions. 54 return BB->begin(); 55 return std::next(BotI->getIterator()); 56 } 57 58 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 59 ArrayRef<Value *> Operands) { 60 auto CreateVectorInstr = [](ArrayRef<Value *> Bndl, 61 ArrayRef<Value *> Operands) -> Value * { 62 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 63 "Expect Instructions!"); 64 auto &Ctx = Bndl[0]->getContext(); 65 66 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 67 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 68 69 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs( 70 Bndl, cast<Instruction>(Bndl[0])->getParent()); 71 72 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 73 switch (Opcode) { 74 case Instruction::Opcode::ZExt: 75 case Instruction::Opcode::SExt: 76 case Instruction::Opcode::FPToUI: 77 case Instruction::Opcode::FPToSI: 78 case Instruction::Opcode::FPExt: 79 case Instruction::Opcode::PtrToInt: 80 case Instruction::Opcode::IntToPtr: 81 case Instruction::Opcode::SIToFP: 82 case Instruction::Opcode::UIToFP: 83 case Instruction::Opcode::Trunc: 84 case Instruction::Opcode::FPTrunc: 85 case Instruction::Opcode::BitCast: { 86 assert(Operands.size() == 1u && "Casts are unary!"); 87 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, 88 "VCast"); 89 } 90 case Instruction::Opcode::FCmp: 91 case Instruction::Opcode::ICmp: { 92 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 93 assert(all_of(drop_begin(Bndl), 94 [Pred](auto *SBV) { 95 return cast<CmpInst>(SBV)->getPredicate() == Pred; 96 }) && 97 "Expected same predicate across bundle."); 98 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 99 "VCmp"); 100 } 101 case Instruction::Opcode::Select: { 102 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 103 Ctx, "Vec"); 104 } 105 case Instruction::Opcode::FNeg: { 106 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 107 auto OpC = UOp0->getOpcode(); 108 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, 109 WhereIt, Ctx, "Vec"); 110 } 111 case Instruction::Opcode::Add: 112 case Instruction::Opcode::FAdd: 113 case Instruction::Opcode::Sub: 114 case Instruction::Opcode::FSub: 115 case Instruction::Opcode::Mul: 116 case Instruction::Opcode::FMul: 117 case Instruction::Opcode::UDiv: 118 case Instruction::Opcode::SDiv: 119 case Instruction::Opcode::FDiv: 120 case Instruction::Opcode::URem: 121 case Instruction::Opcode::SRem: 122 case Instruction::Opcode::FRem: 123 case Instruction::Opcode::Shl: 124 case Instruction::Opcode::LShr: 125 case Instruction::Opcode::AShr: 126 case Instruction::Opcode::And: 127 case Instruction::Opcode::Or: 128 case Instruction::Opcode::Xor: { 129 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 130 auto *LHS = Operands[0]; 131 auto *RHS = Operands[1]; 132 return BinaryOperator::createWithCopiedFlags( 133 BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec"); 134 } 135 case Instruction::Opcode::Load: { 136 auto *Ld0 = cast<LoadInst>(Bndl[0]); 137 Value *Ptr = Ld0->getPointerOperand(); 138 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, 139 "VecL"); 140 } 141 case Instruction::Opcode::Store: { 142 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 143 Value *Val = Operands[0]; 144 Value *Ptr = Operands[1]; 145 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 146 } 147 case Instruction::Opcode::Br: 148 case Instruction::Opcode::Ret: 149 case Instruction::Opcode::PHI: 150 case Instruction::Opcode::AddrSpaceCast: 151 case Instruction::Opcode::Call: 152 case Instruction::Opcode::GetElementPtr: 153 llvm_unreachable("Unimplemented"); 154 break; 155 default: 156 llvm_unreachable("Unimplemented"); 157 break; 158 } 159 llvm_unreachable("Missing switch case!"); 160 // TODO: Propagate debug info. 161 }; 162 163 auto *VecI = CreateVectorInstr(Bndl, Operands); 164 if (VecI != nullptr) { 165 Change = true; 166 IMaps->registerVector(Bndl, VecI); 167 } 168 return VecI; 169 } 170 171 void BottomUpVec::tryEraseDeadInstrs() { 172 // Visiting the dead instructions bottom-to-top. 173 SmallVector<Instruction *> SortedDeadInstrCandidates( 174 DeadInstrCandidates.begin(), DeadInstrCandidates.end()); 175 sort(SortedDeadInstrCandidates, 176 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 177 for (Instruction *I : reverse(SortedDeadInstrCandidates)) { 178 if (I->hasNUses(0)) 179 I->eraseFromParent(); 180 } 181 DeadInstrCandidates.clear(); 182 } 183 184 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask, 185 BasicBlock *UserBB) { 186 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB); 187 return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt, 188 VecOp->getContext(), "VShuf"); 189 } 190 191 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) { 192 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB); 193 194 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 195 unsigned Lanes = VecUtils::getNumLanes(ToPack); 196 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 197 198 // Create a series of pack instructions. 199 Value *LastInsert = PoisonValue::get(VecTy); 200 201 Context &Ctx = ToPack[0]->getContext(); 202 203 unsigned InsertIdx = 0; 204 for (Value *Elm : ToPack) { 205 // An element can be either scalar or vector. We need to generate different 206 // IR for each case. 207 if (Elm->getType()->isVectorTy()) { 208 unsigned NumElms = 209 cast<FixedVectorType>(Elm->getType())->getNumElements(); 210 for (auto ExtrLane : seq<int>(0, NumElms)) { 211 // We generate extract-insert pairs, for each lane in `Elm`. 212 Constant *ExtrLaneC = 213 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 214 // This may return a Constant if Elm is a Constant. 215 auto *ExtrI = 216 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 217 if (!isa<Constant>(ExtrI)) 218 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 219 Constant *InsertLaneC = 220 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 221 // This may also return a Constant if ExtrI is a Constant. 222 auto *InsertI = InsertElementInst::create( 223 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 224 if (!isa<Constant>(InsertI)) { 225 LastInsert = InsertI; 226 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 227 } 228 } 229 } else { 230 Constant *InsertLaneC = 231 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 232 // This may be folded into a Constant if LastInsert is a Constant. In 233 // that case we only collect the last constant. 234 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 235 WhereIt, Ctx, "Pack"); 236 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 237 WhereIt = std::next(NewI->getIterator()); 238 } 239 } 240 return LastInsert; 241 } 242 243 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) { 244 for (Value *V : Bndl) 245 DeadInstrCandidates.insert(cast<Instruction>(V)); 246 // Also collect the GEPs of vectorized loads and stores. 247 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 248 switch (Opcode) { 249 case Instruction::Opcode::Load: { 250 for (Value *V : drop_begin(Bndl)) 251 if (auto *Ptr = 252 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand())) 253 DeadInstrCandidates.insert(Ptr); 254 break; 255 } 256 case Instruction::Opcode::Store: { 257 for (Value *V : drop_begin(Bndl)) 258 if (auto *Ptr = 259 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand())) 260 DeadInstrCandidates.insert(Ptr); 261 break; 262 } 263 default: 264 break; 265 } 266 } 267 268 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, 269 ArrayRef<Value *> UserBndl, unsigned Depth) { 270 Value *NewVec = nullptr; 271 auto *UserBB = !UserBndl.empty() 272 ? cast<Instruction>(UserBndl.front())->getParent() 273 : cast<Instruction>(Bndl[0])->getParent(); 274 const auto &LegalityRes = Legality->canVectorize(Bndl); 275 switch (LegalityRes.getSubclassID()) { 276 case LegalityResultID::Widen: { 277 auto *I = cast<Instruction>(Bndl[0]); 278 SmallVector<Value *, 2> VecOperands; 279 switch (I->getOpcode()) { 280 case Instruction::Opcode::Load: 281 // Don't recurse towards the pointer operand. 282 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 283 break; 284 case Instruction::Opcode::Store: { 285 // Don't recurse towards the pointer operand. 286 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1); 287 VecOperands.push_back(VecOp); 288 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 289 break; 290 } 291 default: 292 // Visit all operands. 293 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 294 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1); 295 VecOperands.push_back(VecOp); 296 } 297 break; 298 } 299 NewVec = createVectorInstr(Bndl, VecOperands); 300 301 // Collect any potentially dead scalar instructions, including the original 302 // scalars and pointer operands of loads/stores. 303 if (NewVec != nullptr) 304 collectPotentiallyDeadInstrs(Bndl); 305 break; 306 } 307 case LegalityResultID::DiamondReuse: { 308 NewVec = cast<DiamondReuse>(LegalityRes).getVector(); 309 break; 310 } 311 case LegalityResultID::DiamondReuseWithShuffle: { 312 auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector(); 313 const ShuffleMask &Mask = 314 cast<DiamondReuseWithShuffle>(LegalityRes).getMask(); 315 NewVec = createShuffle(VecOp, Mask, UserBB); 316 break; 317 } 318 case LegalityResultID::DiamondReuseMultiInput: { 319 const auto &Descr = 320 cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr(); 321 Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); 322 323 // TODO: Try to get WhereIt without creating a vector. 324 SmallVector<Value *, 4> DescrInstrs; 325 for (const auto &ElmDescr : Descr.getDescrs()) { 326 if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue())) 327 DescrInstrs.push_back(I); 328 } 329 BasicBlock::iterator WhereIt = 330 getInsertPointAfterInstrs(DescrInstrs, UserBB); 331 332 Value *LastV = PoisonValue::get(ResTy); 333 for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { 334 Value *VecOp = ElmDescr.getValue(); 335 Context &Ctx = VecOp->getContext(); 336 Value *ValueToInsert; 337 if (ElmDescr.needsExtract()) { 338 ConstantInt *IdxC = 339 ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); 340 ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, 341 VecOp->getContext(), "VExt"); 342 } else { 343 ValueToInsert = VecOp; 344 } 345 ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); 346 Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, 347 WhereIt, Ctx, "VIns"); 348 LastV = Ins; 349 } 350 NewVec = LastV; 351 break; 352 } 353 case LegalityResultID::Pack: { 354 // If we can't vectorize the seeds then just return. 355 if (Depth == 0) 356 return nullptr; 357 NewVec = createPack(Bndl, UserBB); 358 break; 359 } 360 } 361 return NewVec; 362 } 363 364 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 365 DeadInstrCandidates.clear(); 366 Legality->clear(); 367 vectorizeRec(Bndl, {}, /*Depth=*/0); 368 tryEraseDeadInstrs(); 369 return Change; 370 } 371 372 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 373 IMaps = std::make_unique<InstrMaps>(F.getContext()); 374 Legality = std::make_unique<LegalityAnalysis>( 375 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 376 F.getContext(), *IMaps); 377 Change = false; 378 const auto &DL = F.getParent()->getDataLayout(); 379 unsigned VecRegBits = 380 OverrideVecRegBits != 0 381 ? OverrideVecRegBits 382 : A.getTTI() 383 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 384 .getFixedValue(); 385 386 // TODO: Start from innermost BBs first 387 for (auto &BB : F) { 388 SeedCollector SC(&BB, A.getScalarEvolution()); 389 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 390 unsigned ElmBits = 391 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 392 Seeds[Seeds.getFirstUnusedElementIdx()])), 393 DL); 394 395 auto DivideBy2 = [](unsigned Num) { 396 auto Floor = VecUtils::getFloorPowerOf2(Num); 397 if (Floor == Num) 398 return Floor / 2; 399 return Floor; 400 }; 401 // Try to create the largest vector supported by the target. If it fails 402 // reduce the vector size by half. 403 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 404 Seeds.getNumUnusedBits() / ElmBits); 405 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 406 if (Seeds.allUsed()) 407 break; 408 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 409 // the slice. This could be quite expensive, so we enforce a limit. 410 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 411 OE = Seeds.size(); 412 Offset + 1 < OE; Offset += 1) { 413 // Seeds are getting used as we vectorize, so skip them. 414 if (Seeds.isUsed(Offset)) 415 continue; 416 if (Seeds.allUsed()) 417 break; 418 419 auto SeedSlice = 420 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 421 if (SeedSlice.empty()) 422 continue; 423 424 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 425 426 // TODO: If vectorization succeeds, run the RegionPassManager on the 427 // resulting region. 428 429 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 430 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 431 SeedSlice.end()); 432 Change |= tryVectorize(SeedSliceVals); 433 } 434 } 435 } 436 } 437 return Change; 438 } 439 440 } // namespace sandboxir 441 } // namespace llvm 442