xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision ac75d322801411f496fe5d1155c86453f915ae98)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 #ifndef NDEBUG
31 static cl::opt<bool>
32     AlwaysVerify("sbvec-always-verify", cl::init(false), cl::Hidden,
33                  cl::desc("Helps find bugs by verifying the IR whenever we "
34                           "emit new instructions (*very* expensive)."));
35 #endif // NDEBUG
36 
37 namespace sandboxir {
38 
39 BottomUpVec::BottomUpVec(StringRef Pipeline)
40     : FunctionPass("bottom-up-vec"),
41       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
42 
43 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
44                                           unsigned OpIdx) {
45   SmallVector<Value *, 4> Operands;
46   for (Value *BndlV : Bndl) {
47     auto *BndlI = cast<Instruction>(BndlV);
48     Operands.push_back(BndlI->getOperand(OpIdx));
49   }
50   return Operands;
51 }
52 
53 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top
54 /// of BB if no instruction found in \p Vals.
55 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals,
56                                                       BasicBlock *BB) {
57   auto *BotI = VecUtils::getLastPHIOrSelf(VecUtils::getLowest(Vals, BB));
58   if (BotI == nullptr)
59     // We are using BB->begin() (or after PHIs) as the fallback insert point.
60     return BB->empty()
61                ? BB->begin()
62                : std::next(
63                      VecUtils::getLastPHIOrSelf(&*BB->begin())->getIterator());
64   return std::next(BotI->getIterator());
65 }
66 
67 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
68                                       ArrayRef<Value *> Operands) {
69   auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
70                               ArrayRef<Value *> Operands) -> Value * {
71     assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
72            "Expect Instructions!");
73     auto &Ctx = Bndl[0]->getContext();
74 
75     Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
76     auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
77 
78     BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(
79         Bndl, cast<Instruction>(Bndl[0])->getParent());
80 
81     auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
82     switch (Opcode) {
83     case Instruction::Opcode::ZExt:
84     case Instruction::Opcode::SExt:
85     case Instruction::Opcode::FPToUI:
86     case Instruction::Opcode::FPToSI:
87     case Instruction::Opcode::FPExt:
88     case Instruction::Opcode::PtrToInt:
89     case Instruction::Opcode::IntToPtr:
90     case Instruction::Opcode::SIToFP:
91     case Instruction::Opcode::UIToFP:
92     case Instruction::Opcode::Trunc:
93     case Instruction::Opcode::FPTrunc:
94     case Instruction::Opcode::BitCast: {
95       assert(Operands.size() == 1u && "Casts are unary!");
96       return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
97                               "VCast");
98     }
99     case Instruction::Opcode::FCmp:
100     case Instruction::Opcode::ICmp: {
101       auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
102       assert(all_of(drop_begin(Bndl),
103                     [Pred](auto *SBV) {
104                       return cast<CmpInst>(SBV)->getPredicate() == Pred;
105                     }) &&
106              "Expected same predicate across bundle.");
107       return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
108                              "VCmp");
109     }
110     case Instruction::Opcode::Select: {
111       return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
112                                 Ctx, "Vec");
113     }
114     case Instruction::Opcode::FNeg: {
115       auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
116       auto OpC = UOp0->getOpcode();
117       return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
118                                                   WhereIt, Ctx, "Vec");
119     }
120     case Instruction::Opcode::Add:
121     case Instruction::Opcode::FAdd:
122     case Instruction::Opcode::Sub:
123     case Instruction::Opcode::FSub:
124     case Instruction::Opcode::Mul:
125     case Instruction::Opcode::FMul:
126     case Instruction::Opcode::UDiv:
127     case Instruction::Opcode::SDiv:
128     case Instruction::Opcode::FDiv:
129     case Instruction::Opcode::URem:
130     case Instruction::Opcode::SRem:
131     case Instruction::Opcode::FRem:
132     case Instruction::Opcode::Shl:
133     case Instruction::Opcode::LShr:
134     case Instruction::Opcode::AShr:
135     case Instruction::Opcode::And:
136     case Instruction::Opcode::Or:
137     case Instruction::Opcode::Xor: {
138       auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
139       auto *LHS = Operands[0];
140       auto *RHS = Operands[1];
141       return BinaryOperator::createWithCopiedFlags(
142           BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
143     }
144     case Instruction::Opcode::Load: {
145       auto *Ld0 = cast<LoadInst>(Bndl[0]);
146       Value *Ptr = Ld0->getPointerOperand();
147       return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
148                               "VecL");
149     }
150     case Instruction::Opcode::Store: {
151       auto Align = cast<StoreInst>(Bndl[0])->getAlign();
152       Value *Val = Operands[0];
153       Value *Ptr = Operands[1];
154       return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
155     }
156     case Instruction::Opcode::Br:
157     case Instruction::Opcode::Ret:
158     case Instruction::Opcode::PHI:
159     case Instruction::Opcode::AddrSpaceCast:
160     case Instruction::Opcode::Call:
161     case Instruction::Opcode::GetElementPtr:
162       llvm_unreachable("Unimplemented");
163       break;
164     default:
165       llvm_unreachable("Unimplemented");
166       break;
167     }
168     llvm_unreachable("Missing switch case!");
169     // TODO: Propagate debug info.
170   };
171 
172   auto *VecI = CreateVectorInstr(Bndl, Operands);
173   if (VecI != nullptr) {
174     Change = true;
175     IMaps->registerVector(Bndl, VecI);
176   }
177   return VecI;
178 }
179 
180 void BottomUpVec::tryEraseDeadInstrs() {
181   DenseMap<BasicBlock *, SmallVector<Instruction *>> SortedDeadInstrCandidates;
182   // The dead instrs could span BBs, so we need to collect and sort them per BB.
183   for (auto *DeadI : DeadInstrCandidates)
184     SortedDeadInstrCandidates[DeadI->getParent()].push_back(DeadI);
185   for (auto &Pair : SortedDeadInstrCandidates)
186     sort(Pair.second,
187          [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
188   for (const auto &Pair : SortedDeadInstrCandidates) {
189     for (Instruction *I : reverse(Pair.second)) {
190       if (I->hasNUses(0))
191         // Erase the dead instructions bottom-to-top.
192         I->eraseFromParent();
193     }
194   }
195   DeadInstrCandidates.clear();
196 }
197 
198 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask,
199                                   BasicBlock *UserBB) {
200   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB);
201   return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
202                                    VecOp->getContext(), "VShuf");
203 }
204 
205 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) {
206   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB);
207 
208   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
209   unsigned Lanes = VecUtils::getNumLanes(ToPack);
210   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
211 
212   // Create a series of pack instructions.
213   Value *LastInsert = PoisonValue::get(VecTy);
214 
215   Context &Ctx = ToPack[0]->getContext();
216 
217   unsigned InsertIdx = 0;
218   for (Value *Elm : ToPack) {
219     // An element can be either scalar or vector. We need to generate different
220     // IR for each case.
221     if (Elm->getType()->isVectorTy()) {
222       unsigned NumElms =
223           cast<FixedVectorType>(Elm->getType())->getNumElements();
224       for (auto ExtrLane : seq<int>(0, NumElms)) {
225         // We generate extract-insert pairs, for each lane in `Elm`.
226         Constant *ExtrLaneC =
227             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
228         // This may return a Constant if Elm is a Constant.
229         auto *ExtrI =
230             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
231         if (!isa<Constant>(ExtrI))
232           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
233         Constant *InsertLaneC =
234             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
235         // This may also return a Constant if ExtrI is a Constant.
236         auto *InsertI = InsertElementInst::create(
237             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
238         if (!isa<Constant>(InsertI)) {
239           LastInsert = InsertI;
240           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
241         }
242       }
243     } else {
244       Constant *InsertLaneC =
245           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
246       // This may be folded into a Constant if LastInsert is a Constant. In
247       // that case we only collect the last constant.
248       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
249                                              WhereIt, Ctx, "Pack");
250       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
251         WhereIt = std::next(NewI->getIterator());
252     }
253   }
254   return LastInsert;
255 }
256 
257 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
258   for (Value *V : Bndl)
259     DeadInstrCandidates.insert(cast<Instruction>(V));
260   // Also collect the GEPs of vectorized loads and stores.
261   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
262   switch (Opcode) {
263   case Instruction::Opcode::Load: {
264     for (Value *V : drop_begin(Bndl))
265       if (auto *Ptr =
266               dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
267         DeadInstrCandidates.insert(Ptr);
268     break;
269   }
270   case Instruction::Opcode::Store: {
271     for (Value *V : drop_begin(Bndl))
272       if (auto *Ptr =
273               dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
274         DeadInstrCandidates.insert(Ptr);
275     break;
276   }
277   default:
278     break;
279   }
280 }
281 
282 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
283                                  ArrayRef<Value *> UserBndl, unsigned Depth) {
284   Value *NewVec = nullptr;
285   auto *UserBB = !UserBndl.empty()
286                      ? cast<Instruction>(UserBndl.front())->getParent()
287                      : cast<Instruction>(Bndl[0])->getParent();
288   const auto &LegalityRes = Legality->canVectorize(Bndl);
289   switch (LegalityRes.getSubclassID()) {
290   case LegalityResultID::Widen: {
291     auto *I = cast<Instruction>(Bndl[0]);
292     SmallVector<Value *, 2> VecOperands;
293     switch (I->getOpcode()) {
294     case Instruction::Opcode::Load:
295       // Don't recurse towards the pointer operand.
296       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
297       break;
298     case Instruction::Opcode::Store: {
299       // Don't recurse towards the pointer operand.
300       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1);
301       VecOperands.push_back(VecOp);
302       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
303       break;
304     }
305     default:
306       // Visit all operands.
307       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
308         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1);
309         VecOperands.push_back(VecOp);
310       }
311       break;
312     }
313     NewVec = createVectorInstr(Bndl, VecOperands);
314 
315     // Collect any potentially dead scalar instructions, including the original
316     // scalars and pointer operands of loads/stores.
317     if (NewVec != nullptr)
318       collectPotentiallyDeadInstrs(Bndl);
319     break;
320   }
321   case LegalityResultID::DiamondReuse: {
322     NewVec = cast<DiamondReuse>(LegalityRes).getVector();
323     break;
324   }
325   case LegalityResultID::DiamondReuseWithShuffle: {
326     auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
327     const ShuffleMask &Mask =
328         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
329     NewVec = createShuffle(VecOp, Mask, UserBB);
330     break;
331   }
332   case LegalityResultID::DiamondReuseMultiInput: {
333     const auto &Descr =
334         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
335     Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
336 
337     // TODO: Try to get WhereIt without creating a vector.
338     SmallVector<Value *, 4> DescrInstrs;
339     for (const auto &ElmDescr : Descr.getDescrs()) {
340       if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
341         DescrInstrs.push_back(I);
342     }
343     BasicBlock::iterator WhereIt =
344         getInsertPointAfterInstrs(DescrInstrs, UserBB);
345 
346     Value *LastV = PoisonValue::get(ResTy);
347     for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
348       Value *VecOp = ElmDescr.getValue();
349       Context &Ctx = VecOp->getContext();
350       Value *ValueToInsert;
351       if (ElmDescr.needsExtract()) {
352         ConstantInt *IdxC =
353             ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
354         ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
355                                                    VecOp->getContext(), "VExt");
356       } else {
357         ValueToInsert = VecOp;
358       }
359       ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
360       Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
361                                              WhereIt, Ctx, "VIns");
362       LastV = Ins;
363     }
364     NewVec = LastV;
365     break;
366   }
367   case LegalityResultID::Pack: {
368     // If we can't vectorize the seeds then just return.
369     if (Depth == 0)
370       return nullptr;
371     NewVec = createPack(Bndl, UserBB);
372     break;
373   }
374   }
375 #ifndef NDEBUG
376   if (AlwaysVerify) {
377     // This helps find broken IR by constantly verifying the function. Note that
378     // this is very expensive and should only be used for debugging.
379     Instruction *I0 = isa<Instruction>(Bndl[0])
380                           ? cast<Instruction>(Bndl[0])
381                           : cast<Instruction>(UserBndl[0]);
382     assert(!Utils::verifyFunction(I0->getParent()->getParent(), dbgs()) &&
383            "Broken function!");
384   }
385 #endif
386   return NewVec;
387 }
388 
389 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
390   DeadInstrCandidates.clear();
391   Legality->clear();
392   vectorizeRec(Bndl, {}, /*Depth=*/0);
393   tryEraseDeadInstrs();
394   return Change;
395 }
396 
397 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
398   IMaps = std::make_unique<InstrMaps>(F.getContext());
399   Legality = std::make_unique<LegalityAnalysis>(
400       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
401       F.getContext(), *IMaps);
402   Change = false;
403   const auto &DL = F.getParent()->getDataLayout();
404   unsigned VecRegBits =
405       OverrideVecRegBits != 0
406           ? OverrideVecRegBits
407           : A.getTTI()
408                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
409                 .getFixedValue();
410 
411   // TODO: Start from innermost BBs first
412   for (auto &BB : F) {
413     SeedCollector SC(&BB, A.getScalarEvolution());
414     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
415       unsigned ElmBits =
416           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
417                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
418                             DL);
419 
420       auto DivideBy2 = [](unsigned Num) {
421         auto Floor = VecUtils::getFloorPowerOf2(Num);
422         if (Floor == Num)
423           return Floor / 2;
424         return Floor;
425       };
426       // Try to create the largest vector supported by the target. If it fails
427       // reduce the vector size by half.
428       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
429                                          Seeds.getNumUnusedBits() / ElmBits);
430            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
431         if (Seeds.allUsed())
432           break;
433         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
434         // the slice. This could be quite expensive, so we enforce a limit.
435         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
436                       OE = Seeds.size();
437              Offset + 1 < OE; Offset += 1) {
438           // Seeds are getting used as we vectorize, so skip them.
439           if (Seeds.isUsed(Offset))
440             continue;
441           if (Seeds.allUsed())
442             break;
443 
444           auto SeedSlice =
445               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
446           if (SeedSlice.empty())
447             continue;
448 
449           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
450 
451           // TODO: If vectorization succeeds, run the RegionPassManager on the
452           // resulting region.
453 
454           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
455           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
456                                              SeedSlice.end());
457           Change |= tryVectorize(SeedSliceVals);
458         }
459       }
460     }
461   }
462   return Change;
463 }
464 
465 } // namespace sandboxir
466 } // namespace llvm
467