1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h" 10 #include "llvm/ADT/SmallVector.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/SandboxIR/Function.h" 13 #include "llvm/SandboxIR/Instruction.h" 14 #include "llvm/SandboxIR/Module.h" 15 #include "llvm/SandboxIR/Utils.h" 16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h" 17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h" 18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" 19 20 namespace llvm { 21 22 static cl::opt<unsigned> 23 OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden, 24 cl::desc("Override the vector register size in bits, " 25 "which is otherwise found by querying TTI.")); 26 static cl::opt<bool> 27 AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden, 28 cl::desc("Allow non-power-of-2 vectorization.")); 29 30 namespace sandboxir { 31 32 BottomUpVec::BottomUpVec(StringRef Pipeline) 33 : FunctionPass("bottom-up-vec"), 34 RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {} 35 36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl, 37 unsigned OpIdx) { 38 SmallVector<Value *, 4> Operands; 39 for (Value *BndlV : Bndl) { 40 auto *BndlI = cast<Instruction>(BndlV); 41 Operands.push_back(BndlI->getOperand(OpIdx)); 42 } 43 return Operands; 44 } 45 46 static BasicBlock::iterator 47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) { 48 // TODO: Use the VecUtils function for getting the bottom instr once it lands. 49 auto *BotI = cast<Instruction>( 50 *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) { 51 return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2)); 52 })); 53 // If Bndl contains Arguments or Constants, use the beginning of the BB. 54 return std::next(BotI->getIterator()); 55 } 56 57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl, 58 ArrayRef<Value *> Operands) { 59 Change = true; 60 assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) && 61 "Expect Instructions!"); 62 auto &Ctx = Bndl[0]->getContext(); 63 64 Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0])); 65 auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl)); 66 67 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl); 68 69 auto Opcode = cast<Instruction>(Bndl[0])->getOpcode(); 70 switch (Opcode) { 71 case Instruction::Opcode::ZExt: 72 case Instruction::Opcode::SExt: 73 case Instruction::Opcode::FPToUI: 74 case Instruction::Opcode::FPToSI: 75 case Instruction::Opcode::FPExt: 76 case Instruction::Opcode::PtrToInt: 77 case Instruction::Opcode::IntToPtr: 78 case Instruction::Opcode::SIToFP: 79 case Instruction::Opcode::UIToFP: 80 case Instruction::Opcode::Trunc: 81 case Instruction::Opcode::FPTrunc: 82 case Instruction::Opcode::BitCast: { 83 assert(Operands.size() == 1u && "Casts are unary!"); 84 return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast"); 85 } 86 case Instruction::Opcode::FCmp: 87 case Instruction::Opcode::ICmp: { 88 auto Pred = cast<CmpInst>(Bndl[0])->getPredicate(); 89 assert(all_of(drop_begin(Bndl), 90 [Pred](auto *SBV) { 91 return cast<CmpInst>(SBV)->getPredicate() == Pred; 92 }) && 93 "Expected same predicate across bundle."); 94 return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx, 95 "VCmp"); 96 } 97 case Instruction::Opcode::Select: { 98 return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt, 99 Ctx, "Vec"); 100 } 101 case Instruction::Opcode::FNeg: { 102 auto *UOp0 = cast<UnaryOperator>(Bndl[0]); 103 auto OpC = UOp0->getOpcode(); 104 return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt, 105 Ctx, "Vec"); 106 } 107 case Instruction::Opcode::Add: 108 case Instruction::Opcode::FAdd: 109 case Instruction::Opcode::Sub: 110 case Instruction::Opcode::FSub: 111 case Instruction::Opcode::Mul: 112 case Instruction::Opcode::FMul: 113 case Instruction::Opcode::UDiv: 114 case Instruction::Opcode::SDiv: 115 case Instruction::Opcode::FDiv: 116 case Instruction::Opcode::URem: 117 case Instruction::Opcode::SRem: 118 case Instruction::Opcode::FRem: 119 case Instruction::Opcode::Shl: 120 case Instruction::Opcode::LShr: 121 case Instruction::Opcode::AShr: 122 case Instruction::Opcode::And: 123 case Instruction::Opcode::Or: 124 case Instruction::Opcode::Xor: { 125 auto *BinOp0 = cast<BinaryOperator>(Bndl[0]); 126 auto *LHS = Operands[0]; 127 auto *RHS = Operands[1]; 128 return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS, 129 BinOp0, WhereIt, Ctx, "Vec"); 130 } 131 case Instruction::Opcode::Load: { 132 auto *Ld0 = cast<LoadInst>(Bndl[0]); 133 Value *Ptr = Ld0->getPointerOperand(); 134 return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL"); 135 } 136 case Instruction::Opcode::Store: { 137 auto Align = cast<StoreInst>(Bndl[0])->getAlign(); 138 Value *Val = Operands[0]; 139 Value *Ptr = Operands[1]; 140 return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx); 141 } 142 case Instruction::Opcode::Br: 143 case Instruction::Opcode::Ret: 144 case Instruction::Opcode::PHI: 145 case Instruction::Opcode::AddrSpaceCast: 146 case Instruction::Opcode::Call: 147 case Instruction::Opcode::GetElementPtr: 148 llvm_unreachable("Unimplemented"); 149 break; 150 default: 151 llvm_unreachable("Unimplemented"); 152 break; 153 } 154 llvm_unreachable("Missing switch case!"); 155 // TODO: Propagate debug info. 156 } 157 158 void BottomUpVec::tryEraseDeadInstrs() { 159 // Visiting the dead instructions bottom-to-top. 160 sort(DeadInstrCandidates, 161 [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); }); 162 for (Instruction *I : reverse(DeadInstrCandidates)) { 163 if (I->hasNUses(0)) 164 I->eraseFromParent(); 165 } 166 DeadInstrCandidates.clear(); 167 } 168 169 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) { 170 BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack); 171 172 Type *ScalarTy = VecUtils::getCommonScalarType(ToPack); 173 unsigned Lanes = VecUtils::getNumLanes(ToPack); 174 Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes); 175 176 // Create a series of pack instructions. 177 Value *LastInsert = PoisonValue::get(VecTy); 178 179 Context &Ctx = ToPack[0]->getContext(); 180 181 unsigned InsertIdx = 0; 182 for (Value *Elm : ToPack) { 183 // An element can be either scalar or vector. We need to generate different 184 // IR for each case. 185 if (Elm->getType()->isVectorTy()) { 186 unsigned NumElms = 187 cast<FixedVectorType>(Elm->getType())->getNumElements(); 188 for (auto ExtrLane : seq<int>(0, NumElms)) { 189 // We generate extract-insert pairs, for each lane in `Elm`. 190 Constant *ExtrLaneC = 191 ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane); 192 // This may return a Constant if Elm is a Constant. 193 auto *ExtrI = 194 ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack"); 195 if (!isa<Constant>(ExtrI)) 196 WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator()); 197 Constant *InsertLaneC = 198 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 199 // This may also return a Constant if ExtrI is a Constant. 200 auto *InsertI = InsertElementInst::create( 201 LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack"); 202 if (!isa<Constant>(InsertI)) { 203 LastInsert = InsertI; 204 WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator()); 205 } 206 } 207 } else { 208 Constant *InsertLaneC = 209 ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++); 210 // This may be folded into a Constant if LastInsert is a Constant. In 211 // that case we only collect the last constant. 212 LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC, 213 WhereIt, Ctx, "Pack"); 214 if (auto *NewI = dyn_cast<Instruction>(LastInsert)) 215 WhereIt = std::next(NewI->getIterator()); 216 } 217 } 218 return LastInsert; 219 } 220 221 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) { 222 Value *NewVec = nullptr; 223 const auto &LegalityRes = Legality->canVectorize(Bndl); 224 switch (LegalityRes.getSubclassID()) { 225 case LegalityResultID::Widen: { 226 auto *I = cast<Instruction>(Bndl[0]); 227 SmallVector<Value *, 2> VecOperands; 228 switch (I->getOpcode()) { 229 case Instruction::Opcode::Load: 230 // Don't recurse towards the pointer operand. 231 VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand()); 232 break; 233 case Instruction::Opcode::Store: { 234 // Don't recurse towards the pointer operand. 235 auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1); 236 VecOperands.push_back(VecOp); 237 VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand()); 238 break; 239 } 240 default: 241 // Visit all operands. 242 for (auto OpIdx : seq<unsigned>(I->getNumOperands())) { 243 auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1); 244 VecOperands.push_back(VecOp); 245 } 246 break; 247 } 248 NewVec = createVectorInstr(Bndl, VecOperands); 249 250 // Collect the original scalar instructions as they may be dead. 251 if (NewVec != nullptr) { 252 for (Value *V : Bndl) 253 DeadInstrCandidates.push_back(cast<Instruction>(V)); 254 } 255 break; 256 } 257 case LegalityResultID::Pack: { 258 // If we can't vectorize the seeds then just return. 259 if (Depth == 0) 260 return nullptr; 261 NewVec = createPack(Bndl); 262 break; 263 } 264 } 265 return NewVec; 266 } 267 268 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { 269 DeadInstrCandidates.clear(); 270 Legality->clear(); 271 vectorizeRec(Bndl, /*Depth=*/0); 272 tryEraseDeadInstrs(); 273 return Change; 274 } 275 276 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) { 277 Legality = std::make_unique<LegalityAnalysis>( 278 A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(), 279 F.getContext()); 280 Change = false; 281 const auto &DL = F.getParent()->getDataLayout(); 282 unsigned VecRegBits = 283 OverrideVecRegBits != 0 284 ? OverrideVecRegBits 285 : A.getTTI() 286 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) 287 .getFixedValue(); 288 289 // TODO: Start from innermost BBs first 290 for (auto &BB : F) { 291 SeedCollector SC(&BB, A.getScalarEvolution()); 292 for (SeedBundle &Seeds : SC.getStoreSeeds()) { 293 unsigned ElmBits = 294 Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType( 295 Seeds[Seeds.getFirstUnusedElementIdx()])), 296 DL); 297 298 auto DivideBy2 = [](unsigned Num) { 299 auto Floor = VecUtils::getFloorPowerOf2(Num); 300 if (Floor == Num) 301 return Floor / 2; 302 return Floor; 303 }; 304 // Try to create the largest vector supported by the target. If it fails 305 // reduce the vector size by half. 306 for (unsigned SliceElms = std::min(VecRegBits / ElmBits, 307 Seeds.getNumUnusedBits() / ElmBits); 308 SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) { 309 if (Seeds.allUsed()) 310 break; 311 // Keep trying offsets after FirstUnusedElementIdx, until we vectorize 312 // the slice. This could be quite expensive, so we enforce a limit. 313 for (unsigned Offset = Seeds.getFirstUnusedElementIdx(), 314 OE = Seeds.size(); 315 Offset + 1 < OE; Offset += 1) { 316 // Seeds are getting used as we vectorize, so skip them. 317 if (Seeds.isUsed(Offset)) 318 continue; 319 if (Seeds.allUsed()) 320 break; 321 322 auto SeedSlice = 323 Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2); 324 if (SeedSlice.empty()) 325 continue; 326 327 assert(SeedSlice.size() >= 2 && "Should have been rejected!"); 328 329 // TODO: If vectorization succeeds, run the RegionPassManager on the 330 // resulting region. 331 332 // TODO: Refactor to remove the unnecessary copy to SeedSliceVals. 333 SmallVector<Value *> SeedSliceVals(SeedSlice.begin(), 334 SeedSlice.end()); 335 Change |= tryVectorize(SeedSliceVals); 336 } 337 } 338 } 339 } 340 return Change; 341 } 342 343 } // namespace sandboxir 344 } // namespace llvm 345