1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 namespace sandboxir { 31 32 BottomUpVec::BottomUpVec(StringRef Pipeline) 33 : FunctionPass("bottom-up-vec"), 34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 35 36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 37 unsigned OpIdx) { 38 SmallVector<Value *, 4> Operands; 39 for (Value *BndlV : Bndl) { 40 auto *BndlI = cast<Instruction>(BndlV); 41 Operands.push_back(BndlI->getOperand(OpIdx)); 42 } 43 return Operands; 44 } 45 46 static BasicBlock::iterator 47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) { 48 // TODO: Use the VecUtils function for getting the bottom instr once it lands. 49 auto *BotI = cast<Instruction>( 50 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { 51 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2)); 52 })); 53 // If Bndl contains Arguments or Constants, use the beginning of the BB. 54 return std::next(BotI->getIterator()); 55 } 56 57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 58 ArrayRef<Value *> Operands) { 59 Change = true; 60 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 61 "Expect Instructions!"); 62 auto &Ctx = Bndl[0]->getContext(); 63 64 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 65 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 66 67 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); 68 69 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 70 switch (Opcode) { 71 case Instruction::Opcode::ZExt: 72 case Instruction::Opcode::SExt: 73 case Instruction::Opcode::FPToUI: 74 case Instruction::Opcode::FPToSI: 75 case Instruction::Opcode::FPExt: 76 case Instruction::Opcode::PtrToInt: 77 case Instruction::Opcode::IntToPtr: 78 case Instruction::Opcode::SIToFP: 79 case Instruction::Opcode::UIToFP: 80 case Instruction::Opcode::Trunc: 81 case Instruction::Opcode::FPTrunc: 82 case Instruction::Opcode::BitCast: { 83 assert(Operands.size() == 1u && "Casts are unary!"); 84 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); 85 } 86 case Instruction::Opcode::FCmp: 87 case Instruction::Opcode::ICmp: { 88 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 89 assert(all_of(drop_begin(Bndl), 90 [Pred](auto *SBV) { 91 return cast<CmpInst>(SBV)->getPredicate() == Pred; 92 }) && 93 "Expected same predicate across bundle."); 94 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 95 "VCmp"); 96 } 97 case Instruction::Opcode::Select: { 98 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 99 Ctx, "Vec"); 100 } 101 case Instruction::Opcode::FNeg: { 102 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 103 auto OpC = UOp0->getOpcode(); 104 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, 105 Ctx, "Vec"); 106 } 107 case Instruction::Opcode::Add: 108 case Instruction::Opcode::FAdd: 109 case Instruction::Opcode::Sub: 110 case Instruction::Opcode::FSub: 111 case Instruction::Opcode::Mul: 112 case Instruction::Opcode::FMul: 113 case Instruction::Opcode::UDiv: 114 case Instruction::Opcode::SDiv: 115 case Instruction::Opcode::FDiv: 116 case Instruction::Opcode::URem: 117 case Instruction::Opcode::SRem: 118 case Instruction::Opcode::FRem: 119 case Instruction::Opcode::Shl: 120 case Instruction::Opcode::LShr: 121 case Instruction::Opcode::AShr: 122 case Instruction::Opcode::And: 123 case Instruction::Opcode::Or: 124 case Instruction::Opcode::Xor: { 125 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 126 auto *LHS = Operands[0]; 127 auto *RHS = Operands[1]; 128 return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, 129 BinOp0, WhereIt, Ctx, "Vec"); 130 } 131 case Instruction::Opcode::Load: { 132 auto *Ld0 = cast<LoadInst>(Bndl[0]); 133 Value *Ptr = Ld0->getPointerOperand(); 134 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); 135 } 136 case Instruction::Opcode::Store: { 137 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 138 Value *Val = Operands[0]; 139 Value *Ptr = Operands[1]; 140 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 141 } 142 case Instruction::Opcode::Br: 143 case Instruction::Opcode::Ret: 144 case Instruction::Opcode::PHI: 145 case Instruction::Opcode::AddrSpaceCast: 146 case Instruction::Opcode::Call: 147 case Instruction::Opcode::GetElementPtr: 148 llvm_unreachable("Unimplemented"); 149 break; 150 default: 151 llvm_unreachable("Unimplemented"); 152 break; 153 } 154 llvm_unreachable("Missing switch case!"); 155 // TODO: Propagate debug info. 156 } 157 158 void BottomUpVec::tryEraseDeadInstrs() { 159 // Visiting the dead instructions bottom-to-top. 160 SmallVector<Instruction *> SortedDeadInstrCandidates( 161 DeadInstrCandidates.begin(), DeadInstrCandidates.end()); 162 sort(SortedDeadInstrCandidates, 163 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 164 for (Instruction *I : reverse(SortedDeadInstrCandidates)) { 165 if (I->hasNUses(0)) 166 I->eraseFromParent(); 167 } 168 DeadInstrCandidates.clear(); 169 } 170 171 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) { 172 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack); 173 174 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 175 unsigned Lanes = VecUtils::getNumLanes(ToPack); 176 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 177 178 // Create a series of pack instructions. 179 Value *LastInsert = PoisonValue::get(VecTy); 180 181 Context &Ctx = ToPack[0]->getContext(); 182 183 unsigned InsertIdx = 0; 184 for (Value *Elm : ToPack) { 185 // An element can be either scalar or vector. We need to generate different 186 // IR for each case. 187 if (Elm->getType()->isVectorTy()) { 188 unsigned NumElms = 189 cast<FixedVectorType>(Elm->getType())->getNumElements(); 190 for (auto ExtrLane : seq<int>(0, NumElms)) { 191 // We generate extract-insert pairs, for each lane in `Elm`. 192 Constant *ExtrLaneC = 193 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 194 // This may return a Constant if Elm is a Constant. 195 auto *ExtrI = 196 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 197 if (!isa<Constant>(ExtrI)) 198 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 199 Constant *InsertLaneC = 200 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 201 // This may also return a Constant if ExtrI is a Constant. 202 auto *InsertI = InsertElementInst::create( 203 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 204 if (!isa<Constant>(InsertI)) { 205 LastInsert = InsertI; 206 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 207 } 208 } 209 } else { 210 Constant *InsertLaneC = 211 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 212 // This may be folded into a Constant if LastInsert is a Constant. In 213 // that case we only collect the last constant. 214 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 215 WhereIt, Ctx, "Pack"); 216 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 217 WhereIt = std::next(NewI->getIterator()); 218 } 219 } 220 return LastInsert; 221 } 222 223 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) { 224 for (Value *V : Bndl) 225 DeadInstrCandidates.insert(cast<Instruction>(V)); 226 // Also collect the GEPs of vectorized loads and stores. 227 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 228 switch (Opcode) { 229 case Instruction::Opcode::Load: { 230 for (Value *V : drop_begin(Bndl)) 231 if (auto *Ptr = 232 dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand())) 233 DeadInstrCandidates.insert(Ptr); 234 break; 235 } 236 case Instruction::Opcode::Store: { 237 for (Value *V : drop_begin(Bndl)) 238 if (auto *Ptr = 239 dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand())) 240 DeadInstrCandidates.insert(Ptr); 241 break; 242 } 243 default: 244 break; 245 } 246 } 247 248 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) { 249 Value *NewVec = nullptr; 250 const auto &LegalityRes = Legality->canVectorize(Bndl); 251 switch (LegalityRes.getSubclassID()) { 252 case LegalityResultID::Widen: { 253 auto *I = cast<Instruction>(Bndl[0]); 254 SmallVector<Value *, 2> VecOperands; 255 switch (I->getOpcode()) { 256 case Instruction::Opcode::Load: 257 // Don't recurse towards the pointer operand. 258 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 259 break; 260 case Instruction::Opcode::Store: { 261 // Don't recurse towards the pointer operand. 262 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1); 263 VecOperands.push_back(VecOp); 264 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 265 break; 266 } 267 default: 268 // Visit all operands. 269 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 270 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1); 271 VecOperands.push_back(VecOp); 272 } 273 break; 274 } 275 NewVec = createVectorInstr(Bndl, VecOperands); 276 277 // Collect any potentially dead scalar instructions, including the original 278 // scalars and pointer operands of loads/stores. 279 if (NewVec != nullptr) 280 collectPotentiallyDeadInstrs(Bndl); 281 break; 282 } 283 case LegalityResultID::Pack: { 284 // If we can't vectorize the seeds then just return. 285 if (Depth == 0) 286 return nullptr; 287 NewVec = createPack(Bndl); 288 break; 289 } 290 } 291 return NewVec; 292 } 293 294 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 295 DeadInstrCandidates.clear(); 296 Legality->clear(); 297 vectorizeRec(Bndl, /*Depth=*/0); 298 tryEraseDeadInstrs(); 299 return Change; 300 } 301 302 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 303 Legality = std::make_unique<LegalityAnalysis>( 304 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 305 F.getContext()); 306 Change = false; 307 const auto &DL = F.getParent()->getDataLayout(); 308 unsigned VecRegBits = 309 OverrideVecRegBits != 0 310 ? OverrideVecRegBits 311 : A.getTTI() 312 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 313 .getFixedValue(); 314 315 // TODO: Start from innermost BBs first 316 for (auto &BB : F) { 317 SeedCollector SC(&BB, A.getScalarEvolution()); 318 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 319 unsigned ElmBits = 320 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 321 Seeds[Seeds.getFirstUnusedElementIdx()])), 322 DL); 323 324 auto DivideBy2 = [](unsigned Num) { 325 auto Floor = VecUtils::getFloorPowerOf2(Num); 326 if (Floor == Num) 327 return Floor / 2; 328 return Floor; 329 }; 330 // Try to create the largest vector supported by the target. If it fails 331 // reduce the vector size by half. 332 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 333 Seeds.getNumUnusedBits() / ElmBits); 334 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 335 if (Seeds.allUsed()) 336 break; 337 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 338 // the slice. This could be quite expensive, so we enforce a limit. 339 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 340 OE = Seeds.size(); 341 Offset + 1 < OE; Offset += 1) { 342 // Seeds are getting used as we vectorize, so skip them. 343 if (Seeds.isUsed(Offset)) 344 continue; 345 if (Seeds.allUsed()) 346 break; 347 348 auto SeedSlice = 349 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 350 if (SeedSlice.empty()) 351 continue; 352 353 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 354 355 // TODO: If vectorization succeeds, run the RegionPassManager on the 356 // resulting region. 357 358 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 359 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 360 SeedSlice.end()); 361 Change |= tryVectorize(SeedSliceVals); 362 } 363 } 364 } 365 } 366 return Change; 367 } 368 369 } // namespace sandboxir 370 } // namespace llvm 371