1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 namespace sandboxir { 31 32 BottomUpVec::BottomUpVec(StringRef Pipeline) 33 : FunctionPass("bottom-up-vec"), 34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 35 36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 37 unsigned OpIdx) { 38 SmallVector<Value *, 4> Operands; 39 for (Value *BndlV : Bndl) { 40 auto *BndlI = cast<Instruction>(BndlV); 41 Operands.push_back(BndlI->getOperand(OpIdx)); 42 } 43 return Operands; 44 } 45 46 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top 47 /// of BB if no instruction found in \p Vals. 48 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals, 49 BasicBlock *BB) { 50 auto *BotI = VecUtils::getLowest(Vals); 51 if (BotI == nullptr) 52 // We are using BB->begin() as the fallback insert point if `ToPack` did 53 // not contain instructions. 54 return BB->begin(); 55 return std::next(BotI->getIterator()); 56 } 57 58 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 59 ArrayRef<Value *> Operands) { 60 auto CreateVectorInstr = [](ArrayRef<Value *> Bndl, 61 ArrayRef<Value *> Operands) -> Value * { 62 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 63 "Expect Instructions!"); 64 auto &Ctx = Bndl[0]->getContext(); 65 66 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 67 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 68 69 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs( 70 Bndl, cast<Instruction>(Bndl[0])->getParent()); 71 72 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 73 switch (Opcode) { 74 case Instruction::Opcode::ZExt: 75 case Instruction::Opcode::SExt: 76 case Instruction::Opcode::FPToUI: 77 case Instruction::Opcode::FPToSI: 78 case Instruction::Opcode::FPExt: 79 case Instruction::Opcode::PtrToInt: 80 case Instruction::Opcode::IntToPtr: 81 case Instruction::Opcode::SIToFP: 82 case Instruction::Opcode::UIToFP: 83 case Instruction::Opcode::Trunc: 84 case Instruction::Opcode::FPTrunc: 85 case Instruction::Opcode::BitCast: { 86 assert(Operands.size() == 1u && "Casts are unary!"); 87 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, 88 "VCast"); 89 } 90 case Instruction::Opcode::FCmp: 91 case Instruction::Opcode::ICmp: { 92 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 93 assert(all_of(drop_begin(Bndl), 94 [Pred](auto *SBV) { 95 return cast<CmpInst>(SBV)->getPredicate() == Pred; 96 }) && 97 "Expected same predicate across bundle."); 98 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 99 "VCmp"); 100 } 101 case Instruction::Opcode::Select: { 102 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 103 Ctx, "Vec"); 104 } 105 case Instruction::Opcode::FNeg: { 106 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 107 auto OpC = UOp0->getOpcode(); 108 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, 109 WhereIt, Ctx, "Vec"); 110 } 111 case Instruction::Opcode::Add: 112 case Instruction::Opcode::FAdd: 113 case Instruction::Opcode::Sub: 114 case Instruction::Opcode::FSub: 115 case Instruction::Opcode::Mul: 116 case Instruction::Opcode::FMul: 117 case Instruction::Opcode::UDiv: 118 case Instruction::Opcode::SDiv: 119 case Instruction::Opcode::FDiv: 120 case Instruction::Opcode::URem: 121 case Instruction::Opcode::SRem: 122 case Instruction::Opcode::FRem: 123 case Instruction::Opcode::Shl: 124 case Instruction::Opcode::LShr: 125 case Instruction::Opcode::AShr: 126 case Instruction::Opcode::And: 127 case Instruction::Opcode::Or: 128 case Instruction::Opcode::Xor: { 129 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 130 auto *LHS = Operands[0]; 131 auto *RHS = Operands[1]; 132 return BinaryOperator::createWithCopiedFlags( 133 BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec"); 134 } 135 case Instruction::Opcode::Load: { 136 auto *Ld0 = cast<LoadInst>(Bndl[0]); 137 Value *Ptr = Ld0->getPointerOperand(); 138 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, 139 "VecL"); 140 } 141 case Instruction::Opcode::Store: { 142 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 143 Value *Val = Operands[0]; 144 Value *Ptr = Operands[1]; 145 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 146 } 147 case Instruction::Opcode::Br: 148 case Instruction::Opcode::Ret: 149 case Instruction::Opcode::PHI: 150 case Instruction::Opcode::AddrSpaceCast: 151 case Instruction::Opcode::Call: 152 case Instruction::Opcode::GetElementPtr: 153 llvm_unreachable("Unimplemented"); 154 break; 155 default: 156 llvm_unreachable("Unimplemented"); 157 break; 158 } 159 llvm_unreachable("Missing switch case!"); 160 // TODO: Propagate debug info. 161 }; 162 163 auto *VecI = CreateVectorInstr(Bndl, Operands); 164 if (VecI != nullptr) { 165 Change = true; 166 IMaps->registerVector(Bndl, VecI); 167 } 168 return VecI; 169 } 170 171 void BottomUpVec::tryEraseDeadInstrs() { 172 DenseMap<BasicBlock *, SmallVector<Instruction *>> SortedDeadInstrCandidates; 173 // The dead instrs could span BBs, so we need to collect and sort them per BB. 174 for (auto *DeadI : DeadInstrCandidates) 175 SortedDeadInstrCandidates[DeadI->getParent()].push_back(DeadI); 176 for (auto &Pair : SortedDeadInstrCandidates) 177 sort(Pair.second, 178 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 179 for (const auto &Pair : SortedDeadInstrCandidates) { 180 for (Instruction *I : reverse(Pair.second)) { 181 if (I->hasNUses(0)) 182 // Erase the dead instructions bottom-to-top. 183 I->eraseFromParent(); 184 } 185 } 186 DeadInstrCandidates.clear(); 187 } 188 189 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask, 190 BasicBlock *UserBB) { 191 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB); 192 return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt, 193 VecOp->getContext(), "VShuf"); 194 } 195 196 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) { 197 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB); 198 199 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 200 unsigned Lanes = VecUtils::getNumLanes(ToPack); 201 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 202 203 // Create a series of pack instructions. 204 Value *LastInsert = PoisonValue::get(VecTy); 205 206 Context &Ctx = ToPack[0]->getContext(); 207 208 unsigned InsertIdx = 0; 209 for (Value *Elm : ToPack) { 210 // An element can be either scalar or vector. We need to generate different 211 // IR for each case. 212 if (Elm->getType()->isVectorTy()) { 213 unsigned NumElms = 214 cast<FixedVectorType>(Elm->getType())->getNumElements(); 215 for (auto ExtrLane : seq<int>(0, NumElms)) { 216 // We generate extract-insert pairs, for each lane in `Elm`. 217 Constant *ExtrLaneC = 218 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 219 // This may return a Constant if Elm is a Constant. 220 auto *ExtrI = 221 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 222 if (!isa<Constant>(ExtrI)) 223 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 224 Constant *InsertLaneC = 225 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 226 // This may also return a Constant if ExtrI is a Constant. 227 auto *InsertI = InsertElementInst::create( 228 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 229 if (!isa<Constant>(InsertI)) { 230 LastInsert = InsertI; 231 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 232 } 233 } 234 } else { 235 Constant *InsertLaneC = 236 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 237 // This may be folded into a Constant if LastInsert is a Constant. In 238 // that case we only collect the last constant. 239 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 240 WhereIt, Ctx, "Pack"); 241 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 242 WhereIt = std::next(NewI->getIterator()); 243 } 244 } 245 return LastInsert; 246 } 247 248 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) { 249 for (Value *V : Bndl) 250 DeadInstrCandidates.insert(cast<Instruction>(V)); 251 // Also collect the GEPs of vectorized loads and stores. 252 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 253 switch (Opcode) { 254 case Instruction::Opcode::Load: { 255 for (Value *V : drop_begin(Bndl)) 256 if (auto *Ptr = 257 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand())) 258 DeadInstrCandidates.insert(Ptr); 259 break; 260 } 261 case Instruction::Opcode::Store: { 262 for (Value *V : drop_begin(Bndl)) 263 if (auto *Ptr = 264 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand())) 265 DeadInstrCandidates.insert(Ptr); 266 break; 267 } 268 default: 269 break; 270 } 271 } 272 273 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, 274 ArrayRef<Value *> UserBndl, unsigned Depth) { 275 Value *NewVec = nullptr; 276 auto *UserBB = !UserBndl.empty() 277 ? cast<Instruction>(UserBndl.front())->getParent() 278 : cast<Instruction>(Bndl[0])->getParent(); 279 const auto &LegalityRes = Legality->canVectorize(Bndl); 280 switch (LegalityRes.getSubclassID()) { 281 case LegalityResultID::Widen: { 282 auto *I = cast<Instruction>(Bndl[0]); 283 SmallVector<Value *, 2> VecOperands; 284 switch (I->getOpcode()) { 285 case Instruction::Opcode::Load: 286 // Don't recurse towards the pointer operand. 287 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 288 break; 289 case Instruction::Opcode::Store: { 290 // Don't recurse towards the pointer operand. 291 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1); 292 VecOperands.push_back(VecOp); 293 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 294 break; 295 } 296 default: 297 // Visit all operands. 298 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 299 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1); 300 VecOperands.push_back(VecOp); 301 } 302 break; 303 } 304 NewVec = createVectorInstr(Bndl, VecOperands); 305 306 // Collect any potentially dead scalar instructions, including the original 307 // scalars and pointer operands of loads/stores. 308 if (NewVec != nullptr) 309 collectPotentiallyDeadInstrs(Bndl); 310 break; 311 } 312 case LegalityResultID::DiamondReuse: { 313 NewVec = cast<DiamondReuse>(LegalityRes).getVector(); 314 break; 315 } 316 case LegalityResultID::DiamondReuseWithShuffle: { 317 auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector(); 318 const ShuffleMask &Mask = 319 cast<DiamondReuseWithShuffle>(LegalityRes).getMask(); 320 NewVec = createShuffle(VecOp, Mask, UserBB); 321 break; 322 } 323 case LegalityResultID::DiamondReuseMultiInput: { 324 const auto &Descr = 325 cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr(); 326 Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); 327 328 // TODO: Try to get WhereIt without creating a vector. 329 SmallVector<Value *, 4> DescrInstrs; 330 for (const auto &ElmDescr : Descr.getDescrs()) { 331 if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue())) 332 DescrInstrs.push_back(I); 333 } 334 BasicBlock::iterator WhereIt = 335 getInsertPointAfterInstrs(DescrInstrs, UserBB); 336 337 Value *LastV = PoisonValue::get(ResTy); 338 for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { 339 Value *VecOp = ElmDescr.getValue(); 340 Context &Ctx = VecOp->getContext(); 341 Value *ValueToInsert; 342 if (ElmDescr.needsExtract()) { 343 ConstantInt *IdxC = 344 ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); 345 ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, 346 VecOp->getContext(), "VExt"); 347 } else { 348 ValueToInsert = VecOp; 349 } 350 ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); 351 Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, 352 WhereIt, Ctx, "VIns"); 353 LastV = Ins; 354 } 355 NewVec = LastV; 356 break; 357 } 358 case LegalityResultID::Pack: { 359 // If we can't vectorize the seeds then just return. 360 if (Depth == 0) 361 return nullptr; 362 NewVec = createPack(Bndl, UserBB); 363 break; 364 } 365 } 366 return NewVec; 367 } 368 369 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 370 DeadInstrCandidates.clear(); 371 Legality->clear(); 372 vectorizeRec(Bndl, {}, /*Depth=*/0); 373 tryEraseDeadInstrs(); 374 return Change; 375 } 376 377 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 378 IMaps = std::make_unique<InstrMaps>(F.getContext()); 379 Legality = std::make_unique<LegalityAnalysis>( 380 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 381 F.getContext(), *IMaps); 382 Change = false; 383 const auto &DL = F.getParent()->getDataLayout(); 384 unsigned VecRegBits = 385 OverrideVecRegBits != 0 386 ? OverrideVecRegBits 387 : A.getTTI() 388 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 389 .getFixedValue(); 390 391 // TODO: Start from innermost BBs first 392 for (auto &BB : F) { 393 SeedCollector SC(&BB, A.getScalarEvolution()); 394 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 395 unsigned ElmBits = 396 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 397 Seeds[Seeds.getFirstUnusedElementIdx()])), 398 DL); 399 400 auto DivideBy2 = [](unsigned Num) { 401 auto Floor = VecUtils::getFloorPowerOf2(Num); 402 if (Floor == Num) 403 return Floor / 2; 404 return Floor; 405 }; 406 // Try to create the largest vector supported by the target. If it fails 407 // reduce the vector size by half. 408 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 409 Seeds.getNumUnusedBits() / ElmBits); 410 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 411 if (Seeds.allUsed()) 412 break; 413 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 414 // the slice. This could be quite expensive, so we enforce a limit. 415 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 416 OE = Seeds.size(); 417 Offset + 1 < OE; Offset += 1) { 418 // Seeds are getting used as we vectorize, so skip them. 419 if (Seeds.isUsed(Offset)) 420 continue; 421 if (Seeds.allUsed()) 422 break; 423 424 auto SeedSlice = 425 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 426 if (SeedSlice.empty()) 427 continue; 428 429 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 430 431 // TODO: If vectorization succeeds, run the RegionPassManager on the 432 // resulting region. 433 434 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 435 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 436 SeedSlice.end()); 437 Change |= tryVectorize(SeedSliceVals); 438 } 439 } 440 } 441 } 442 return Change; 443 } 444 445 } // namespace sandboxir 446 } // namespace llvm 447