xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision 7c51c310ad9a50e721e5f17f2f27f066a0d77b80)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 namespace sandboxir {
31 
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33     : FunctionPass("bottom-up-vec"),
34       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35 
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37                                           unsigned OpIdx) {
38   SmallVector<Value *, 4> Operands;
39   for (Value *BndlV : Bndl) {
40     auto *BndlI = cast<Instruction>(BndlV);
41     Operands.push_back(BndlI->getOperand(OpIdx));
42   }
43   return Operands;
44 }
45 
46 static BasicBlock::iterator
47 getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
48   // TODO: Use the VecUtils function for getting the bottom instr once it lands.
49   auto *BotI = cast<Instruction>(
50       *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
51         return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
52       }));
53   // If Bndl contains Arguments or Constants, use the beginning of the BB.
54   return std::next(BotI->getIterator());
55 }
56 
57 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
58                                       ArrayRef<Value *> Operands) {
59   Change = true;
60   assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
61          "Expect Instructions!");
62   auto &Ctx = Bndl[0]->getContext();
63 
64   Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
65   auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
66 
67   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(Bndl);
68 
69   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
70   switch (Opcode) {
71   case Instruction::Opcode::ZExt:
72   case Instruction::Opcode::SExt:
73   case Instruction::Opcode::FPToUI:
74   case Instruction::Opcode::FPToSI:
75   case Instruction::Opcode::FPExt:
76   case Instruction::Opcode::PtrToInt:
77   case Instruction::Opcode::IntToPtr:
78   case Instruction::Opcode::SIToFP:
79   case Instruction::Opcode::UIToFP:
80   case Instruction::Opcode::Trunc:
81   case Instruction::Opcode::FPTrunc:
82   case Instruction::Opcode::BitCast: {
83     assert(Operands.size() == 1u && "Casts are unary!");
84     return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx, "VCast");
85   }
86   case Instruction::Opcode::FCmp:
87   case Instruction::Opcode::ICmp: {
88     auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
89     assert(all_of(drop_begin(Bndl),
90                   [Pred](auto *SBV) {
91                     return cast<CmpInst>(SBV)->getPredicate() == Pred;
92                   }) &&
93            "Expected same predicate across bundle.");
94     return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
95                            "VCmp");
96   }
97   case Instruction::Opcode::Select: {
98     return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
99                               Ctx, "Vec");
100   }
101   case Instruction::Opcode::FNeg: {
102     auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
103     auto OpC = UOp0->getOpcode();
104     return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0, WhereIt,
105                                                 Ctx, "Vec");
106   }
107   case Instruction::Opcode::Add:
108   case Instruction::Opcode::FAdd:
109   case Instruction::Opcode::Sub:
110   case Instruction::Opcode::FSub:
111   case Instruction::Opcode::Mul:
112   case Instruction::Opcode::FMul:
113   case Instruction::Opcode::UDiv:
114   case Instruction::Opcode::SDiv:
115   case Instruction::Opcode::FDiv:
116   case Instruction::Opcode::URem:
117   case Instruction::Opcode::SRem:
118   case Instruction::Opcode::FRem:
119   case Instruction::Opcode::Shl:
120   case Instruction::Opcode::LShr:
121   case Instruction::Opcode::AShr:
122   case Instruction::Opcode::And:
123   case Instruction::Opcode::Or:
124   case Instruction::Opcode::Xor: {
125     auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
126     auto *LHS = Operands[0];
127     auto *RHS = Operands[1];
128     return BinaryOperator::createWithCopiedFlags(BinOp0->getOpcode(), LHS, RHS,
129                                                  BinOp0, WhereIt, Ctx, "Vec");
130   }
131   case Instruction::Opcode::Load: {
132     auto *Ld0 = cast<LoadInst>(Bndl[0]);
133     Value *Ptr = Ld0->getPointerOperand();
134     return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx, "VecL");
135   }
136   case Instruction::Opcode::Store: {
137     auto Align = cast<StoreInst>(Bndl[0])->getAlign();
138     Value *Val = Operands[0];
139     Value *Ptr = Operands[1];
140     return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
141   }
142   case Instruction::Opcode::Br:
143   case Instruction::Opcode::Ret:
144   case Instruction::Opcode::PHI:
145   case Instruction::Opcode::AddrSpaceCast:
146   case Instruction::Opcode::Call:
147   case Instruction::Opcode::GetElementPtr:
148     llvm_unreachable("Unimplemented");
149     break;
150   default:
151     llvm_unreachable("Unimplemented");
152     break;
153   }
154   llvm_unreachable("Missing switch case!");
155   // TODO: Propagate debug info.
156 }
157 
158 void BottomUpVec::tryEraseDeadInstrs() {
159   // Visiting the dead instructions bottom-to-top.
160   SmallVector<Instruction *> SortedDeadInstrCandidates(
161       DeadInstrCandidates.begin(), DeadInstrCandidates.end());
162   sort(SortedDeadInstrCandidates,
163        [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
164   for (Instruction *I : reverse(SortedDeadInstrCandidates)) {
165     if (I->hasNUses(0))
166       I->eraseFromParent();
167   }
168   DeadInstrCandidates.clear();
169 }
170 
171 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
172   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
173 
174   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
175   unsigned Lanes = VecUtils::getNumLanes(ToPack);
176   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
177 
178   // Create a series of pack instructions.
179   Value *LastInsert = PoisonValue::get(VecTy);
180 
181   Context &Ctx = ToPack[0]->getContext();
182 
183   unsigned InsertIdx = 0;
184   for (Value *Elm : ToPack) {
185     // An element can be either scalar or vector. We need to generate different
186     // IR for each case.
187     if (Elm->getType()->isVectorTy()) {
188       unsigned NumElms =
189           cast<FixedVectorType>(Elm->getType())->getNumElements();
190       for (auto ExtrLane : seq<int>(0, NumElms)) {
191         // We generate extract-insert pairs, for each lane in `Elm`.
192         Constant *ExtrLaneC =
193             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
194         // This may return a Constant if Elm is a Constant.
195         auto *ExtrI =
196             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
197         if (!isa<Constant>(ExtrI))
198           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
199         Constant *InsertLaneC =
200             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
201         // This may also return a Constant if ExtrI is a Constant.
202         auto *InsertI = InsertElementInst::create(
203             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
204         if (!isa<Constant>(InsertI)) {
205           LastInsert = InsertI;
206           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
207         }
208       }
209     } else {
210       Constant *InsertLaneC =
211           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
212       // This may be folded into a Constant if LastInsert is a Constant. In
213       // that case we only collect the last constant.
214       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
215                                              WhereIt, Ctx, "Pack");
216       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
217         WhereIt = std::next(NewI->getIterator());
218     }
219   }
220   return LastInsert;
221 }
222 
223 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
224   for (Value *V : Bndl)
225     DeadInstrCandidates.insert(cast<Instruction>(V));
226   // Also collect the GEPs of vectorized loads and stores.
227   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
228   switch (Opcode) {
229   case Instruction::Opcode::Load: {
230     for (Value *V : drop_begin(Bndl))
231       if (auto *Ptr =
232               dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
233         DeadInstrCandidates.insert(Ptr);
234     break;
235   }
236   case Instruction::Opcode::Store: {
237     for (Value *V : drop_begin(Bndl))
238       if (auto *Ptr =
239               dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
240         DeadInstrCandidates.insert(Ptr);
241     break;
242   }
243   default:
244     break;
245   }
246 }
247 
248 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
249   Value *NewVec = nullptr;
250   const auto &LegalityRes = Legality->canVectorize(Bndl);
251   switch (LegalityRes.getSubclassID()) {
252   case LegalityResultID::Widen: {
253     auto *I = cast<Instruction>(Bndl[0]);
254     SmallVector<Value *, 2> VecOperands;
255     switch (I->getOpcode()) {
256     case Instruction::Opcode::Load:
257       // Don't recurse towards the pointer operand.
258       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
259       break;
260     case Instruction::Opcode::Store: {
261       // Don't recurse towards the pointer operand.
262       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
263       VecOperands.push_back(VecOp);
264       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
265       break;
266     }
267     default:
268       // Visit all operands.
269       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
270         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
271         VecOperands.push_back(VecOp);
272       }
273       break;
274     }
275     NewVec = createVectorInstr(Bndl, VecOperands);
276 
277     // Collect any potentially dead scalar instructions, including the original
278     // scalars and pointer operands of loads/stores.
279     if (NewVec != nullptr)
280       collectPotentiallyDeadInstrs(Bndl);
281     break;
282   }
283   case LegalityResultID::Pack: {
284     // If we can't vectorize the seeds then just return.
285     if (Depth == 0)
286       return nullptr;
287     NewVec = createPack(Bndl);
288     break;
289   }
290   }
291   return NewVec;
292 }
293 
294 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
295   DeadInstrCandidates.clear();
296   Legality->clear();
297   vectorizeRec(Bndl, /*Depth=*/0);
298   tryEraseDeadInstrs();
299   return Change;
300 }
301 
302 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
303   Legality = std::make_unique<LegalityAnalysis>(
304       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
305       F.getContext());
306   Change = false;
307   const auto &DL = F.getParent()->getDataLayout();
308   unsigned VecRegBits =
309       OverrideVecRegBits != 0
310           ? OverrideVecRegBits
311           : A.getTTI()
312                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
313                 .getFixedValue();
314 
315   // TODO: Start from innermost BBs first
316   for (auto &BB : F) {
317     SeedCollector SC(&BB, A.getScalarEvolution());
318     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
319       unsigned ElmBits =
320           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
321                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
322                             DL);
323 
324       auto DivideBy2 = [](unsigned Num) {
325         auto Floor = VecUtils::getFloorPowerOf2(Num);
326         if (Floor == Num)
327           return Floor / 2;
328         return Floor;
329       };
330       // Try to create the largest vector supported by the target. If it fails
331       // reduce the vector size by half.
332       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
333                                          Seeds.getNumUnusedBits() / ElmBits);
334            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
335         if (Seeds.allUsed())
336           break;
337         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
338         // the slice. This could be quite expensive, so we enforce a limit.
339         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
340                       OE = Seeds.size();
341              Offset + 1 < OE; Offset += 1) {
342           // Seeds are getting used as we vectorize, so skip them.
343           if (Seeds.isUsed(Offset))
344             continue;
345           if (Seeds.allUsed())
346             break;
347 
348           auto SeedSlice =
349               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
350           if (SeedSlice.empty())
351             continue;
352 
353           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
354 
355           // TODO: If vectorization succeeds, run the RegionPassManager on the
356           // resulting region.
357 
358           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
359           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
360                                              SeedSlice.end());
361           Change |= tryVectorize(SeedSliceVals);
362         }
363       }
364     }
365   }
366   return Change;
367 }
368 
369 } // namespace sandboxir
370 } // namespace llvm
371