xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision 6312beef788a209dc7d73c2c10b36197dab1cff3)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 namespace sandboxir {
31 
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33     : FunctionPass("bottom-up-vec"),
34       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35 
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37                                           unsigned OpIdx) {
38   SmallVector<Value *, 4> Operands;
39   for (Value *BndlV : Bndl) {
40     auto *BndlI = cast<Instruction>(BndlV);
41     Operands.push_back(BndlI->getOperand(OpIdx));
42   }
43   return Operands;
44 }
45 
46 static BasicBlock::iterator
47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
48   // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49   auto *BotI = cast<Instruction>(
50       *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
51         return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
52       }));
53   // If Bndl contains Arguments or Constants, use the beginning of the BB.
54   return std::next(BotI->getIterator());
55 }
56 
57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
58                                       ArrayRef<Value *> Operands) {
59   Change = true;
60   assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
61          "Expect Instructions!");
62   auto &Ctx = Bndl[0]->getContext();
63 
64   Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
65   auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
66 
67   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
68 
69   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
70   switch (Opcode) {
71   case Instruction::Opcode::ZExt:
72   case Instruction::Opcode::SExt:
73   case Instruction::Opcode::FPToUI:
74   case Instruction::Opcode::FPToSI:
75   case Instruction::Opcode::FPExt:
76   case Instruction::Opcode::PtrToInt:
77   case Instruction::Opcode::IntToPtr:
78   case Instruction::Opcode::SIToFP:
79   case Instruction::Opcode::UIToFP:
80   case Instruction::Opcode::Trunc:
81   case Instruction::Opcode::FPTrunc:
82   case Instruction::Opcode::BitCast: {
83     assert(Operands.size() == 1u && "Casts are unary!");
84     return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
85   }
86   case Instruction::Opcode::FCmp:
87   case Instruction::Opcode::ICmp: {
88     auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
89     assert(all_of(drop_begin(Bndl),
90                   [Pred](auto *SBV) {
91                     return cast<CmpInst>(SBV)->getPredicate() == Pred;
92                   }) &&
93            "Expected same predicate across bundle.");
94     return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
95                            "VCmp");
96   }
97   case Instruction::Opcode::Select: {
98     return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
99                               Ctx, "Vec");
100   }
101   case Instruction::Opcode::FNeg: {
102     auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
103     auto OpC = UOp0->getOpcode();
104     return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
105                                                 Ctx, "Vec");
106   }
107   case Instruction::Opcode::Add:
108   case Instruction::Opcode::FAdd:
109   case Instruction::Opcode::Sub:
110   case Instruction::Opcode::FSub:
111   case Instruction::Opcode::Mul:
112   case Instruction::Opcode::FMul:
113   case Instruction::Opcode::UDiv:
114   case Instruction::Opcode::SDiv:
115   case Instruction::Opcode::FDiv:
116   case Instruction::Opcode::URem:
117   case Instruction::Opcode::SRem:
118   case Instruction::Opcode::FRem:
119   case Instruction::Opcode::Shl:
120   case Instruction::Opcode::LShr:
121   case Instruction::Opcode::AShr:
122   case Instruction::Opcode::And:
123   case Instruction::Opcode::Or:
124   case Instruction::Opcode::Xor: {
125     auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
126     auto *LHS = Operands[0];
127     auto *RHS = Operands[1];
128     return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
129                                                  BinOp0, WhereIt, Ctx, "Vec");
130   }
131   case Instruction::Opcode::Load: {
132     auto *Ld0 = cast<LoadInst>(Bndl[0]);
133     Value *Ptr = Ld0->getPointerOperand();
134     return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
135   }
136   case Instruction::Opcode::Store: {
137     auto Align = cast<StoreInst>(Bndl[0])->getAlign();
138     Value *Val = Operands[0];
139     Value *Ptr = Operands[1];
140     return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
141   }
142   case Instruction::Opcode::Br:
143   case Instruction::Opcode::Ret:
144   case Instruction::Opcode::PHI:
145   case Instruction::Opcode::AddrSpaceCast:
146   case Instruction::Opcode::Call:
147   case Instruction::Opcode::GetElementPtr:
148     llvm_unreachable("Unimplemented");
149     break;
150   default:
151     llvm_unreachable("Unimplemented");
152     break;
153   }
154   llvm_unreachable("Missing switch case!");
155   // TODO: Propagate debug info.
156 }
157 
158 void BottomUpVec::tryEraseDeadInstrs() {
159   // Visiting the dead instructions bottom-to-top.
160   sort(DeadInstrCandidates,
161        [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
162   for (Instruction *I : reverse(DeadInstrCandidates)) {
163     if (I->hasNUses(0))
164       I->eraseFromParent();
165   }
166   DeadInstrCandidates.clear();
167 }
168 
169 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
170   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
171 
172   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
173   unsigned Lanes = VecUtils::getNumLanes(ToPack);
174   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
175 
176   // Create a series of pack instructions.
177   Value *LastInsert = PoisonValue::get(VecTy);
178 
179   Context &Ctx = ToPack[0]->getContext();
180 
181   unsigned InsertIdx = 0;
182   for (Value *Elm : ToPack) {
183     // An element can be either scalar or vector. We need to generate different
184     // IR for each case.
185     if (Elm->getType()->isVectorTy()) {
186       unsigned NumElms =
187           cast<FixedVectorType>(Elm->getType())->getNumElements();
188       for (auto ExtrLane : seq<int>(0, NumElms)) {
189         // We generate extract-insert pairs, for each lane in `Elm`.
190         Constant *ExtrLaneC =
191             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
192         // This may return a Constant if Elm is a Constant.
193         auto *ExtrI =
194             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
195         if (!isa<Constant>(ExtrI))
196           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
197         Constant *InsertLaneC =
198             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
199         // This may also return a Constant if ExtrI is a Constant.
200         auto *InsertI = InsertElementInst::create(
201             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
202         if (!isa<Constant>(InsertI)) {
203           LastInsert = InsertI;
204           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
205         }
206       }
207     } else {
208       Constant *InsertLaneC =
209           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
210       // This may be folded into a Constant if LastInsert is a Constant. In
211       // that case we only collect the last constant.
212       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
213                                              WhereIt, Ctx, "Pack");
214       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
215         WhereIt = std::next(NewI->getIterator());
216     }
217   }
218   return LastInsert;
219 }
220 
221 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
222   Value *NewVec = nullptr;
223   const auto &LegalityRes = Legality->canVectorize(Bndl);
224   switch (LegalityRes.getSubclassID()) {
225   case LegalityResultID::Widen: {
226     auto *I = cast<Instruction>(Bndl[0]);
227     SmallVector<Value *, 2> VecOperands;
228     switch (I->getOpcode()) {
229     case Instruction::Opcode::Load:
230       // Don't recurse towards the pointer operand.
231       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
232       break;
233     case Instruction::Opcode::Store: {
234       // Don't recurse towards the pointer operand.
235       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
236       VecOperands.push_back(VecOp);
237       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
238       break;
239     }
240     default:
241       // Visit all operands.
242       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
243         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
244         VecOperands.push_back(VecOp);
245       }
246       break;
247     }
248     NewVec = createVectorInstr(Bndl, VecOperands);
249 
250     // Collect the original scalar instructions as they may be dead.
251     if (NewVec != nullptr) {
252       for (Value *V : Bndl)
253         DeadInstrCandidates.push_back(cast<Instruction>(V));
254     }
255     break;
256   }
257   case LegalityResultID::Pack: {
258     // If we can't vectorize the seeds then just return.
259     if (Depth == 0)
260       return nullptr;
261     NewVec = createPack(Bndl);
262     break;
263   }
264   }
265   return NewVec;
266 }
267 
268 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
269   DeadInstrCandidates.clear();
270   Legality->clear();
271   vectorizeRec(Bndl, /*Depth=*/0);
272   tryEraseDeadInstrs();
273   return Change;
274 }
275 
276 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
277   Legality = std::make_unique<LegalityAnalysis>(
278       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
279       F.getContext());
280   Change = false;
281   const auto &DL = F.getParent()->getDataLayout();
282   unsigned VecRegBits =
283       OverrideVecRegBits != 0
284           ? OverrideVecRegBits
285           : A.getTTI()
286                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
287                 .getFixedValue();
288 
289   // TODO: Start from innermost BBs first
290   for (auto &BB : F) {
291     SeedCollector SC(&BB, A.getScalarEvolution());
292     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
293       unsigned ElmBits =
294           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
295                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
296                             DL);
297 
298       auto DivideBy2 = [](unsigned Num) {
299         auto Floor = VecUtils::getFloorPowerOf2(Num);
300         if (Floor == Num)
301           return Floor / 2;
302         return Floor;
303       };
304       // Try to create the largest vector supported by the target. If it fails
305       // reduce the vector size by half.
306       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
307                                          Seeds.getNumUnusedBits() / ElmBits);
308            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
309         if (Seeds.allUsed())
310           break;
311         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
312         // the slice. This could be quite expensive, so we enforce a limit.
313         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
314                       OE = Seeds.size();
315              Offset + 1 < OE; Offset += 1) {
316           // Seeds are getting used as we vectorize, so skip them.
317           if (Seeds.isUsed(Offset))
318             continue;
319           if (Seeds.allUsed())
320             break;
321 
322           auto SeedSlice =
323               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
324           if (SeedSlice.empty())
325             continue;
326 
327           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
328 
329           // TODO: If vectorization succeeds, run the RegionPassManager on the
330           // resulting region.
331 
332           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
333           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
334                                              SeedSlice.end());
335           Change |= tryVectorize(SeedSliceVals);
336         }
337       }
338     }
339   }
340   return Change;
341 }
342 
343 } // namespace sandboxir
344 } // namespace llvm
345