xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision 8110af75b1500be2313e523a2d2da6bb7806b700)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 namespace sandboxir {
31 
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33     : FunctionPass("bottom-up-vec"),
34       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35 
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37                                           unsigned OpIdx) {
38   SmallVector<Value *, 4> Operands;
39   for (Value *BndlV : Bndl) {
40     auto *BndlI = cast<Instruction>(BndlV);
41     Operands.push_back(BndlI->getOperand(OpIdx));
42   }
43   return Operands;
44 }
45 
46 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top
47 /// of BB if no instruction found in \p Vals.
48 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals,
49                                                       BasicBlock *BB) {
50   auto *BotI = VecUtils::getLowest(Vals);
51   if (BotI == nullptr)
52     // We are using BB->begin() as the fallback insert point if `ToPack` did
53     // not contain instructions.
54     return BB->begin();
55   return std::next(BotI->getIterator());
56 }
57 
58 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
59                                       ArrayRef<Value *> Operands) {
60   auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
61                               ArrayRef<Value *> Operands) -> Value * {
62     assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
63            "Expect Instructions!");
64     auto &Ctx = Bndl[0]->getContext();
65 
66     Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
67     auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
68 
69     BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(
70         Bndl, cast<Instruction>(Bndl[0])->getParent());
71 
72     auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
73     switch (Opcode) {
74     case Instruction::Opcode::ZExt:
75     case Instruction::Opcode::SExt:
76     case Instruction::Opcode::FPToUI:
77     case Instruction::Opcode::FPToSI:
78     case Instruction::Opcode::FPExt:
79     case Instruction::Opcode::PtrToInt:
80     case Instruction::Opcode::IntToPtr:
81     case Instruction::Opcode::SIToFP:
82     case Instruction::Opcode::UIToFP:
83     case Instruction::Opcode::Trunc:
84     case Instruction::Opcode::FPTrunc:
85     case Instruction::Opcode::BitCast: {
86       assert(Operands.size() == 1u && "Casts are unary!");
87       return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
88                               "VCast");
89     }
90     case Instruction::Opcode::FCmp:
91     case Instruction::Opcode::ICmp: {
92       auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
93       assert(all_of(drop_begin(Bndl),
94                     [Pred](auto *SBV) {
95                       return cast<CmpInst>(SBV)->getPredicate() == Pred;
96                     }) &&
97              "Expected same predicate across bundle.");
98       return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
99                              "VCmp");
100     }
101     case Instruction::Opcode::Select: {
102       return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
103                                 Ctx, "Vec");
104     }
105     case Instruction::Opcode::FNeg: {
106       auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
107       auto OpC = UOp0->getOpcode();
108       return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
109                                                   WhereIt, Ctx, "Vec");
110     }
111     case Instruction::Opcode::Add:
112     case Instruction::Opcode::FAdd:
113     case Instruction::Opcode::Sub:
114     case Instruction::Opcode::FSub:
115     case Instruction::Opcode::Mul:
116     case Instruction::Opcode::FMul:
117     case Instruction::Opcode::UDiv:
118     case Instruction::Opcode::SDiv:
119     case Instruction::Opcode::FDiv:
120     case Instruction::Opcode::URem:
121     case Instruction::Opcode::SRem:
122     case Instruction::Opcode::FRem:
123     case Instruction::Opcode::Shl:
124     case Instruction::Opcode::LShr:
125     case Instruction::Opcode::AShr:
126     case Instruction::Opcode::And:
127     case Instruction::Opcode::Or:
128     case Instruction::Opcode::Xor: {
129       auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
130       auto *LHS = Operands[0];
131       auto *RHS = Operands[1];
132       return BinaryOperator::createWithCopiedFlags(
133           BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
134     }
135     case Instruction::Opcode::Load: {
136       auto *Ld0 = cast<LoadInst>(Bndl[0]);
137       Value *Ptr = Ld0->getPointerOperand();
138       return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
139                               "VecL");
140     }
141     case Instruction::Opcode::Store: {
142       auto Align = cast<StoreInst>(Bndl[0])->getAlign();
143       Value *Val = Operands[0];
144       Value *Ptr = Operands[1];
145       return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
146     }
147     case Instruction::Opcode::Br:
148     case Instruction::Opcode::Ret:
149     case Instruction::Opcode::PHI:
150     case Instruction::Opcode::AddrSpaceCast:
151     case Instruction::Opcode::Call:
152     case Instruction::Opcode::GetElementPtr:
153       llvm_unreachable("Unimplemented");
154       break;
155     default:
156       llvm_unreachable("Unimplemented");
157       break;
158     }
159     llvm_unreachable("Missing switch case!");
160     // TODO: Propagate debug info.
161   };
162 
163   auto *VecI = CreateVectorInstr(Bndl, Operands);
164   if (VecI != nullptr) {
165     Change = true;
166     IMaps->registerVector(Bndl, VecI);
167   }
168   return VecI;
169 }
170 
171 void BottomUpVec::tryEraseDeadInstrs() {
172   // Visiting the dead instructions bottom-to-top.
173   SmallVector<Instruction *> SortedDeadInstrCandidates(
174       DeadInstrCandidates.begin(), DeadInstrCandidates.end());
175   sort(SortedDeadInstrCandidates,
176        [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
177   for (Instruction *I : reverse(SortedDeadInstrCandidates)) {
178     if (I->hasNUses(0))
179       I->eraseFromParent();
180   }
181   DeadInstrCandidates.clear();
182 }
183 
184 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask,
185                                   BasicBlock *UserBB) {
186   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB);
187   return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
188                                    VecOp->getContext(), "VShuf");
189 }
190 
191 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) {
192   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB);
193 
194   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
195   unsigned Lanes = VecUtils::getNumLanes(ToPack);
196   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
197 
198   // Create a series of pack instructions.
199   Value *LastInsert = PoisonValue::get(VecTy);
200 
201   Context &Ctx = ToPack[0]->getContext();
202 
203   unsigned InsertIdx = 0;
204   for (Value *Elm : ToPack) {
205     // An element can be either scalar or vector. We need to generate different
206     // IR for each case.
207     if (Elm->getType()->isVectorTy()) {
208       unsigned NumElms =
209           cast<FixedVectorType>(Elm->getType())->getNumElements();
210       for (auto ExtrLane : seq<int>(0, NumElms)) {
211         // We generate extract-insert pairs, for each lane in `Elm`.
212         Constant *ExtrLaneC =
213             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
214         // This may return a Constant if Elm is a Constant.
215         auto *ExtrI =
216             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
217         if (!isa<Constant>(ExtrI))
218           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
219         Constant *InsertLaneC =
220             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
221         // This may also return a Constant if ExtrI is a Constant.
222         auto *InsertI = InsertElementInst::create(
223             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
224         if (!isa<Constant>(InsertI)) {
225           LastInsert = InsertI;
226           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
227         }
228       }
229     } else {
230       Constant *InsertLaneC =
231           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
232       // This may be folded into a Constant if LastInsert is a Constant. In
233       // that case we only collect the last constant.
234       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
235                                              WhereIt, Ctx, "Pack");
236       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
237         WhereIt = std::next(NewI->getIterator());
238     }
239   }
240   return LastInsert;
241 }
242 
243 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
244   for (Value *V : Bndl)
245     DeadInstrCandidates.insert(cast<Instruction>(V));
246   // Also collect the GEPs of vectorized loads and stores.
247   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
248   switch (Opcode) {
249   case Instruction::Opcode::Load: {
250     for (Value *V : drop_begin(Bndl))
251       if (auto *Ptr =
252               dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
253         DeadInstrCandidates.insert(Ptr);
254     break;
255   }
256   case Instruction::Opcode::Store: {
257     for (Value *V : drop_begin(Bndl))
258       if (auto *Ptr =
259               dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
260         DeadInstrCandidates.insert(Ptr);
261     break;
262   }
263   default:
264     break;
265   }
266 }
267 
268 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
269                                  ArrayRef<Value *> UserBndl, unsigned Depth) {
270   Value *NewVec = nullptr;
271   auto *UserBB = !UserBndl.empty()
272                      ? cast<Instruction>(UserBndl.front())->getParent()
273                      : cast<Instruction>(Bndl[0])->getParent();
274   const auto &LegalityRes = Legality->canVectorize(Bndl);
275   switch (LegalityRes.getSubclassID()) {
276   case LegalityResultID::Widen: {
277     auto *I = cast<Instruction>(Bndl[0]);
278     SmallVector<Value *, 2> VecOperands;
279     switch (I->getOpcode()) {
280     case Instruction::Opcode::Load:
281       // Don't recurse towards the pointer operand.
282       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
283       break;
284     case Instruction::Opcode::Store: {
285       // Don't recurse towards the pointer operand.
286       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1);
287       VecOperands.push_back(VecOp);
288       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
289       break;
290     }
291     default:
292       // Visit all operands.
293       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
294         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1);
295         VecOperands.push_back(VecOp);
296       }
297       break;
298     }
299     NewVec = createVectorInstr(Bndl, VecOperands);
300 
301     // Collect any potentially dead scalar instructions, including the original
302     // scalars and pointer operands of loads/stores.
303     if (NewVec != nullptr)
304       collectPotentiallyDeadInstrs(Bndl);
305     break;
306   }
307   case LegalityResultID::DiamondReuse: {
308     NewVec = cast<DiamondReuse>(LegalityRes).getVector();
309     break;
310   }
311   case LegalityResultID::DiamondReuseWithShuffle: {
312     auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
313     const ShuffleMask &Mask =
314         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
315     NewVec = createShuffle(VecOp, Mask, UserBB);
316     break;
317   }
318   case LegalityResultID::DiamondReuseMultiInput: {
319     const auto &Descr =
320         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
321     Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
322 
323     // TODO: Try to get WhereIt without creating a vector.
324     SmallVector<Value *, 4> DescrInstrs;
325     for (const auto &ElmDescr : Descr.getDescrs()) {
326       if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
327         DescrInstrs.push_back(I);
328     }
329     BasicBlock::iterator WhereIt =
330         getInsertPointAfterInstrs(DescrInstrs, UserBB);
331 
332     Value *LastV = PoisonValue::get(ResTy);
333     for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
334       Value *VecOp = ElmDescr.getValue();
335       Context &Ctx = VecOp->getContext();
336       Value *ValueToInsert;
337       if (ElmDescr.needsExtract()) {
338         ConstantInt *IdxC =
339             ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
340         ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
341                                                    VecOp->getContext(), "VExt");
342       } else {
343         ValueToInsert = VecOp;
344       }
345       ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
346       Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
347                                              WhereIt, Ctx, "VIns");
348       LastV = Ins;
349     }
350     NewVec = LastV;
351     break;
352   }
353   case LegalityResultID::Pack: {
354     // If we can't vectorize the seeds then just return.
355     if (Depth == 0)
356       return nullptr;
357     NewVec = createPack(Bndl, UserBB);
358     break;
359   }
360   }
361   return NewVec;
362 }
363 
364 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
365   DeadInstrCandidates.clear();
366   Legality->clear();
367   vectorizeRec(Bndl, {}, /*Depth=*/0);
368   tryEraseDeadInstrs();
369   return Change;
370 }
371 
372 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
373   IMaps = std::make_unique<InstrMaps>(F.getContext());
374   Legality = std::make_unique<LegalityAnalysis>(
375       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
376       F.getContext(), *IMaps);
377   Change = false;
378   const auto &DL = F.getParent()->getDataLayout();
379   unsigned VecRegBits =
380       OverrideVecRegBits != 0
381           ? OverrideVecRegBits
382           : A.getTTI()
383                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
384                 .getFixedValue();
385 
386   // TODO: Start from innermost BBs first
387   for (auto &BB : F) {
388     SeedCollector SC(&BB, A.getScalarEvolution());
389     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
390       unsigned ElmBits =
391           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
392                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
393                             DL);
394 
395       auto DivideBy2 = [](unsigned Num) {
396         auto Floor = VecUtils::getFloorPowerOf2(Num);
397         if (Floor == Num)
398           return Floor / 2;
399         return Floor;
400       };
401       // Try to create the largest vector supported by the target. If it fails
402       // reduce the vector size by half.
403       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
404                                          Seeds.getNumUnusedBits() / ElmBits);
405            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
406         if (Seeds.allUsed())
407           break;
408         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
409         // the slice. This could be quite expensive, so we enforce a limit.
410         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
411                       OE = Seeds.size();
412              Offset + 1 < OE; Offset += 1) {
413           // Seeds are getting used as we vectorize, so skip them.
414           if (Seeds.isUsed(Offset))
415             continue;
416           if (Seeds.allUsed())
417             break;
418 
419           auto SeedSlice =
420               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
421           if (SeedSlice.empty())
422             continue;
423 
424           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
425 
426           // TODO: If vectorization succeeds, run the RegionPassManager on the
427           // resulting region.
428 
429           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
430           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
431                                              SeedSlice.end());
432           Change |= tryVectorize(SeedSliceVals);
433         }
434       }
435     }
436   }
437   return Change;
438 }
439 
440 } // namespace sandboxir
441 } // namespace llvm
442