1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 namespace sandboxir { 31 32 BottomUpVec::BottomUpVec(StringRef Pipeline) 33 : FunctionPass("bottom-up-vec"), 34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 35 36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 37 unsigned OpIdx) { 38 SmallVector<Value *, 4> Operands; 39 for (Value *BndlV : Bndl) { 40 auto *BndlI = cast<Instruction>(BndlV); 41 Operands.push_back(BndlI->getOperand(OpIdx)); 42 } 43 return Operands; 44 } 45 46 static BasicBlock::iterator 47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) { 48 // TODO: Use the VecUtils function for getting the bottom instr once it lands. 49 auto *BotI = cast<Instruction>( 50 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { 51 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2)); 52 })); 53 // If Bndl contains Arguments or Constants, use the beginning of the BB. 54 return std::next(BotI->getIterator()); 55 } 56 57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 58 ArrayRef<Value *> Operands) { 59 auto CreateVectorInstr = [](ArrayRef<Value *> Bndl, 60 ArrayRef<Value *> Operands) -> Value * { 61 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 62 "Expect Instructions!"); 63 auto &Ctx = Bndl[0]->getContext(); 64 65 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 66 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 67 68 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); 69 70 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 71 switch (Opcode) { 72 case Instruction::Opcode::ZExt: 73 case Instruction::Opcode::SExt: 74 case Instruction::Opcode::FPToUI: 75 case Instruction::Opcode::FPToSI: 76 case Instruction::Opcode::FPExt: 77 case Instruction::Opcode::PtrToInt: 78 case Instruction::Opcode::IntToPtr: 79 case Instruction::Opcode::SIToFP: 80 case Instruction::Opcode::UIToFP: 81 case Instruction::Opcode::Trunc: 82 case Instruction::Opcode::FPTrunc: 83 case Instruction::Opcode::BitCast: { 84 assert(Operands.size() == 1u && "Casts are unary!"); 85 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, 86 "VCast"); 87 } 88 case Instruction::Opcode::FCmp: 89 case Instruction::Opcode::ICmp: { 90 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 91 assert(all_of(drop_begin(Bndl), 92 [Pred](auto *SBV) { 93 return cast<CmpInst>(SBV)->getPredicate() == Pred; 94 }) && 95 "Expected same predicate across bundle."); 96 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 97 "VCmp"); 98 } 99 case Instruction::Opcode::Select: { 100 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 101 Ctx, "Vec"); 102 } 103 case Instruction::Opcode::FNeg: { 104 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 105 auto OpC = UOp0->getOpcode(); 106 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, 107 WhereIt, Ctx, "Vec"); 108 } 109 case Instruction::Opcode::Add: 110 case Instruction::Opcode::FAdd: 111 case Instruction::Opcode::Sub: 112 case Instruction::Opcode::FSub: 113 case Instruction::Opcode::Mul: 114 case Instruction::Opcode::FMul: 115 case Instruction::Opcode::UDiv: 116 case Instruction::Opcode::SDiv: 117 case Instruction::Opcode::FDiv: 118 case Instruction::Opcode::URem: 119 case Instruction::Opcode::SRem: 120 case Instruction::Opcode::FRem: 121 case Instruction::Opcode::Shl: 122 case Instruction::Opcode::LShr: 123 case Instruction::Opcode::AShr: 124 case Instruction::Opcode::And: 125 case Instruction::Opcode::Or: 126 case Instruction::Opcode::Xor: { 127 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 128 auto *LHS = Operands[0]; 129 auto *RHS = Operands[1]; 130 return BinaryOperator::createWithCopiedFlags( 131 BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec"); 132 } 133 case Instruction::Opcode::Load: { 134 auto *Ld0 = cast<LoadInst>(Bndl[0]); 135 Value *Ptr = Ld0->getPointerOperand(); 136 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, 137 "VecL"); 138 } 139 case Instruction::Opcode::Store: { 140 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 141 Value *Val = Operands[0]; 142 Value *Ptr = Operands[1]; 143 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 144 } 145 case Instruction::Opcode::Br: 146 case Instruction::Opcode::Ret: 147 case Instruction::Opcode::PHI: 148 case Instruction::Opcode::AddrSpaceCast: 149 case Instruction::Opcode::Call: 150 case Instruction::Opcode::GetElementPtr: 151 llvm_unreachable("Unimplemented"); 152 break; 153 default: 154 llvm_unreachable("Unimplemented"); 155 break; 156 } 157 llvm_unreachable("Missing switch case!"); 158 // TODO: Propagate debug info. 159 }; 160 161 auto *VecI = CreateVectorInstr(Bndl, Operands); 162 if (VecI != nullptr) { 163 Change = true; 164 IMaps->registerVector(Bndl, VecI); 165 } 166 return VecI; 167 } 168 169 void BottomUpVec::tryEraseDeadInstrs() { 170 // Visiting the dead instructions bottom-to-top. 171 SmallVector<Instruction *> SortedDeadInstrCandidates( 172 DeadInstrCandidates.begin(), DeadInstrCandidates.end()); 173 sort(SortedDeadInstrCandidates, 174 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 175 for (Instruction *I : reverse(SortedDeadInstrCandidates)) { 176 if (I->hasNUses(0)) 177 I->eraseFromParent(); 178 } 179 DeadInstrCandidates.clear(); 180 } 181 182 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) { 183 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}); 184 return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt, 185 VecOp->getContext(), "VShuf"); 186 } 187 188 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) { 189 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack); 190 191 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 192 unsigned Lanes = VecUtils::getNumLanes(ToPack); 193 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 194 195 // Create a series of pack instructions. 196 Value *LastInsert = PoisonValue::get(VecTy); 197 198 Context &Ctx = ToPack[0]->getContext(); 199 200 unsigned InsertIdx = 0; 201 for (Value *Elm : ToPack) { 202 // An element can be either scalar or vector. We need to generate different 203 // IR for each case. 204 if (Elm->getType()->isVectorTy()) { 205 unsigned NumElms = 206 cast<FixedVectorType>(Elm->getType())->getNumElements(); 207 for (auto ExtrLane : seq<int>(0, NumElms)) { 208 // We generate extract-insert pairs, for each lane in `Elm`. 209 Constant *ExtrLaneC = 210 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 211 // This may return a Constant if Elm is a Constant. 212 auto *ExtrI = 213 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 214 if (!isa<Constant>(ExtrI)) 215 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 216 Constant *InsertLaneC = 217 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 218 // This may also return a Constant if ExtrI is a Constant. 219 auto *InsertI = InsertElementInst::create( 220 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 221 if (!isa<Constant>(InsertI)) { 222 LastInsert = InsertI; 223 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 224 } 225 } 226 } else { 227 Constant *InsertLaneC = 228 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 229 // This may be folded into a Constant if LastInsert is a Constant. In 230 // that case we only collect the last constant. 231 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 232 WhereIt, Ctx, "Pack"); 233 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 234 WhereIt = std::next(NewI->getIterator()); 235 } 236 } 237 return LastInsert; 238 } 239 240 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) { 241 for (Value *V : Bndl) 242 DeadInstrCandidates.insert(cast<Instruction>(V)); 243 // Also collect the GEPs of vectorized loads and stores. 244 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 245 switch (Opcode) { 246 case Instruction::Opcode::Load: { 247 for (Value *V : drop_begin(Bndl)) 248 if (auto *Ptr = 249 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand())) 250 DeadInstrCandidates.insert(Ptr); 251 break; 252 } 253 case Instruction::Opcode::Store: { 254 for (Value *V : drop_begin(Bndl)) 255 if (auto *Ptr = 256 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand())) 257 DeadInstrCandidates.insert(Ptr); 258 break; 259 } 260 default: 261 break; 262 } 263 } 264 265 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) { 266 Value *NewVec = nullptr; 267 const auto &LegalityRes = Legality->canVectorize(Bndl); 268 switch (LegalityRes.getSubclassID()) { 269 case LegalityResultID::Widen: { 270 auto *I = cast<Instruction>(Bndl[0]); 271 SmallVector<Value *, 2> VecOperands; 272 switch (I->getOpcode()) { 273 case Instruction::Opcode::Load: 274 // Don't recurse towards the pointer operand. 275 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 276 break; 277 case Instruction::Opcode::Store: { 278 // Don't recurse towards the pointer operand. 279 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1); 280 VecOperands.push_back(VecOp); 281 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 282 break; 283 } 284 default: 285 // Visit all operands. 286 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 287 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1); 288 VecOperands.push_back(VecOp); 289 } 290 break; 291 } 292 NewVec = createVectorInstr(Bndl, VecOperands); 293 294 // Collect any potentially dead scalar instructions, including the original 295 // scalars and pointer operands of loads/stores. 296 if (NewVec != nullptr) 297 collectPotentiallyDeadInstrs(Bndl); 298 break; 299 } 300 case LegalityResultID::DiamondReuse: { 301 NewVec = cast<DiamondReuse>(LegalityRes).getVector(); 302 break; 303 } 304 case LegalityResultID::DiamondReuseWithShuffle: { 305 auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector(); 306 const ShuffleMask &Mask = 307 cast<DiamondReuseWithShuffle>(LegalityRes).getMask(); 308 NewVec = createShuffle(VecOp, Mask); 309 break; 310 } 311 case LegalityResultID::DiamondReuseMultiInput: { 312 const auto &Descr = 313 cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr(); 314 Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size()); 315 316 // TODO: Try to get WhereIt without creating a vector. 317 SmallVector<Value *, 4> DescrInstrs; 318 for (const auto &ElmDescr : Descr.getDescrs()) { 319 if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue())) 320 DescrInstrs.push_back(I); 321 } 322 auto WhereIt = getInsertPointAfterInstrs(DescrInstrs); 323 324 Value *LastV = PoisonValue::get(ResTy); 325 for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) { 326 Value *VecOp = ElmDescr.getValue(); 327 Context &Ctx = VecOp->getContext(); 328 Value *ValueToInsert; 329 if (ElmDescr.needsExtract()) { 330 ConstantInt *IdxC = 331 ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx()); 332 ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt, 333 VecOp->getContext(), "VExt"); 334 } else { 335 ValueToInsert = VecOp; 336 } 337 ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane); 338 Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC, 339 WhereIt, Ctx, "VIns"); 340 LastV = Ins; 341 } 342 NewVec = LastV; 343 break; 344 } 345 case LegalityResultID::Pack: { 346 // If we can't vectorize the seeds then just return. 347 if (Depth == 0) 348 return nullptr; 349 NewVec = createPack(Bndl); 350 break; 351 } 352 } 353 return NewVec; 354 } 355 356 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 357 DeadInstrCandidates.clear(); 358 Legality->clear(); 359 vectorizeRec(Bndl, /*Depth=*/0); 360 tryEraseDeadInstrs(); 361 return Change; 362 } 363 364 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 365 IMaps = std::make_unique<InstrMaps>(F.getContext()); 366 Legality = std::make_unique<LegalityAnalysis>( 367 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 368 F.getContext(), *IMaps); 369 Change = false; 370 const auto &DL = F.getParent()->getDataLayout(); 371 unsigned VecRegBits = 372 OverrideVecRegBits != 0 373 ? OverrideVecRegBits 374 : A.getTTI() 375 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 376 .getFixedValue(); 377 378 // TODO: Start from innermost BBs first 379 for (auto &BB : F) { 380 SeedCollector SC(&BB, A.getScalarEvolution()); 381 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 382 unsigned ElmBits = 383 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 384 Seeds[Seeds.getFirstUnusedElementIdx()])), 385 DL); 386 387 auto DivideBy2 = [](unsigned Num) { 388 auto Floor = VecUtils::getFloorPowerOf2(Num); 389 if (Floor == Num) 390 return Floor / 2; 391 return Floor; 392 }; 393 // Try to create the largest vector supported by the target. If it fails 394 // reduce the vector size by half. 395 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 396 Seeds.getNumUnusedBits() / ElmBits); 397 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 398 if (Seeds.allUsed()) 399 break; 400 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 401 // the slice. This could be quite expensive, so we enforce a limit. 402 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 403 OE = Seeds.size(); 404 Offset + 1 < OE; Offset += 1) { 405 // Seeds are getting used as we vectorize, so skip them. 406 if (Seeds.isUsed(Offset)) 407 continue; 408 if (Seeds.allUsed()) 409 break; 410 411 auto SeedSlice = 412 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 413 if (SeedSlice.empty()) 414 continue; 415 416 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 417 418 // TODO: If vectorization succeeds, run the RegionPassManager on the 419 // resulting region. 420 421 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 422 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 423 SeedSlice.end()); 424 Change |= tryVectorize(SeedSliceVals); 425 } 426 } 427 } 428 } 429 return Change; 430 } 431 432 } // namespace sandboxir 433 } // namespace llvm 434