xref: /llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp (revision c7053ac202de1723c49d2f02d1c56d7a0a4481c0)
1 //===- BottomUpVec.cpp - A bottom-up vectorizer pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
10 #include "llvm/ADT/SmallVector.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/SandboxIR/Function.h"
13 #include "llvm/SandboxIR/Instruction.h"
14 #include "llvm/SandboxIR/Module.h"
15 #include "llvm/SandboxIR/Utils.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
17 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
18 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19 
20 namespace llvm {
21 
22 static cl::opt<unsigned>
23     OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
24                        cl::desc("Override the vector register size in bits, "
25                                 "which is otherwise found by querying TTI."));
26 static cl::opt<bool>
27     AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
28                  cl::desc("Allow non-power-of-2 vectorization."));
29 
30 namespace sandboxir {
31 
32 BottomUpVec::BottomUpVec(StringRef Pipeline)
33     : FunctionPass("bottom-up-vec"),
34       RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}
35 
36 static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
37                                           unsigned OpIdx) {
38   SmallVector<Value *, 4> Operands;
39   for (Value *BndlV : Bndl) {
40     auto *BndlI = cast<Instruction>(BndlV);
41     Operands.push_back(BndlI->getOperand(OpIdx));
42   }
43   return Operands;
44 }
45 
46 /// \Returns the BB iterator after the lowest instruction in \p Vals, or the top
47 /// of BB if no instruction found in \p Vals.
48 static BasicBlock::iterator getInsertPointAfterInstrs(ArrayRef<Value *> Vals,
49                                                       BasicBlock *BB) {
50   auto *BotI = VecUtils::getLowest(Vals);
51   if (BotI == nullptr)
52     // We are using BB->begin() as the fallback insert point if `ToPack` did
53     // not contain instructions.
54     return BB->begin();
55   return std::next(BotI->getIterator());
56 }
57 
58 Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
59                                       ArrayRef<Value *> Operands) {
60   auto CreateVectorInstr = [](ArrayRef<Value *> Bndl,
61                               ArrayRef<Value *> Operands) -> Value * {
62     assert(all_of(Bndl, [](auto *V) { return isa<Instruction>(V); }) &&
63            "Expect Instructions!");
64     auto &Ctx = Bndl[0]->getContext();
65 
66     Type *ScalarTy = VecUtils::getElementType(Utils::getExpectedType(Bndl[0]));
67     auto *VecTy = VecUtils::getWideType(ScalarTy, VecUtils::getNumLanes(Bndl));
68 
69     BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(
70         Bndl, cast<Instruction>(Bndl[0])->getParent());
71 
72     auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
73     switch (Opcode) {
74     case Instruction::Opcode::ZExt:
75     case Instruction::Opcode::SExt:
76     case Instruction::Opcode::FPToUI:
77     case Instruction::Opcode::FPToSI:
78     case Instruction::Opcode::FPExt:
79     case Instruction::Opcode::PtrToInt:
80     case Instruction::Opcode::IntToPtr:
81     case Instruction::Opcode::SIToFP:
82     case Instruction::Opcode::UIToFP:
83     case Instruction::Opcode::Trunc:
84     case Instruction::Opcode::FPTrunc:
85     case Instruction::Opcode::BitCast: {
86       assert(Operands.size() == 1u && "Casts are unary!");
87       return CastInst::create(VecTy, Opcode, Operands[0], WhereIt, Ctx,
88                               "VCast");
89     }
90     case Instruction::Opcode::FCmp:
91     case Instruction::Opcode::ICmp: {
92       auto Pred = cast<CmpInst>(Bndl[0])->getPredicate();
93       assert(all_of(drop_begin(Bndl),
94                     [Pred](auto *SBV) {
95                       return cast<CmpInst>(SBV)->getPredicate() == Pred;
96                     }) &&
97              "Expected same predicate across bundle.");
98       return CmpInst::create(Pred, Operands[0], Operands[1], WhereIt, Ctx,
99                              "VCmp");
100     }
101     case Instruction::Opcode::Select: {
102       return SelectInst::create(Operands[0], Operands[1], Operands[2], WhereIt,
103                                 Ctx, "Vec");
104     }
105     case Instruction::Opcode::FNeg: {
106       auto *UOp0 = cast<UnaryOperator>(Bndl[0]);
107       auto OpC = UOp0->getOpcode();
108       return UnaryOperator::createWithCopiedFlags(OpC, Operands[0], UOp0,
109                                                   WhereIt, Ctx, "Vec");
110     }
111     case Instruction::Opcode::Add:
112     case Instruction::Opcode::FAdd:
113     case Instruction::Opcode::Sub:
114     case Instruction::Opcode::FSub:
115     case Instruction::Opcode::Mul:
116     case Instruction::Opcode::FMul:
117     case Instruction::Opcode::UDiv:
118     case Instruction::Opcode::SDiv:
119     case Instruction::Opcode::FDiv:
120     case Instruction::Opcode::URem:
121     case Instruction::Opcode::SRem:
122     case Instruction::Opcode::FRem:
123     case Instruction::Opcode::Shl:
124     case Instruction::Opcode::LShr:
125     case Instruction::Opcode::AShr:
126     case Instruction::Opcode::And:
127     case Instruction::Opcode::Or:
128     case Instruction::Opcode::Xor: {
129       auto *BinOp0 = cast<BinaryOperator>(Bndl[0]);
130       auto *LHS = Operands[0];
131       auto *RHS = Operands[1];
132       return BinaryOperator::createWithCopiedFlags(
133           BinOp0->getOpcode(), LHS, RHS, BinOp0, WhereIt, Ctx, "Vec");
134     }
135     case Instruction::Opcode::Load: {
136       auto *Ld0 = cast<LoadInst>(Bndl[0]);
137       Value *Ptr = Ld0->getPointerOperand();
138       return LoadInst::create(VecTy, Ptr, Ld0->getAlign(), WhereIt, Ctx,
139                               "VecL");
140     }
141     case Instruction::Opcode::Store: {
142       auto Align = cast<StoreInst>(Bndl[0])->getAlign();
143       Value *Val = Operands[0];
144       Value *Ptr = Operands[1];
145       return StoreInst::create(Val, Ptr, Align, WhereIt, Ctx);
146     }
147     case Instruction::Opcode::Br:
148     case Instruction::Opcode::Ret:
149     case Instruction::Opcode::PHI:
150     case Instruction::Opcode::AddrSpaceCast:
151     case Instruction::Opcode::Call:
152     case Instruction::Opcode::GetElementPtr:
153       llvm_unreachable("Unimplemented");
154       break;
155     default:
156       llvm_unreachable("Unimplemented");
157       break;
158     }
159     llvm_unreachable("Missing switch case!");
160     // TODO: Propagate debug info.
161   };
162 
163   auto *VecI = CreateVectorInstr(Bndl, Operands);
164   if (VecI != nullptr) {
165     Change = true;
166     IMaps->registerVector(Bndl, VecI);
167   }
168   return VecI;
169 }
170 
171 void BottomUpVec::tryEraseDeadInstrs() {
172   DenseMap<BasicBlock *, SmallVector<Instruction *>> SortedDeadInstrCandidates;
173   // The dead instrs could span BBs, so we need to collect and sort them per BB.
174   for (auto *DeadI : DeadInstrCandidates)
175     SortedDeadInstrCandidates[DeadI->getParent()].push_back(DeadI);
176   for (auto &Pair : SortedDeadInstrCandidates)
177     sort(Pair.second,
178          [](Instruction *I1, Instruction *I2) { return I1->comesBefore(I2); });
179   for (const auto &Pair : SortedDeadInstrCandidates) {
180     for (Instruction *I : reverse(Pair.second)) {
181       if (I->hasNUses(0))
182         // Erase the dead instructions bottom-to-top.
183         I->eraseFromParent();
184     }
185   }
186   DeadInstrCandidates.clear();
187 }
188 
189 Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask,
190                                   BasicBlock *UserBB) {
191   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp}, UserBB);
192   return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
193                                    VecOp->getContext(), "VShuf");
194 }
195 
196 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack, BasicBlock *UserBB) {
197   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack, UserBB);
198 
199   Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
200   unsigned Lanes = VecUtils::getNumLanes(ToPack);
201   Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
202 
203   // Create a series of pack instructions.
204   Value *LastInsert = PoisonValue::get(VecTy);
205 
206   Context &Ctx = ToPack[0]->getContext();
207 
208   unsigned InsertIdx = 0;
209   for (Value *Elm : ToPack) {
210     // An element can be either scalar or vector. We need to generate different
211     // IR for each case.
212     if (Elm->getType()->isVectorTy()) {
213       unsigned NumElms =
214           cast<FixedVectorType>(Elm->getType())->getNumElements();
215       for (auto ExtrLane : seq<int>(0, NumElms)) {
216         // We generate extract-insert pairs, for each lane in `Elm`.
217         Constant *ExtrLaneC =
218             ConstantInt::getSigned(Type::getInt32Ty(Ctx), ExtrLane);
219         // This may return a Constant if Elm is a Constant.
220         auto *ExtrI =
221             ExtractElementInst::create(Elm, ExtrLaneC, WhereIt, Ctx, "VPack");
222         if (!isa<Constant>(ExtrI))
223           WhereIt = std::next(cast<Instruction>(ExtrI)->getIterator());
224         Constant *InsertLaneC =
225             ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
226         // This may also return a Constant if ExtrI is a Constant.
227         auto *InsertI = InsertElementInst::create(
228             LastInsert, ExtrI, InsertLaneC, WhereIt, Ctx, "VPack");
229         if (!isa<Constant>(InsertI)) {
230           LastInsert = InsertI;
231           WhereIt = std::next(cast<Instruction>(LastInsert)->getIterator());
232         }
233       }
234     } else {
235       Constant *InsertLaneC =
236           ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
237       // This may be folded into a Constant if LastInsert is a Constant. In
238       // that case we only collect the last constant.
239       LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
240                                              WhereIt, Ctx, "Pack");
241       if (auto *NewI = dyn_cast<Instruction>(LastInsert))
242         WhereIt = std::next(NewI->getIterator());
243     }
244   }
245   return LastInsert;
246 }
247 
248 void BottomUpVec::collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl) {
249   for (Value *V : Bndl)
250     DeadInstrCandidates.insert(cast<Instruction>(V));
251   // Also collect the GEPs of vectorized loads and stores.
252   auto Opcode = cast<Instruction>(Bndl[0])->getOpcode();
253   switch (Opcode) {
254   case Instruction::Opcode::Load: {
255     for (Value *V : drop_begin(Bndl))
256       if (auto *Ptr =
257               dyn_cast<Instruction>(cast<LoadInst>(V)->getPointerOperand()))
258         DeadInstrCandidates.insert(Ptr);
259     break;
260   }
261   case Instruction::Opcode::Store: {
262     for (Value *V : drop_begin(Bndl))
263       if (auto *Ptr =
264               dyn_cast<Instruction>(cast<StoreInst>(V)->getPointerOperand()))
265         DeadInstrCandidates.insert(Ptr);
266     break;
267   }
268   default:
269     break;
270   }
271 }
272 
273 Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
274                                  ArrayRef<Value *> UserBndl, unsigned Depth) {
275   Value *NewVec = nullptr;
276   auto *UserBB = !UserBndl.empty()
277                      ? cast<Instruction>(UserBndl.front())->getParent()
278                      : cast<Instruction>(Bndl[0])->getParent();
279   const auto &LegalityRes = Legality->canVectorize(Bndl);
280   switch (LegalityRes.getSubclassID()) {
281   case LegalityResultID::Widen: {
282     auto *I = cast<Instruction>(Bndl[0]);
283     SmallVector<Value *, 2> VecOperands;
284     switch (I->getOpcode()) {
285     case Instruction::Opcode::Load:
286       // Don't recurse towards the pointer operand.
287       VecOperands.push_back(cast<LoadInst>(I)->getPointerOperand());
288       break;
289     case Instruction::Opcode::Store: {
290       // Don't recurse towards the pointer operand.
291       auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Bndl, Depth + 1);
292       VecOperands.push_back(VecOp);
293       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
294       break;
295     }
296     default:
297       // Visit all operands.
298       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
299         auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Bndl, Depth + 1);
300         VecOperands.push_back(VecOp);
301       }
302       break;
303     }
304     NewVec = createVectorInstr(Bndl, VecOperands);
305 
306     // Collect any potentially dead scalar instructions, including the original
307     // scalars and pointer operands of loads/stores.
308     if (NewVec != nullptr)
309       collectPotentiallyDeadInstrs(Bndl);
310     break;
311   }
312   case LegalityResultID::DiamondReuse: {
313     NewVec = cast<DiamondReuse>(LegalityRes).getVector();
314     break;
315   }
316   case LegalityResultID::DiamondReuseWithShuffle: {
317     auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
318     const ShuffleMask &Mask =
319         cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
320     NewVec = createShuffle(VecOp, Mask, UserBB);
321     break;
322   }
323   case LegalityResultID::DiamondReuseMultiInput: {
324     const auto &Descr =
325         cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
326     Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
327 
328     // TODO: Try to get WhereIt without creating a vector.
329     SmallVector<Value *, 4> DescrInstrs;
330     for (const auto &ElmDescr : Descr.getDescrs()) {
331       if (auto *I = dyn_cast<Instruction>(ElmDescr.getValue()))
332         DescrInstrs.push_back(I);
333     }
334     BasicBlock::iterator WhereIt =
335         getInsertPointAfterInstrs(DescrInstrs, UserBB);
336 
337     Value *LastV = PoisonValue::get(ResTy);
338     for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
339       Value *VecOp = ElmDescr.getValue();
340       Context &Ctx = VecOp->getContext();
341       Value *ValueToInsert;
342       if (ElmDescr.needsExtract()) {
343         ConstantInt *IdxC =
344             ConstantInt::get(Type::getInt32Ty(Ctx), ElmDescr.getExtractIdx());
345         ValueToInsert = ExtractElementInst::create(VecOp, IdxC, WhereIt,
346                                                    VecOp->getContext(), "VExt");
347       } else {
348         ValueToInsert = VecOp;
349       }
350       ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
351       Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
352                                              WhereIt, Ctx, "VIns");
353       LastV = Ins;
354     }
355     NewVec = LastV;
356     break;
357   }
358   case LegalityResultID::Pack: {
359     // If we can't vectorize the seeds then just return.
360     if (Depth == 0)
361       return nullptr;
362     NewVec = createPack(Bndl, UserBB);
363     break;
364   }
365   }
366   return NewVec;
367 }
368 
369 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
370   DeadInstrCandidates.clear();
371   Legality->clear();
372   vectorizeRec(Bndl, {}, /*Depth=*/0);
373   tryEraseDeadInstrs();
374   return Change;
375 }
376 
377 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
378   IMaps = std::make_unique<InstrMaps>(F.getContext());
379   Legality = std::make_unique<LegalityAnalysis>(
380       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
381       F.getContext(), *IMaps);
382   Change = false;
383   const auto &DL = F.getParent()->getDataLayout();
384   unsigned VecRegBits =
385       OverrideVecRegBits != 0
386           ? OverrideVecRegBits
387           : A.getTTI()
388                 .getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
389                 .getFixedValue();
390 
391   // TODO: Start from innermost BBs first
392   for (auto &BB : F) {
393     SeedCollector SC(&BB, A.getScalarEvolution());
394     for (SeedBundle &Seeds : SC.getStoreSeeds()) {
395       unsigned ElmBits =
396           Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
397                                 Seeds[Seeds.getFirstUnusedElementIdx()])),
398                             DL);
399 
400       auto DivideBy2 = [](unsigned Num) {
401         auto Floor = VecUtils::getFloorPowerOf2(Num);
402         if (Floor == Num)
403           return Floor / 2;
404         return Floor;
405       };
406       // Try to create the largest vector supported by the target. If it fails
407       // reduce the vector size by half.
408       for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
409                                          Seeds.getNumUnusedBits() / ElmBits);
410            SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
411         if (Seeds.allUsed())
412           break;
413         // Keep trying offsets after FirstUnusedElementIdx, until we vectorize
414         // the slice. This could be quite expensive, so we enforce a limit.
415         for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
416                       OE = Seeds.size();
417              Offset + 1 < OE; Offset += 1) {
418           // Seeds are getting used as we vectorize, so skip them.
419           if (Seeds.isUsed(Offset))
420             continue;
421           if (Seeds.allUsed())
422             break;
423 
424           auto SeedSlice =
425               Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
426           if (SeedSlice.empty())
427             continue;
428 
429           assert(SeedSlice.size() >= 2 && "Should have been rejected!");
430 
431           // TODO: If vectorization succeeds, run the RegionPassManager on the
432           // resulting region.
433 
434           // TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
435           SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
436                                              SeedSlice.end());
437           Change |= tryVectorize(SeedSliceVals);
438         }
439       }
440     }
441   }
442   return Change;
443 }
444 
445 } // namespace sandboxir
446 } // namespace llvm
447